From 2b48a8b8b944b3cd523ab9db7d4045ed7ac47e0e Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 9 Jun 2020 04:41:38 -0600 Subject: [PATCH] add full SNI peeking, fix TLS routing --- cmd/telebit/telebit.go | 4 +-- mplexer/conn.go | 2 -- mplexer/connwrap.go | 59 +++++++++++++++++++++++++++++++++--------- mplexer/listener.go | 3 ++- mplexer/routemux.go | 13 ++++++++-- mplexer/telebit.go | 42 +++++++++++++++++++++++++++--- 6 files changed, 100 insertions(+), 23 deletions(-) diff --git a/cmd/telebit/telebit.go b/cmd/telebit/telebit.go index 9474caf..b6ace7b 100644 --- a/cmd/telebit/telebit.go +++ b/cmd/telebit/telebit.go @@ -208,8 +208,8 @@ func main() { } mux.HandleTLS("*", acme, mux) for _, fwd := range forwards { - mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second) - //mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second) + //mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second) + mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second) } done := make(chan error) diff --git a/mplexer/conn.go b/mplexer/conn.go index 4ce53e3..dd490e0 100644 --- a/mplexer/conn.go +++ b/mplexer/conn.go @@ -83,13 +83,11 @@ func (c *Conn) LocalAddr() net.Addr { // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { - // TODO is this the right one? return &c.relaySourceAddr } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { - // TODO is this the right one? return &c.relayTargetAddr } diff --git a/mplexer/connwrap.go b/mplexer/connwrap.go index 94a73ed..6d956fc 100644 --- a/mplexer/connwrap.go +++ b/mplexer/connwrap.go @@ -2,17 +2,22 @@ package telebit import ( "bufio" + "fmt" "net" "time" + + "git.coolaj86.com/coolaj86/go-telebitd/sni" ) // ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn type ConnWrap struct { // TODO use io.MultiReader to unbuffer the peeker //Conn net.Conn - peeker *bufio.Reader - Conn net.Conn - Plain net.Conn + peeker *bufio.Reader + servername string + scheme string + Conn net.Conn + Plain net.Conn } type Peeker interface { @@ -60,11 +65,19 @@ func (c *ConnWrap) Close() error { // Scheme returns one of "https", "http", "tcp", "tls", or "" func (c *ConnWrap) Scheme() string { - if nil != c.Plain { - tlsConn := &ConnWrap{Conn: c.Plain} - return tlsConn.Scheme() + if "" != c.scheme { + return c.scheme } + /* + if nil != c.Plain { + tlsConn := &ConnWrap{Conn: c.Plain} + // TODO upgrade tls+http => https + c.scheme = tlsConn.Scheme() + return c.scheme + } + */ + switch conn := c.Conn.(type) { case *ConnWrap: return conn.Scheme() @@ -74,20 +87,34 @@ func (c *ConnWrap) Scheme() string { return "" } +/* +func (c *ConnWrap) SetServername(name string) { + c.servername = name +} +*/ + // Servername may return Servername or Hostname as hinted by a tunnel or buffered peeking func (c *ConnWrap) Servername() string { + if "" != c.servername { + return c.servername + } + if nil != c.Plain { tlsConn := &ConnWrap{Conn: c.Plain} - return tlsConn.Servername() + c.servername = tlsConn.Servername() + return c.servername } switch conn := c.Conn.(type) { case *ConnWrap: - return conn.Scheme() + //c.servername = conn.Servername() + return conn.Servername() case *Conn: - return string(conn.relaySourceAddr.scheme) + // TODO XXX + //c.servername = string(conn.relayTargetAddr.addr) + return string(conn.relayTargetAddr.addr) } - return "" + return c.servername } // isTerminated returns true if net.Conn is either a ConnWrap{ tls.Conn }, @@ -104,16 +131,24 @@ func (c *ConnWrap) isTerminated() bool { c.SetDeadline(time.Now().Add(5 * time.Second)) n := 6 b, _ := c.Peek(n) + fmt.Println("Peek(n)", b) defer c.SetDeadline(time.Time{}) if len(b) >= n { // SSL v3.x / TLS v1.x // 0: TLS Byte // 1: Major Version - // 2: 0-Indexed Minor Version + // 2: Minor Version - 1 // 3-4: Header Length + + // Payload // 5: TLS Client Hello Marker Byte if 0x16 == b[0] && 0x03 == b[1] && 0x01 == b[5] { - //length := (int(b[3]) << 8) + int(b[4]) + length := (int(b[3]) << 8) + int(b[4]) + b, err := c.Peek(n - 1 + length) + if nil != err { + return true + } + c.servername, _ = sni.GetHostname(b) return false } } diff --git a/mplexer/listener.go b/mplexer/listener.go index 412b616..bd9e2f2 100644 --- a/mplexer/listener.go +++ b/mplexer/listener.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/http" + "strings" ) // A Listener transforms a multiplexed websocket connection into individual net.Conn-like connections. @@ -76,7 +77,7 @@ func Serve(listener net.Listener, mux Handler) error { go func() { err = mux.Serve(client) if nil != err { - if io.EOF != err { + if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) { fmt.Printf("client could not be served: %q\n", err.Error()) } } diff --git a/mplexer/routemux.go b/mplexer/routemux.go index 930d0f9..8e5c303 100644 --- a/mplexer/routemux.go +++ b/mplexer/routemux.go @@ -2,6 +2,7 @@ package telebit import ( "fmt" + "io" "net" "strconv" "strings" @@ -65,7 +66,11 @@ func (m *RouteMux) Serve(client net.Conn) error { for _, meta := range m.routes { // TODO '*.example.com' - fmt.Println("Meta:", meta.addr) + if meta.terminate && "" == servername { + wconn.isTerminated() + servername = wconn.servername + } + fmt.Println("Meta:", meta.addr, servername) if servername == meta.addr || "*" == meta.addr || port == meta.addr { //fmt.Println("[debug] test of route:", meta) if err := meta.handler.Serve(wconn); nil != err { @@ -124,7 +129,11 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err //NewTerminator(acme, handler)(client) //return handler.Serve(client) - return handler.Serve(TerminateTLS(wconn, acme)) + err := handler.Serve(TerminateTLS(wconn, acme)) + if nil == err || io.EOF == err { + return io.EOF + } + return err }), }) return nil diff --git a/mplexer/telebit.go b/mplexer/telebit.go index c2fd790..fb67c2b 100644 --- a/mplexer/telebit.go +++ b/mplexer/telebit.go @@ -180,11 +180,15 @@ type ACME struct { var acmecert *certmagic.Config = nil -func NewTerminator(acme *ACME, handler Handler) HandlerFunc { +/* +func NewTerminator(servername string, acme *ACME, handler Handler) HandlerFunc { return func(client net.Conn) error { - return handler.Serve(TerminateTLS(client, acme)) + return handler.Serve(TerminateTLS("", client, acme)) } } +*/ + +//func TerminateTLS(client *ConnWrap, acme *ACME) net.Conn func TerminateTLS(client net.Conn, acme *ACME) net.Conn { var magic *certmagic.Config = nil @@ -231,10 +235,40 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn { }, } + var servername string + var scheme string + // I think this must always be ConnWrap, but I'm not sure + switch conn := client.(type) { + case *ConnWrap: + servername = conn.Servername() + scheme = conn.Scheme() + client = conn + default: + wconn := &ConnWrap{ + Conn: client, + } + wconn.isTerminated() + servername = wconn.Servername() + scheme = wconn.Scheme() + client = wconn + } + + /* + // TODO ? + if "" == scheme { + scheme = "tls" + } + if "http" == scheme { + scheme = "https" + } + */ + tlsconn := tls.Server(client, tlsConfig) return &ConnWrap{ - Conn: tlsconn, - Plain: client, + Conn: tlsconn, + Plain: client, + servername: servername, + scheme: scheme, } }