added cmd line arguments to the client

This commit is contained in:
tigerbot 2017-04-04 12:09:20 -06:00
parent c1c18b71b0
commit c024c1b88f
1 changed files with 190 additions and 24 deletions

View File

@ -2,41 +2,207 @@ package main
import ( import (
"context" "context"
"fmt"
"regexp"
"strconv"
"strings"
jwt "github.com/dgrijalva/jwt-go"
flag "github.com/spf13/pflag"
"github.com/spf13/viper"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/client" "git.daplie.com/Daplie/go-rvpn-server/rvpn/client"
jwt "github.com/dgrijalva/jwt-go"
) )
func main() { var httpRegexp = regexp.MustCompile(`(?i)^http`)
tokenData := jwt.MapClaims{
"domains": []string{ func init() {
"localhost.foo.daplie.me", flag.StringSlice("locals", []string{}, "comma separated list of <proto>:<port> or "+
"localhost.bar.daplie.me", "<proto>:<hostname>:<port> to which matching incoming connections should forward. "+
}, "Ex: smtps:8465,https:example.com:8443")
flag.StringSlice("domains", []string{}, "comma separated list of domain names to set to the tunnel")
viper.BindPFlag("locals", flag.Lookup("locals"))
viper.BindPFlag("domains", flag.Lookup("domains"))
flag.BoolP("insecure", "k", false, "Allow TLS connections to stunneld without valid certs")
flag.String("stunneld", "", "the domain (or ip address) at which the RVPN server is running")
flag.String("secret", "", "the same secret used by stunneld (used for JWT authentication)")
flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
viper.BindPFlag("raw.insecure", flag.Lookup("insecure"))
viper.BindPFlag("raw.stunneld", flag.Lookup("stunneld"))
viper.BindPFlag("raw.secret", flag.Lookup("secret"))
viper.BindPFlag("raw.token", flag.Lookup("token"))
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
tokenStr, err := token.SignedString([]byte("abc123")) type proxy struct {
if err != nil { protocol string
hostname string
port int
}
func addLocals(proxies []proxy, location string) []proxy {
parts := strings.Split(location, ":")
if len(parts) > 3 {
panic(fmt.Sprintf("provided invalid location %q", location))
}
// If all that was provided as a "local" is the domain name we assume that domain
// has HTTP and HTTPS handlers on the default ports.
if len(parts) == 1 {
proxies = append(proxies, proxy{"http", parts[0], 80})
proxies = append(proxies, proxy{"https", parts[0], 443})
return proxies
}
// Make everything lower case and trim any slashes in something like https://john.example.com
parts[0] = strings.ToLower(parts[0])
parts[1] = strings.ToLower(strings.Trim(parts[1], "/"))
if len(parts) == 2 {
if strings.Contains(parts[1], ".") {
if parts[0] == "http" {
parts = append(parts, "80")
} else if parts[0] == "https" {
parts = append(parts, "443")
} else {
panic(fmt.Sprintf("port must be specified for %q", location))
}
} else {
// https:3443 -> https:*:3443
parts = []string{parts[0], "*", parts[1]}
}
}
if port, err := strconv.Atoi(parts[2]); err != nil {
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
} else if port <= 0 || port > 65535 {
panic(fmt.Sprintf("%d is an invalid port for local services", port))
} else {
proxies = append(proxies, proxy{parts[0], parts[1], port})
}
return proxies
}
func addDomains(proxies []proxy, location string) []proxy {
parts := strings.Split(location, ":")
if len(parts) > 3 {
panic(fmt.Sprintf("provided invalid location %q", location))
} else if len(parts) == 2 {
panic("invalid argument for --domains, use format <domainname> or <scheme>:<domainname>:<local-port>")
}
// If the scheme and port weren't provided use the zero values
if len(parts) == 1 {
return append(proxies, proxy{"", parts[0], 0})
}
if port, err := strconv.Atoi(parts[2]); err != nil {
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
} else if port <= 0 || port > 65535 {
panic(fmt.Sprintf("%d is an invalid port for local services", port))
} else {
proxies = append(proxies, proxy{parts[0], parts[1], port})
}
return proxies
}
func extractServicePorts(proxies []proxy) map[string]map[string]int {
result := make(map[string]map[string]int, 2)
for _, p := range proxies {
if p.protocol != "" && p.port != 0 {
hostPorts := result[p.protocol]
if hostPorts == nil {
result[p.protocol] = make(map[string]int)
hostPorts = result[p.protocol]
}
// Only HTTP and HTTPS allow us to determine the hostname from the request, so only
// those protocols support different ports for the same service.
if !httpRegexp.MatchString(p.protocol) || p.hostname == "" {
p.hostname = "*"
}
if port, ok := hostPorts[p.hostname]; ok && port != p.port {
panic(fmt.Sprintf("duplicate ports for %s://%s", p.protocol, p.hostname))
}
hostPorts[p.hostname] = p.port
}
}
// Make sure we have defaults for HTTPS and HTTP.
if result["https"] == nil {
result["https"] = make(map[string]int, 1)
}
if result["https"]["*"] == 0 {
result["https"]["*"] = 8443
}
if result["http"] == nil {
result["http"] = make(map[string]int, 1)
}
if result["http"]["*"] == 0 {
result["http"]["*"] = result["https"]["*"]
}
return result
}
func main() {
flag.Parse()
proxies := make([]proxy, 0)
for _, option := range viper.GetStringSlice("locals") {
for _, location := range strings.Split(option, ",") {
proxies = addLocals(proxies, location)
}
}
for _, option := range viper.GetStringSlice("domains") {
for _, location := range strings.Split(option, ",") {
proxies = addDomains(proxies, location)
}
}
servicePorts := extractServicePorts(proxies)
domainMap := make(map[string]bool)
for _, p := range proxies {
if p.hostname != "" && p.hostname != "*" {
domainMap[p.hostname] = true
}
}
if viper.GetString("raw.stunneld") == "" {
panic("must provide remote RVPN server to connect to")
}
var token string
if viper.GetString("raw.token") != "" {
token = viper.GetString("raw.token")
} else if viper.GetString("raw.secret") != "" {
domains := make([]string, 0, len(domainMap))
for name := range domainMap {
domains = append(domains, name)
}
tokenData := jwt.MapClaims{"domains": domains}
secret := []byte(viper.GetString("raw.secret"))
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
if tokenStr, err := jwtToken.SignedString(secret); err != nil {
panic(err) panic(err)
} else {
token = tokenStr
}
} else {
panic("must provide either token or secret")
} }
ctx, quit := context.WithCancel(context.Background()) ctx, quit := context.WithCancel(context.Background())
defer quit() defer quit()
config := client.Config{ config := client.Config{
Server: "wss://localhost.daplie.me:9999", Insecure: viper.GetBool("raw.insecure"),
Services: map[string]map[string]int{ Server: viper.GetString("raw.stunneld"),
"https": map[string]int{ Services: servicePorts,
"*": 8443, Token: token,
"localhost.foo.daplie.me": 4443,
},
"http": map[string]int{
"*": 8443,
"localhost.foo.daplie.me": 4443,
},
},
Token: tokenStr,
Insecure: true,
} }
panic(client.Run(ctx, &config)) panic(client.Run(ctx, &config))
} }