diff --git a/go-rvpn-client/main.go b/go-rvpn-client/main.go index 63aba8a..9ab43bc 100644 --- a/go-rvpn-client/main.go +++ b/go-rvpn-client/main.go @@ -2,41 +2,207 @@ package main import ( "context" + "fmt" + "regexp" + "strconv" + "strings" + + jwt "github.com/dgrijalva/jwt-go" + flag "github.com/spf13/pflag" + "github.com/spf13/viper" "git.daplie.com/Daplie/go-rvpn-server/rvpn/client" - jwt "github.com/dgrijalva/jwt-go" ) -func main() { - tokenData := jwt.MapClaims{ - "domains": []string{ - "localhost.foo.daplie.me", - "localhost.bar.daplie.me", - }, +var httpRegexp = regexp.MustCompile(`(?i)^http`) + +func init() { + flag.StringSlice("locals", []string{}, "comma separated list of : or "+ + ":: to which matching incoming connections should forward. "+ + "Ex: smtps:8465,https:example.com:8443") + flag.StringSlice("domains", []string{}, "comma separated list of domain names to set to the tunnel") + viper.BindPFlag("locals", flag.Lookup("locals")) + viper.BindPFlag("domains", flag.Lookup("domains")) + + flag.BoolP("insecure", "k", false, "Allow TLS connections to stunneld without valid certs") + flag.String("stunneld", "", "the domain (or ip address) at which the RVPN server is running") + flag.String("secret", "", "the same secret used by stunneld (used for JWT authentication)") + flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)") + viper.BindPFlag("raw.insecure", flag.Lookup("insecure")) + viper.BindPFlag("raw.stunneld", flag.Lookup("stunneld")) + viper.BindPFlag("raw.secret", flag.Lookup("secret")) + viper.BindPFlag("raw.token", flag.Lookup("token")) +} + +type proxy struct { + protocol string + hostname string + port int +} + +func addLocals(proxies []proxy, location string) []proxy { + parts := strings.Split(location, ":") + if len(parts) > 3 { + panic(fmt.Sprintf("provided invalid location %q", location)) } - token := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) - tokenStr, err := token.SignedString([]byte("abc123")) - if err != nil { - panic(err) + + // If all that was provided as a "local" is the domain name we assume that domain + // has HTTP and HTTPS handlers on the default ports. + if len(parts) == 1 { + proxies = append(proxies, proxy{"http", parts[0], 80}) + proxies = append(proxies, proxy{"https", parts[0], 443}) + return proxies + } + + // Make everything lower case and trim any slashes in something like https://john.example.com + parts[0] = strings.ToLower(parts[0]) + parts[1] = strings.ToLower(strings.Trim(parts[1], "/")) + + if len(parts) == 2 { + if strings.Contains(parts[1], ".") { + if parts[0] == "http" { + parts = append(parts, "80") + } else if parts[0] == "https" { + parts = append(parts, "443") + } else { + panic(fmt.Sprintf("port must be specified for %q", location)) + } + } else { + // https:3443 -> https:*:3443 + parts = []string{parts[0], "*", parts[1]} + } + } + + if port, err := strconv.Atoi(parts[2]); err != nil { + panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err)) + } else if port <= 0 || port > 65535 { + panic(fmt.Sprintf("%d is an invalid port for local services", port)) + } else { + proxies = append(proxies, proxy{parts[0], parts[1], port}) + } + return proxies +} + +func addDomains(proxies []proxy, location string) []proxy { + parts := strings.Split(location, ":") + if len(parts) > 3 { + panic(fmt.Sprintf("provided invalid location %q", location)) + } else if len(parts) == 2 { + panic("invalid argument for --domains, use format or ::") + } + + // If the scheme and port weren't provided use the zero values + if len(parts) == 1 { + return append(proxies, proxy{"", parts[0], 0}) + } + + if port, err := strconv.Atoi(parts[2]); err != nil { + panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err)) + } else if port <= 0 || port > 65535 { + panic(fmt.Sprintf("%d is an invalid port for local services", port)) + } else { + proxies = append(proxies, proxy{parts[0], parts[1], port}) + } + return proxies +} + +func extractServicePorts(proxies []proxy) map[string]map[string]int { + result := make(map[string]map[string]int, 2) + + for _, p := range proxies { + if p.protocol != "" && p.port != 0 { + hostPorts := result[p.protocol] + if hostPorts == nil { + result[p.protocol] = make(map[string]int) + hostPorts = result[p.protocol] + } + + // Only HTTP and HTTPS allow us to determine the hostname from the request, so only + // those protocols support different ports for the same service. + if !httpRegexp.MatchString(p.protocol) || p.hostname == "" { + p.hostname = "*" + } + if port, ok := hostPorts[p.hostname]; ok && port != p.port { + panic(fmt.Sprintf("duplicate ports for %s://%s", p.protocol, p.hostname)) + } + hostPorts[p.hostname] = p.port + } + } + + // Make sure we have defaults for HTTPS and HTTP. + if result["https"] == nil { + result["https"] = make(map[string]int, 1) + } + if result["https"]["*"] == 0 { + result["https"]["*"] = 8443 + } + + if result["http"] == nil { + result["http"] = make(map[string]int, 1) + } + if result["http"]["*"] == 0 { + result["http"]["*"] = result["https"]["*"] + } + + return result +} + +func main() { + flag.Parse() + + proxies := make([]proxy, 0) + for _, option := range viper.GetStringSlice("locals") { + for _, location := range strings.Split(option, ",") { + proxies = addLocals(proxies, location) + } + } + for _, option := range viper.GetStringSlice("domains") { + for _, location := range strings.Split(option, ",") { + proxies = addDomains(proxies, location) + } + } + + servicePorts := extractServicePorts(proxies) + domainMap := make(map[string]bool) + for _, p := range proxies { + if p.hostname != "" && p.hostname != "*" { + domainMap[p.hostname] = true + } + } + + if viper.GetString("raw.stunneld") == "" { + panic("must provide remote RVPN server to connect to") + } + + var token string + if viper.GetString("raw.token") != "" { + token = viper.GetString("raw.token") + } else if viper.GetString("raw.secret") != "" { + domains := make([]string, 0, len(domainMap)) + for name := range domainMap { + domains = append(domains, name) + } + tokenData := jwt.MapClaims{"domains": domains} + + secret := []byte(viper.GetString("raw.secret")) + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) + if tokenStr, err := jwtToken.SignedString(secret); err != nil { + panic(err) + } else { + token = tokenStr + } + } else { + panic("must provide either token or secret") } ctx, quit := context.WithCancel(context.Background()) defer quit() config := client.Config{ - Server: "wss://localhost.daplie.me:9999", - Services: map[string]map[string]int{ - "https": map[string]int{ - "*": 8443, - "localhost.foo.daplie.me": 4443, - }, - "http": map[string]int{ - "*": 8443, - "localhost.foo.daplie.me": 4443, - }, - }, - Token: tokenStr, - Insecure: true, + Insecure: viper.GetBool("raw.insecure"), + Server: viper.GetString("raw.stunneld"), + Services: servicePorts, + Token: token, } panic(client.Run(ctx, &config)) }