package genericlistener import ( "bytes" "context" "crypto/tls" "encoding/hex" "log" "net" "strconv" "strings" "time" jwt "github.com/dgrijalva/jwt-go" "github.com/gorilla/mux" "github.com/gorilla/websocket" "net/http" "bufio" "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" ) type contextKey string //CtxConnectionTrack const ( ctxSecretKey contextKey = "secretKey" ctxServerStatus contextKey = "serverstatus" //ctxConnectionTable contextKey = "connectionTable" ctxConfig contextKey = "config" ctxListenerRegistration contextKey = "listenerRegistration" ctxConnectionTrack contextKey = "connectionTrack" ctxWssHostName contextKey = "wsshostname" ctxAdminHostName contextKey = "adminHostName" ctxCancelCheck contextKey = "cancelcheck" ctxLoadbalanceDefaultMethod contextKey = "lbdefaultmethod" ) const ( encryptNone int = iota encryptSSLV2 encryptSSLV3 encryptTLS10 encryptTLS11 encryptTLS12 ) //GenericListenAndServe -- used to lisen for any https traffic on 443 (8443) // - setup generic TCP listener, unencrypted TCP, with a Deadtime out // - leaverage the wedgeConn to peek into the buffer. // - if TLS, consume connection with TLS certbundle, pass to request identifier // - else, just pass to the request identififer func GenericListenAndServe(ctx context.Context, listenerRegistration *ListenerRegistration) { loginfo.Println(":" + string(listenerRegistration.port)) cancelCheck := ctx.Value(ctxCancelCheck).(int) listenAddr, err := net.ResolveTCPAddr("tcp", ":"+strconv.Itoa(listenerRegistration.port)) if nil != err { loginfo.Println(err) return } ln, err := net.ListenTCP("tcp", listenAddr) if err != nil { loginfo.Println("unable to bind", err) listenerRegistration.status = listenerFault listenerRegistration.err = err listenerRegistration.commCh <- listenerRegistration return } listenerRegistration.status = listenerAdded listenerRegistration.commCh <- listenerRegistration for { select { case <-ctx.Done(): loginfo.Println("Cancel signal hit") return default: ln.SetDeadline(time.Now().Add(time.Duration(cancelCheck) * time.Second)) conn, err := ln.Accept() if nil != err { if opErr, ok := err.(*net.OpError); ok && opErr.Timeout() { continue } log.Println(err) return } wedgeConn := NewWedgeConn(conn) go handleConnection(ctx, wedgeConn) } } } //handleConnection - // - accept a wedgeConnection along with all the other required attritvues // - peek into the buffer, determine TLS or unencrypted // - if TSL, then terminate with a TLS endpoint, pass to handleStream // - if clearText, pass to handleStream func handleConnection(ctx context.Context, wConn *WedgeConn) { defer wConn.Close() peekCnt := 10 encryptMode := encryptNone loginfo.Println("conn", wConn, wConn.LocalAddr().String(), wConn.RemoteAddr().String()) peek, err := wConn.Peek(peekCnt) if err != nil { loginfo.Println("error while peeking") return } //take a look for a TLS header. if bytes.Contains(peek[0:0], []byte{0x80}) && bytes.Contains(peek[2:4], []byte{0x01, 0x03}) { encryptMode = encryptSSLV2 } else if bytes.Contains(peek[0:3], []byte{0x16, 0x03, 0x00}) { encryptMode = encryptSSLV3 } else if bytes.Contains(peek[0:3], []byte{0x16, 0x03, 0x01}) { encryptMode = encryptTLS10 loginfo.Println("TLS10") } else if bytes.Contains(peek[0:3], []byte{0x16, 0x03, 0x02}) { encryptMode = encryptTLS11 } else if bytes.Contains(peek[0:3], []byte{0x16, 0x03, 0x03}) { encryptMode = encryptTLS12 } oneConn := &oneConnListener{wConn} config := ctx.Value(ctxConfig).(*tls.Config) if encryptMode == encryptSSLV2 { loginfo.Println("SSLv2 is not accepted") return } else if encryptMode != encryptNone { loginfo.Println("Handle Encryption") // check SNI heading // if matched, then looks like a WSS connection // else external don't pull off TLS. peek, err := wConn.PeekAll() if err != nil { loginfo.Println("error while peeking") loginfo.Println(hex.Dump(peek[0:])) return } wssHostName := ctx.Value(ctxWssHostName).(string) adminHostName := ctx.Value(ctxAdminHostName).(string) sniHostName, err := getHello(peek) if err != nil { loginfo.Println(err) return } loginfo.Println("sni:", sniHostName) if sniHostName == wssHostName { //handle WSS Path tlsListener := tls.NewListener(oneConn, config) conn, err := tlsListener.Accept() if err != nil { loginfo.Println(err) return } tlsWedgeConn := NewWedgeConn(conn) handleStream(ctx, tlsWedgeConn) return } else if sniHostName == adminHostName { // handle admin path tlsListener := tls.NewListener(oneConn, config) conn, err := tlsListener.Accept() if err != nil { loginfo.Println(err) return } tlsWedgeConn := NewWedgeConn(conn) handleStream(ctx, tlsWedgeConn) return } else { //traffic not terminating on the rvpn do not decrypt loginfo.Println("processing non terminating traffic", wssHostName, sniHostName) handleExternalHTTPRequest(ctx, wConn, sniHostName, "https") } } loginfo.Println("Handle Unencrypted") handleStream(ctx, wConn) return } //handleStream -- // - we have an unencrypted stream connection with the ability to peek // - attempt to identify HTTP // - handle http // - attempt to identify as WSS session // - attempt to identify as ADMIN/API session // - else handle as raw http // - handle other? func handleStream(ctx context.Context, wConn *WedgeConn) { loginfo.Println("handle Stream") loginfo.Println("conn", wConn.LocalAddr().String(), wConn.RemoteAddr().String()) peek, err := wConn.PeekAll() if err != nil { loginfo.Println("error while peeking", err) loginfo.Println(hex.Dump(peek[0:])) return } // HTTP Identifcation if bytes.Contains(peek[:], []byte{0x0d, 0x0a}) { //string protocol if bytes.ContainsAny(peek[:], "HTTP/") { loginfo.Println("identifed HTTP") r, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(peek))) if err != nil { loginfo.Println("identifed as HTTP, failed request parsing", err) return } // do we have a valid wss_client? secretKey := ctx.Value(ctxSecretKey).(string) tokenString := r.URL.Query().Get("access_token") result, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { return []byte(secretKey), nil }) if err == nil && result.Valid { loginfo.Println("Valid WSS dected...sending to handler") oneConn := &oneConnListener{wConn} handleWssClient(ctx, oneConn) //do we have a invalid domain indicating Admin? //if yes, prep the oneConn and send it to the handler } else if strings.Contains(r.Host, "rvpn.daplie.invalid") { loginfo.Println("admin") oneConn := &oneConnListener{wConn} handleAdminClient(ctx, oneConn) return } else { loginfo.Println("unsupported") loginfo.Println(hex.Dump(peek)) return } } } } //handleExternalHTTPRequest - // - get a wConn and start processing requests func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname string, service string) { //connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking) serverStatus := ctx.Value(ctxServerStatus).(*Status) defer func() { serverStatus.ExtConnectionUnregister(extConn) extConn.Close() }() //find the connection by domain name conn, ok := serverStatus.ConnectionTable.ConnByDomain(hostname) if !ok { //matching connection can not be found based on ConnByDomain loginfo.Println("unable to match ", hostname, " to an existing connection") //http.Error(, "Domain not supported", http.StatusBadRequest) return } track := NewTrack(extConn, hostname) serverStatus.ExtConnectionRegister(track) loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String()) rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String()) if err != nil { loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String()) return } for { buffer, err := extConn.PeekAll() if err != nil { loginfo.Println("unable to peekAll", err) return } loginfo.Println("Before Packer", hex.Dump(buffer)) cnt := len(buffer) p := packer.NewPacker() p.Header.SetAddress(rAddr) p.Header.Port, err = strconv.Atoi(rPort) if err != nil { loginfo.Println("Unable to set Remote port", err) return } p.Header.Service = service p.Data.AppendBytes(buffer[0:cnt]) buf := p.PackV1() //loginfo.Println(hex.Dump(buf.Bytes())) //Bundle up the send request and dispatch sendTrack := NewSendTrack(buf.Bytes(), hostname) serverStatus.SendExtRequest(conn, sendTrack) _, err = extConn.Discard(cnt) if err != nil { loginfo.Println("unable to discard", cnt, err) return } } } //handleWssClient - // - expecting an existing oneConnListener with a qualified wss client connected. // - auth will happen again since we were just peeking at the token. func handleWssClient(ctx context.Context, oneConn *oneConnListener) { secretKey := ctx.Value(ctxSecretKey).(string) serverStatus := ctx.Value(ctxServerStatus).(*Status) //connectionTable := ctx.Value(ctxConnectionTable).(*Table) router := mux.NewRouter().StrictSlash(true) router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { loginfo.Println("HandleFunc /") switch url := r.URL.Path; url { case "/": loginfo.Println("websocket opening ", r.RemoteAddr, " ", r.Host) tokenString := r.URL.Query().Get("access_token") result, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { return []byte(secretKey), nil }) if err != nil || !result.Valid { w.WriteHeader(http.StatusForbidden) w.Write([]byte("Not Authorized")) loginfo.Println("access_token invalid...closing connection") return } claims := result.Claims.(jwt.MapClaims) domains, ok := claims["domains"].([]interface{}) var upgrader = websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, } conn, err := upgrader.Upgrade(w, r, nil) if err != nil { loginfo.Println("WebSocket upgrade failed", err) return } loginfo.Println("before connection table") //newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains) //connectionTrack := ctx.Value(ctxConnectionTrack).(*Tracking) newRegistration := NewRegistration(conn, r.RemoteAddr, domains, serverStatus.ConnectionTracking) serverStatus.WSSConnectionRegister(newRegistration) ok = <-newRegistration.CommCh() if !ok { loginfo.Println("connection registration failed ", newRegistration) return } loginfo.Println("connection registration accepted ", newRegistration) } }) s := &http.Server{ Addr: ":80", Handler: router, } err := s.Serve(oneConn) if err != nil { loginfo.Println("Serve error: ", err) } select { case <-ctx.Done(): loginfo.Println("Cancel signal hit") return } }