diff --git a/go-rvpn-client/main.go b/go-rvpn-client/main.go new file mode 100644 index 0000000..9ab43bc --- /dev/null +++ b/go-rvpn-client/main.go @@ -0,0 +1,208 @@ +package main + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + jwt "github.com/dgrijalva/jwt-go" + flag "github.com/spf13/pflag" + "github.com/spf13/viper" + + "git.daplie.com/Daplie/go-rvpn-server/rvpn/client" +) + +var httpRegexp = regexp.MustCompile(`(?i)^http`) + +func init() { + flag.StringSlice("locals", []string{}, "comma separated list of : or "+ + ":: to which matching incoming connections should forward. "+ + "Ex: smtps:8465,https:example.com:8443") + flag.StringSlice("domains", []string{}, "comma separated list of domain names to set to the tunnel") + viper.BindPFlag("locals", flag.Lookup("locals")) + viper.BindPFlag("domains", flag.Lookup("domains")) + + flag.BoolP("insecure", "k", false, "Allow TLS connections to stunneld without valid certs") + flag.String("stunneld", "", "the domain (or ip address) at which the RVPN server is running") + flag.String("secret", "", "the same secret used by stunneld (used for JWT authentication)") + flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)") + viper.BindPFlag("raw.insecure", flag.Lookup("insecure")) + viper.BindPFlag("raw.stunneld", flag.Lookup("stunneld")) + viper.BindPFlag("raw.secret", flag.Lookup("secret")) + viper.BindPFlag("raw.token", flag.Lookup("token")) +} + +type proxy struct { + protocol string + hostname string + port int +} + +func addLocals(proxies []proxy, location string) []proxy { + parts := strings.Split(location, ":") + if len(parts) > 3 { + panic(fmt.Sprintf("provided invalid location %q", location)) + } + + // If all that was provided as a "local" is the domain name we assume that domain + // has HTTP and HTTPS handlers on the default ports. + if len(parts) == 1 { + proxies = append(proxies, proxy{"http", parts[0], 80}) + proxies = append(proxies, proxy{"https", parts[0], 443}) + return proxies + } + + // Make everything lower case and trim any slashes in something like https://john.example.com + parts[0] = strings.ToLower(parts[0]) + parts[1] = strings.ToLower(strings.Trim(parts[1], "/")) + + if len(parts) == 2 { + if strings.Contains(parts[1], ".") { + if parts[0] == "http" { + parts = append(parts, "80") + } else if parts[0] == "https" { + parts = append(parts, "443") + } else { + panic(fmt.Sprintf("port must be specified for %q", location)) + } + } else { + // https:3443 -> https:*:3443 + parts = []string{parts[0], "*", parts[1]} + } + } + + if port, err := strconv.Atoi(parts[2]); err != nil { + panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err)) + } else if port <= 0 || port > 65535 { + panic(fmt.Sprintf("%d is an invalid port for local services", port)) + } else { + proxies = append(proxies, proxy{parts[0], parts[1], port}) + } + return proxies +} + +func addDomains(proxies []proxy, location string) []proxy { + parts := strings.Split(location, ":") + if len(parts) > 3 { + panic(fmt.Sprintf("provided invalid location %q", location)) + } else if len(parts) == 2 { + panic("invalid argument for --domains, use format or ::") + } + + // If the scheme and port weren't provided use the zero values + if len(parts) == 1 { + return append(proxies, proxy{"", parts[0], 0}) + } + + if port, err := strconv.Atoi(parts[2]); err != nil { + panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err)) + } else if port <= 0 || port > 65535 { + panic(fmt.Sprintf("%d is an invalid port for local services", port)) + } else { + proxies = append(proxies, proxy{parts[0], parts[1], port}) + } + return proxies +} + +func extractServicePorts(proxies []proxy) map[string]map[string]int { + result := make(map[string]map[string]int, 2) + + for _, p := range proxies { + if p.protocol != "" && p.port != 0 { + hostPorts := result[p.protocol] + if hostPorts == nil { + result[p.protocol] = make(map[string]int) + hostPorts = result[p.protocol] + } + + // Only HTTP and HTTPS allow us to determine the hostname from the request, so only + // those protocols support different ports for the same service. + if !httpRegexp.MatchString(p.protocol) || p.hostname == "" { + p.hostname = "*" + } + if port, ok := hostPorts[p.hostname]; ok && port != p.port { + panic(fmt.Sprintf("duplicate ports for %s://%s", p.protocol, p.hostname)) + } + hostPorts[p.hostname] = p.port + } + } + + // Make sure we have defaults for HTTPS and HTTP. + if result["https"] == nil { + result["https"] = make(map[string]int, 1) + } + if result["https"]["*"] == 0 { + result["https"]["*"] = 8443 + } + + if result["http"] == nil { + result["http"] = make(map[string]int, 1) + } + if result["http"]["*"] == 0 { + result["http"]["*"] = result["https"]["*"] + } + + return result +} + +func main() { + flag.Parse() + + proxies := make([]proxy, 0) + for _, option := range viper.GetStringSlice("locals") { + for _, location := range strings.Split(option, ",") { + proxies = addLocals(proxies, location) + } + } + for _, option := range viper.GetStringSlice("domains") { + for _, location := range strings.Split(option, ",") { + proxies = addDomains(proxies, location) + } + } + + servicePorts := extractServicePorts(proxies) + domainMap := make(map[string]bool) + for _, p := range proxies { + if p.hostname != "" && p.hostname != "*" { + domainMap[p.hostname] = true + } + } + + if viper.GetString("raw.stunneld") == "" { + panic("must provide remote RVPN server to connect to") + } + + var token string + if viper.GetString("raw.token") != "" { + token = viper.GetString("raw.token") + } else if viper.GetString("raw.secret") != "" { + domains := make([]string, 0, len(domainMap)) + for name := range domainMap { + domains = append(domains, name) + } + tokenData := jwt.MapClaims{"domains": domains} + + secret := []byte(viper.GetString("raw.secret")) + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) + if tokenStr, err := jwtToken.SignedString(secret); err != nil { + panic(err) + } else { + token = tokenStr + } + } else { + panic("must provide either token or secret") + } + + ctx, quit := context.WithCancel(context.Background()) + defer quit() + + config := client.Config{ + Insecure: viper.GetBool("raw.insecure"), + Server: viper.GetString("raw.stunneld"), + Services: servicePorts, + Token: token, + } + panic(client.Run(ctx, &config)) +} diff --git a/rvpn/client/client.go b/rvpn/client/client.go new file mode 100644 index 0000000..a903305 --- /dev/null +++ b/rvpn/client/client.go @@ -0,0 +1,71 @@ +package client + +import ( + "context" + "crypto/tls" + "fmt" + "net/url" + "time" + + "github.com/gorilla/websocket" +) + +// The Config struct holds all of the information needed to establish and handle a connection +// with the RVPN server. +type Config struct { + Server string + Token string + Insecure bool + Services map[string]map[string]int +} + +// Run establishes a connection with the RVPN server specified in the config. If the first attempt +// to connect fails it is assumed that something is wrong with the authentication and it will +// return an error. Otherwise it will continuously attempt to reconnect whenever the connection +// is broken. +func Run(ctx context.Context, config *Config) error { + serverURL, err := url.Parse(config.Server) + if err != nil { + return fmt.Errorf("Invalid server URL: %v", err) + } + if serverURL.Scheme == "" { + serverURL.Scheme = "wss" + } + serverURL.Path = "" + + query := make(url.Values) + query.Set("access_token", config.Token) + serverURL.RawQuery = query.Encode() + + dialer := websocket.Dialer{} + if config.Insecure { + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + + for name, portList := range config.Services { + if _, ok := portList["*"]; !ok { + return fmt.Errorf(`service %s missing port for "*"`, name) + } + } + handler := NewWsHandler(config.Services) + + authenticated := false + for { + if conn, _, err := dialer.Dial(serverURL.String(), nil); err == nil { + loginfo.Println("connected to remote server") + authenticated = true + handler.HandleConn(ctx, conn) + } else if !authenticated { + return fmt.Errorf("First connection to server failed - check auth: %v", err) + } + loginfo.Println("disconnected from remote server") + + // Sleep for a few seconds before trying again, but only if the context is still active + select { + case <-ctx.Done(): + return nil + case <-time.After(5 * time.Second): + } + loginfo.Println("attempting reconnect to remote server") + } +} diff --git a/rvpn/client/setup.go b/rvpn/client/setup.go new file mode 100644 index 0000000..e08f076 --- /dev/null +++ b/rvpn/client/setup.go @@ -0,0 +1,15 @@ +package client + +import ( + "log" + "os" +) + +const ( + logFlags = log.Ldate | log.Lmicroseconds | log.Lshortfile +) + +var ( + loginfo = log.New(os.Stdout, "INFO: client: ", logFlags) + logdebug = log.New(os.Stdout, "DEBUG: client:", logFlags) +) diff --git a/rvpn/client/ws_handler.go b/rvpn/client/ws_handler.go new file mode 100644 index 0000000..7501d88 --- /dev/null +++ b/rvpn/client/ws_handler.go @@ -0,0 +1,233 @@ +package client + +import ( + "context" + "fmt" + "io" + "net" + "regexp" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" + "git.daplie.com/Daplie/go-rvpn-server/rvpn/sni" +) + +var hostRegexp = regexp.MustCompile(`(?im)(?:^|[\r\n])Host: *([^\r\n]+)[\r\n]`) + +// WsHandler handles all of reading and writing for the websocket connection to the RVPN server +// and the TCP connections to the local servers. +type WsHandler struct { + lock sync.Mutex + localConns map[string]net.Conn + + servicePorts map[string]map[string]int + + ctx context.Context + dataChan chan *packer.Packer +} + +// NewWsHandler creates a new handler ready to be given a websocket connection. The services +// argument specifies what port each service type should be directed to on the local interface. +func NewWsHandler(services map[string]map[string]int) *WsHandler { + h := new(WsHandler) + h.servicePorts = services + h.localConns = make(map[string]net.Conn) + return h +} + +// HandleConn handles all of the traffic on the provided websocket connection. The function +// will not return until the connection ends. +// +// The WsHandler is designed to handle exactly one connection at a time. If HandleConn is called +// again while the instance is still handling another connection (or if the previous connection +// failed to fully cleanup) calling HandleConn again will panic. +func (h *WsHandler) HandleConn(ctx context.Context, conn *websocket.Conn) { + if h.dataChan != nil { + panic("WsHandler.HandleConn called while handling a previous connection") + } + if len(h.localConns) > 0 { + panic(fmt.Sprintf("WsHandler has lingering local connections: %v", h.localConns)) + } + h.dataChan = make(chan *packer.Packer) + + // The sub context allows us to clean up all of the goroutines associated with this websocket + // if it closes at any point for any reason. + subCtx, socketQuit := context.WithCancel(ctx) + defer socketQuit() + h.ctx = subCtx + + // Start the routine that will write all of the data from the local connection to the + // remote websocket connection. + go h.writeRemote(conn) + + for { + _, message, err := conn.ReadMessage() + if err != nil { + loginfo.Println("failed to read message from websocket", err) + return + } + + p, err := packer.ReadMessage(message) + if err != nil { + loginfo.Println("failed to parse message from websocket", err) + return + } + + h.writeLocal(p) + } +} + +func (h *WsHandler) writeRemote(conn *websocket.Conn) { + defer h.closeConnections() + defer func() { h.dataChan = nil }() + + for { + select { + case <-h.ctx.Done(): + // We can't tell if this happened because the websocket is already closed/errored or + // if it happened because the main context closed (in which case it would be preferable + // to properly close the connection). As such we try to close the connection and ignore + // all errors if it doesn't work. + message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection") + deadline := time.Now().Add(10 * time.Second) + conn.WriteControl(websocket.CloseMessage, message, deadline) + conn.Close() + return + + case p := <-h.dataChan: + packed := p.PackV1() + conn.WriteMessage(websocket.BinaryMessage, packed.Bytes()) + } + } +} + +func (h *WsHandler) sendPackedMessage(header *packer.Header, data []byte, service string) { + p := packer.NewPacker(header) + if len(data) > 0 { + p.Data.AppendBytes(data) + } + if service != "" { + p.SetService(service) + } + + // Avoid blocking on the data channel if the websocket closes or is already closed + select { + case h.dataChan <- p: + case <-h.ctx.Done(): + } +} + +func (h *WsHandler) getLocalConn(p *packer.Packer) net.Conn { + h.lock.Lock() + defer h.lock.Unlock() + + key := fmt.Sprintf("%s:%d", p.Address(), p.Port()) + // Simplest case: it's already open, just return it. + if conn := h.localConns[key]; conn != nil { + return conn + } + + service := strings.ToLower(p.Service()) + portList := h.servicePorts[service] + if portList == nil { + loginfo.Println("cannot open connection for invalid service", service) + return nil + } + + var hostname string + if service == "http" { + if match := hostRegexp.FindSubmatch(p.Data.Data()); match != nil { + hostname = strings.Split(string(match[1]), ":")[0] + } + } else if service == "https" { + hostname, _ = sni.GetHostname(p.Data.Data()) + } else { + hostname = "*" + } + if hostname == "" { + loginfo.Println("missing servername for", service, key) + return nil + } + hostname = strings.ToLower(hostname) + + port := portList[hostname] + if port == 0 { + port = portList["*"] + } + if port == 0 { + loginfo.Println("unable to determine local port for", service, hostname) + return nil + } + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + loginfo.Println("unable to open local connection on port", port, err) + return nil + } + + h.localConns[key] = conn + loginfo.Printf("new client %q for %s:%d (%d clients)\n", key, hostname, port, len(h.localConns)) + go h.readLocal(key, &p.Header) + return conn +} + +func (h *WsHandler) writeLocal(p *packer.Packer) { + conn := h.getLocalConn(p) + if conn == nil { + h.sendPackedMessage(&p.Header, nil, "error") + return + } + + if p.Service() == "error" || p.Service() == "end" { + conn.Close() + return + } + + if _, err := conn.Write(p.Data.Data()); err != nil { + h.sendPackedMessage(&p.Header, nil, "error") + loginfo.Println("failed to write to local connection", err) + } +} + +func (h *WsHandler) readLocal(key string, header *packer.Header) { + h.lock.Lock() + conn := h.localConns[key] + h.lock.Unlock() + + defer conn.Close() + defer func() { + h.lock.Lock() + delete(h.localConns, key) + loginfo.Printf("closing client %q: (%d clients)\n", key, len(h.localConns)) + h.lock.Unlock() + }() + + buf := make([]byte, 4096) + for { + size, err := conn.Read(buf) + if err != nil { + if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") { + h.sendPackedMessage(header, nil, "end") + } else { + loginfo.Println("failed to read from local connection for", key, err) + h.sendPackedMessage(header, nil, "error") + } + return + } + + h.sendPackedMessage(header, buf[:size], "") + } +} + +func (h *WsHandler) closeConnections() { + h.lock.Lock() + defer h.lock.Unlock() + + for _, conn := range h.localConns { + conn.Close() + } +} diff --git a/rvpn/packer/packer.go b/rvpn/packer/packer.go index 7f4af1b..3948bf2 100644 --- a/rvpn/packer/packer.go +++ b/rvpn/packer/packer.go @@ -16,16 +16,19 @@ const ( //Packer -- contains both header and data type Packer struct { - Header *packerHeader - Data *packerData + Header + Data packerData } -//NewPacker -- Structre -func NewPacker() (p *Packer) { - p = new(Packer) - p.Header = newPackerHeader() - p.Data = newPackerData() - return +// NewPacker creates a new Packer struct using the information from the provided header as +// its own header. (Because the header is stored directly and not as a pointer/reference +// it should be safe to override items like the service without affecting the template header.) +func NewPacker(header *Header) *Packer { + p := new(Packer) + if header != nil { + p.Header = *header + } + return p } func splitHeader(header []byte, names []string) (map[string]string, error) { @@ -48,7 +51,7 @@ func ReadMessage(b []byte) (*Packer, error) { // Detect protocol in use if b[0] == packerV1 { // Separate the header and body using the header length in the second byte. - p := NewPacker() + p := NewPacker(nil) header := b[2 : b[1]+2] data := b[b[1]+2:] @@ -69,8 +72,8 @@ func ReadMessage(b []byte) (*Packer, error) { p.Header.address = net.ParseIP(parts["address"]) if p.Header.address == nil { return nil, fmt.Errorf("Invalid network address %q", parts["address"]) - } else if p.Header.Family() == FamilyIPv4 && p.Header.address.To4() == nil { - return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.FamilyText()) + } else if p.Header.family == FamilyIPv4 && p.Header.address.To4() == nil { + return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.Family()) } //handle port @@ -79,7 +82,7 @@ func ReadMessage(b []byte) (*Packer, error) { } else if port <= 0 || port > 65535 { return nil, fmt.Errorf("Port %d out of range", port) } else { - p.Header.Port = port + p.Header.port = port } //handle data length @@ -90,7 +93,7 @@ func ReadMessage(b []byte) (*Packer, error) { } //handle Service - p.Header.Service = parts["service"] + p.Header.service = parts["service"] //handle payload p.Data.AppendBytes(data) @@ -103,11 +106,11 @@ func ReadMessage(b []byte) (*Packer, error) { //PackV1 -- Outputs version 1 of packer func (p *Packer) PackV1() bytes.Buffer { header := strings.Join([]string{ - p.Header.FamilyText(), - p.Header.AddressString(), - strconv.Itoa(p.Header.Port), + p.Header.Family(), + p.Header.Address(), + strconv.Itoa(p.Header.Port()), strconv.Itoa(p.Data.DataLen()), - p.Header.Service, + p.Header.Service(), }, ",") var buf bytes.Buffer diff --git a/rvpn/packer/packer_data.go b/rvpn/packer/packer_data.go index 08775ad..3dfbadc 100644 --- a/rvpn/packer/packer_data.go +++ b/rvpn/packer/packer_data.go @@ -9,10 +9,6 @@ type packerData struct { buffer bytes.Buffer } -func newPackerData() *packerData { - return new(packerData) -} - func (p *packerData) AppendString(dataString string) (int, error) { return p.buffer.WriteString(dataString) } diff --git a/rvpn/packer/packer_header.go b/rvpn/packer/packer_header.go index be7cbb2..668cec4 100644 --- a/rvpn/packer/packer_header.go +++ b/rvpn/packer/packer_header.go @@ -7,12 +7,15 @@ import ( type addressFamily int -// packerHeader structure to hold our header information. -type packerHeader struct { +// The Header struct holds most of the information contained in the header for packets +// between the client and the server (the length of the data is not included here). It +// is used to uniquely identify remote connections on the servers end and to communicate +// which service the remote client is trying to connect to. +type Header struct { family addressFamily address net.IP - Port int - Service string + port int + service string } //Family -- ENUM for Address Family @@ -26,16 +29,19 @@ var addressFamilyText = [...]string{ "IPv6", } -func newPackerHeader() (p *packerHeader) { - p = new(packerHeader) - p.SetAddress("127.0.0.1") - p.Port = 65535 - p.Service = "na" - return +// NewHeader create a new Header object. +func NewHeader(address string, port int, service string) (*Header, error) { + h := new(Header) + if err := h.setAddress(address); err != nil { + return nil, err + } + h.port = port + h.service = service + return h, nil } -//SetAddress -- Set Address. which sets address family automatically -func (p *packerHeader) SetAddress(addr string) { +// setAddress parses the provided address string and automatically sets the IP family. +func (p *Header) setAddress(addr string) error { p.address = net.ParseIP(addr) if p.address.To4() != nil { @@ -43,30 +49,33 @@ func (p *packerHeader) SetAddress(addr string) { } else if p.address.To16() != nil { p.family = FamilyIPv6 } else { - panic(fmt.Sprintf("setAddress does not support %q", addr)) + return fmt.Errorf("invalid IP address %q", addr) } + return nil } -func (p *packerHeader) AddressBytes() []byte { - if ip4 := p.address.To4(); ip4 != nil { - p.address = ip4 - } - - return []byte(p.address) +// Family returns the string corresponding to the address's IP family. +func (p *Header) Family() string { + return addressFamilyText[p.family] } -func (p *packerHeader) AddressString() string { +// Address returns the string form of the header's remote address. +func (p *Header) Address() string { return p.address.String() } -func (p *packerHeader) Address() net.IP { - return p.address +// Port returns the connected port of the remote connection. +func (p *Header) Port() int { + return p.port } -func (p *packerHeader) Family() addressFamily { - return p.family +// SetService overrides the header's original service. This is primarily useful +// for sending 'error' and 'end' messages. +func (p *Header) SetService(service string) { + p.service = service } -func (p *packerHeader) FamilyText() string { - return addressFamilyText[p.family] +// Service returns the service stored in the header. +func (p *Header) Service() string { + return p.service } diff --git a/rvpn/server/api_collect_status_dead time.go b/rvpn/server/api_collect_status_dead_time.go similarity index 100% rename from rvpn/server/api_collect_status_dead time.go rename to rvpn/server/api_collect_status_dead_time.go diff --git a/rvpn/server/connection.go b/rvpn/server/connection.go index a2ab6f2..7048870 100755 --- a/rvpn/server/connection.go +++ b/rvpn/server/connection.go @@ -2,8 +2,8 @@ package server import ( "context" + "fmt" "io" - "strconv" "sync" "time" @@ -252,7 +252,7 @@ func (c *Connection) Reader(ctx context.Context) { // unpack the message. p, err := packer.ReadMessage(message) - key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port) + key := fmt.Sprintf("%s:%d", p.Address(), p.Port()) track, err := connectionTrack.Lookup(key) //loginfo.Println(hex.Dump(p.Data.Data())) diff --git a/rvpn/server/listener_generic.go b/rvpn/server/listener_generic.go index 201d2ee..0f68e52 100644 --- a/rvpn/server/listener_generic.go +++ b/rvpn/server/listener_generic.go @@ -18,6 +18,7 @@ import ( "github.com/gorilla/websocket" "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" + "git.daplie.com/Daplie/go-rvpn-server/rvpn/sni" ) type contextKey string @@ -160,7 +161,7 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) { wssHostName := ctx.Value(ctxWssHostName).(string) adminHostName := ctx.Value(ctxAdminHostName).(string) - sniHostName, err := getHello(peek) + sniHostName, err := sni.GetHostname(peek) if err != nil { loginfo.Println(err) return @@ -292,11 +293,19 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname track := NewTrack(extConn, hostname) serverStatus.ExtConnectionRegister(track) - loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String()) + remoteStr := extConn.RemoteAddr().String() + loginfo.Println("Domain Accepted", hostname, remoteStr) - rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String()) - if err != nil { - loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String()) + var header *packer.Header + if rAddr, rPort, err := net.SplitHostPort(remoteStr); err != nil { + loginfo.Println("unable to decode hostport", remoteStr, err) + } else if port, err := strconv.Atoi(rPort); err != nil { + loginfo.Printf("unable to parse port string %q: %v\n", rPort, err) + } else if header, err = packer.NewHeader(rAddr, port, service); err != nil { + loginfo.Println("unable to create packer header", err) + } + + if header == nil { return } @@ -309,18 +318,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname loginfo.Println("Before Packer", hex.Dump(buffer)) - cnt := len(buffer) - - p := packer.NewPacker() - p.Header.SetAddress(rAddr) - p.Header.Port, err = strconv.Atoi(rPort) - if err != nil { - loginfo.Println("Unable to set Remote port", err) - return - } - - p.Header.Service = service - p.Data.AppendBytes(buffer[0:cnt]) + p := packer.NewPacker(header) + p.Data.AppendBytes(buffer) buf := p.PackV1() //loginfo.Println(hex.Dump(buf.Bytes())) @@ -329,8 +328,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname sendTrack := NewSendTrack(buf.Bytes(), hostname) serverStatus.SendExtRequest(conn, sendTrack) - _, err = extConn.Discard(cnt) - if err != nil { + cnt := len(buffer) + if _, err = extConn.Discard(cnt); err != nil { loginfo.Println("unable to discard", cnt, err) return } diff --git a/rvpn/server/manager.go b/rvpn/server/manager.go index 17a6b20..f3033f9 100644 --- a/rvpn/server/manager.go +++ b/rvpn/server/manager.go @@ -52,7 +52,7 @@ type servers struct { secretKey string certbundle tls.Certificate register chan *ListenerRegistration - servers *servers + servers *servers wssHostName string adminHostName string cancelCheck int diff --git a/rvpn/server/tls_get_hello.go b/rvpn/sni/tls_get_hostname.go similarity index 91% rename from rvpn/server/tls_get_hello.go rename to rvpn/sni/tls_get_hostname.go index 8aa78c4..606fdf7 100644 --- a/rvpn/server/tls_get_hello.go +++ b/rvpn/sni/tls_get_hostname.go @@ -1,10 +1,11 @@ -package server +package sni import ( "errors" ) -func getHello(b []byte) (string, error) { +// GetHostname uses SNI to determine the intended target of a new TLS connection. +func GetHostname(b []byte) (string, error) { rest := b[5:] current := 0 handshakeType := rest[0]