From 5334649fba8d0b620e786ab690811b119b7ef389 Mon Sep 17 00:00:00 2001 From: Henry Camacho Date: Fri, 10 Mar 2017 23:36:42 -0600 Subject: [PATCH] got non-terminating traffic identified, and used SNI to figure direction --- main.go | 3 +- rvpn/genericlistener/listener_generic.go | 97 ++++++++++++++---------- rvpn/genericlistener/manager.go | 5 +- rvpn/genericlistener/tls_get_hello.go | 69 +++++++++++++++++ 4 files changed, 131 insertions(+), 43 deletions(-) create mode 100644 rvpn/genericlistener/tls_get_hello.go diff --git a/main.go b/main.go index a62f1a0..c53838d 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,7 @@ var ( argDeadTime int connectionTable *genericlistener.Table secretKey = "abc123" + wssHostName = "localhost.daplie.me" ) func init() { @@ -70,7 +71,7 @@ func main() { connectionTable = genericlistener.NewTable() go connectionTable.Run(ctx) - genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime) + genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime, wssHostName) go genericListeners.Run(ctx, 9999) //Run for 10 minutes and then shutdown cleanly diff --git a/rvpn/genericlistener/listener_generic.go b/rvpn/genericlistener/listener_generic.go index 3885537..7cefc7a 100644 --- a/rvpn/genericlistener/listener_generic.go +++ b/rvpn/genericlistener/listener_generic.go @@ -34,6 +34,7 @@ const ( ctxDeadTime contextKey = "deadtime" ctxListenerRegistration contextKey = "listenerRegistration" ctxConnectionTrack contextKey = "connectionTrack" + ctxWssHostName contextKey = "wsshostname" ) const ( @@ -146,6 +147,35 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) { } else if encryptMode != encryptNone { loginfo.Println("Handle Encryption") + + // check SNI heading + // if matched, then looks like a WSS connection + // else external don't pull off TLS. + + peek, err := wConn.PeekAll() + if err != nil { + loginfo.Println("error while peeking") + loginfo.Println(hex.Dump(peek[0:])) + return + } + + wssHostName := ctx.Value(ctxWssHostName).(string) + sniHostName, err := getHello(peek) + if err != nil { + loginfo.Println(err) + return + } + + loginfo.Println("sni:", sniHostName) + + if wssHostName != sniHostName { + //traffic not terminating on the rvpn do not decrypt + loginfo.Println("processing non terminating traffic") + handleExternalHTTPRequest(ctx, wConn, sniHostName) + } + + loginfo.Println("processing traffic terminating on RVPN") + tlsListener := tls.NewListener(oneConn, config) conn, err := tlsListener.Accept() @@ -175,11 +205,11 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) { // - handle other? func handleStream(ctx context.Context, wConn *WedgeConn) { loginfo.Println("handle Stream") - loginfo.Println("conn", wConn, wConn.LocalAddr().String(), wConn.RemoteAddr().String()) + loginfo.Println("conn", wConn.LocalAddr().String(), wConn.RemoteAddr().String()) peek, err := wConn.PeekAll() if err != nil { - loginfo.Println("error while peeking") + loginfo.Println("error while peeking", err) loginfo.Println(hex.Dump(peek[0:])) return } @@ -217,8 +247,8 @@ func handleStream(ctx context.Context, wConn *WedgeConn) { return } else { - loginfo.Println("default connection") - handleExternalHTTPRequest(ctx, wConn) + loginfo.Println("unsupported") + loginfo.Println(hex.Dump(peek)) return } } @@ -227,7 +257,7 @@ func handleStream(ctx context.Context, wConn *WedgeConn) { //handleExternalHTTPRequest - // - get a wConn and start processing requests -func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { +func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn, hostname string) { connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking) defer func() { @@ -236,6 +266,25 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { }() connectionTable := ctx.Value(ctxConnectionTable).(*Table) + //find the connection by domain name + conn, ok := connectionTable.ConnByDomain(hostname) + if !ok { + //matching connection can not be found based on ConnByDomain + loginfo.Println("unable to match ", hostname, " to an existing connection") + //http.Error(, "Domain not supported", http.StatusBadRequest) + return + } + + track := NewTrack(extConn, hostname) + connectionTracking.register <- track + + loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String()) + + rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String()) + if err != nil { + loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String()) + return + } var buffer [512]byte for { @@ -244,42 +293,6 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { return } - readBuffer := bytes.NewBuffer(buffer[0:cnt]) - reader := bufio.NewReader(readBuffer) - r, err := http.ReadRequest(reader) - - if err != nil { - loginfo.Println("error parsing request") - return - } - - hostname := r.Host - loginfo.Println("Host: ", hostname) - - if strings.Contains(hostname, ":") { - arr := strings.Split(hostname, ":") - hostname = arr[0] - } - - loginfo.Println("Remote: ", extConn.RemoteAddr().String()) - - remoteSplit := strings.Split(extConn.RemoteAddr().String(), ":") - rAddr := remoteSplit[0] - rPort := remoteSplit[1] - - //find the connection by domain name - conn, ok := connectionTable.ConnByDomain(hostname) - if !ok { - //matching connection can not be found based on ConnByDomain - loginfo.Println("unable to match ", hostname, " to an existing connection") - //http.Error(, "Domain not supported", http.StatusBadRequest) - return - } - - track := NewTrack(extConn, hostname) - connectionTracking.register <- track - - loginfo.Println("Domain Accepted", conn, rAddr, rPort) p := packer.NewPacker() p.Header.SetAddress(rAddr) p.Header.Port, err = strconv.Atoi(rPort) @@ -292,6 +305,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { p.Data.AppendBytes(buffer[0:cnt]) buf := p.PackV1() + loginfo.Println(hex.Dump(buf.Bytes())) + sendTrack := NewSendTrack(buf.Bytes(), hostname) conn.SendCh() <- sendTrack } diff --git a/rvpn/genericlistener/manager.go b/rvpn/genericlistener/manager.go index 9abda26..f31c305 100644 --- a/rvpn/genericlistener/manager.go +++ b/rvpn/genericlistener/manager.go @@ -54,10 +54,11 @@ type GenericListeners struct { deadTime int register chan *ListenerRegistration genericListeners *GenericListeners + wssHostName string } //NewGenerListeners -- -func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int) (p *GenericListeners) { +func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int, wssHostName string) (p *GenericListeners) { p = new(GenericListeners) p.listeners = make(map[*net.Listener]int) p.ctx = ctx @@ -67,6 +68,7 @@ func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTr p.certbundle = certbundle p.deadTime = deadTime p.register = make(chan *ListenerRegistration) + p.wssHostName = wssHostName return } @@ -87,6 +89,7 @@ func (gl *GenericListeners) Run(ctx context.Context, initialPort int) { ctx = context.WithValue(ctx, ctxConfig, config) ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime) ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register) + ctx = context.WithValue(ctx, ctxWssHostName, gl.wssHostName) go func(ctx context.Context) { for { diff --git a/rvpn/genericlistener/tls_get_hello.go b/rvpn/genericlistener/tls_get_hello.go new file mode 100644 index 0000000..0ddb051 --- /dev/null +++ b/rvpn/genericlistener/tls_get_hello.go @@ -0,0 +1,69 @@ +package genericlistener + +import "errors" + +func getHello(b []byte) (string, error) { + rest := b[5:] + current := 0 + handshakeType := rest[0] + current++ + if handshakeType != 0x1 { + return "", errors.New("Not a ClientHello") + } + + // Skip over another length + current += 3 + // Skip over protocolversion + current += 2 + // Skip over random number + current += 4 + 28 + // Skip over session ID + sessionIDLength := int(rest[current]) + current++ + current += sessionIDLength + + cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1]) + current += 2 + current += cipherSuiteLength + + compressionMethodLength := int(rest[current]) + current++ + current += compressionMethodLength + + if current > len(rest) { + return "", errors.New("no extensions") + } + + current += 2 + + hostname := "" + for current < len(rest) && hostname == "" { + extensionType := (int(rest[current]) << 8) + int(rest[current+1]) + current += 2 + + extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1]) + current += 2 + + if extensionType == 0 { + + // Skip over number of names as we're assuming there's just one + current += 2 + + nameType := rest[current] + current++ + if nameType != 0 { + return "", errors.New("Not a hostname") + } + nameLen := (int(rest[current]) << 8) + int(rest[current+1]) + current += 2 + hostname = string(rest[current : current+nameLen]) + } + + current += extensionDataLength + } + if hostname == "" { + return "", errors.New("No hostname") + } + return hostname, nil + +}