From 1872302d666a4a3dc7f3ba6e39cf24761e7c55eb Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Tue, 9 Jun 2020 02:42:56 -0600 Subject: [PATCH] better TLS termination, port-forward support --- cmd/telebit/telebit.go | 47 +++++++++++++++++++++++++++++++++++-- mplexer/addr.go | 13 ++++++----- mplexer/conn.go | 2 +- mplexer/connwrap.go | 52 ++++++++++++++++++++++++++--------------- mplexer/listener.go | 16 ++++++++----- mplexer/routemux.go | 53 ++++++++++++++++++++++++++++++++++++------ mplexer/telebit.go | 24 +++++++++++++++---- mplexer/v1.go | 1 + 8 files changed, 163 insertions(+), 45 deletions(-) diff --git a/cmd/telebit/telebit.go b/cmd/telebit/telebit.go index 1fc61ea..9474caf 100644 --- a/cmd/telebit/telebit.go +++ b/cmd/telebit/telebit.go @@ -48,6 +48,7 @@ type Forward struct { func main() { var domains []string var forwards []Forward + var portForwards []Forward // TODO replace the websocket connection with a mock server appID := flag.String("app-id", "telebit.io", "a unique identifier for a deploy target environment") @@ -65,6 +66,7 @@ func main() { token := flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)") bindAddrsStr := flag.String("listen", "", "list of bind addresses on which to listen, such as localhost:80, or :443") locals := flag.String("locals", "", "a list of :") + portToPorts := flag.String("port-forward", "", "a list of : for raw port-forwarding") flag.Parse() if len(os.Args) >= 2 { @@ -110,6 +112,16 @@ func main() { domains = append(domains, domain) } + if 0 == len(*portToPorts) { + *portToPorts = os.Getenv("PORT_FORWARDS") + } + portForwards, err := parsePortForwards(portToPorts) + if nil != err { + fmt.Fprintf(os.Stderr, "%s", err) + os.Exit(1) + return + } + bindAddrs, err := parseBindAddrs(*bindAddrsStr) if nil != err { fmt.Fprintf(os.Stderr, "invalid bind address(es) given to --listen\n") @@ -188,6 +200,12 @@ func main() { //mux := telebit.NewRouteMux(acme) mux := telebit.NewRouteMux() + + // Port forward without TerminatingTLS + for _, fwd := range portForwards { + fmt.Println("Fwd:", fwd.pattern, fwd.port) + mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second) + } mux.HandleTLS("*", acme, mux) for _, fwd := range forwards { mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second) @@ -260,6 +278,31 @@ func main() { } } +func parsePortForwards(portToPorts *string) ([]Forward, error) { + var portForwards []Forward + + for _, cfg := range strings.Fields(strings.ReplaceAll(*portToPorts, ",", " ")) { + parts := strings.Split(cfg, ":") + if 2 != len(parts) { + return nil, fmt.Errorf("--port-forward should be in the format 1234:5678, not %q", cfg) + } + + if _, err := strconv.Atoi(parts[0]); nil != err { + return nil, fmt.Errorf("couldn't parse port %q of %q", parts[0], cfg) + } + if _, err := strconv.Atoi(parts[1]); nil != err { + return nil, fmt.Errorf("couldn't parse port %q of %q", parts[1], cfg) + } + + portForwards = append(portForwards, Forward{ + pattern: ":" + parts[0], + port: parts[1], + }) + } + + return portForwards, nil +} + func parseBindAddrs(bindAddrsStr string) ([]string, error) { bindAddrs := []string{} @@ -386,8 +429,8 @@ func newAPIDNSProvider(baseURL string, token string) (*dns01.DNSProvider, error) t.ListenAndServe("wss://example.com", mux) */ -func getToken(secret string, domains []string) (token string, err error) { - tokenData := jwt.MapClaims{"domains": domains} +func getToken(secret string, domains, ports []string) (token string, err error) { + tokenData := jwt.MapClaims{"domains": domains, "ports": ports} jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) if token, err = jwtToken.SignedString([]byte(secret)); err != nil { diff --git a/mplexer/addr.go b/mplexer/addr.go index 1ca380c..1ffac69 100644 --- a/mplexer/addr.go +++ b/mplexer/addr.go @@ -1,9 +1,6 @@ package telebit -import ( - "fmt" - "strconv" -) +import "fmt" type Scheme string @@ -39,11 +36,15 @@ func NewAddr(s Scheme, t Termination, a string, p int) *Addr { } func (a *Addr) String() string { - return fmt.Sprintf("%s:%s:%s:%d", a.family, a.Scheme(), a.addr, a.port) + //return a.addr + ":" + strconv.Itoa(a.port) + return fmt.Sprintf("%s+%s:%s:%d", a.family, a.Scheme(), a.addr, a.port) } +// Network s typically network "family", such as "tcp" or "ip", +// but in this case will be "tun", which is a cue to do a `switch` +// to actually use the specific features of a telebit.Addr func (a *Addr) Network() string { - return a.addr + ":" + strconv.Itoa(a.port) + return "tun" } func (a *Addr) Port() int { diff --git a/mplexer/conn.go b/mplexer/conn.go index 12051d4..4ce53e3 100644 --- a/mplexer/conn.go +++ b/mplexer/conn.go @@ -47,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Peek(n int) (b []byte, err error) { if nil == c.peeker { - c.peeker = bufio.NewReaderSize(c, defaultPeekerSize) + c.peeker = bufio.NewReaderSize(c.relay, defaultPeekerSize) } return c.peeker.Peek(n) } diff --git a/mplexer/connwrap.go b/mplexer/connwrap.go index 7512149..94a73ed 100644 --- a/mplexer/connwrap.go +++ b/mplexer/connwrap.go @@ -8,6 +8,7 @@ import ( // ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn type ConnWrap struct { + // TODO use io.MultiReader to unbuffer the peeker //Conn net.Conn peeker *bufio.Reader Conn net.Conn @@ -31,7 +32,7 @@ func (c *ConnWrap) Peek(n int) ([]byte, error) { default: // *net.UDPConn,*net.TCPConn,*net.IPConn,*net.UnixConn if nil == c.peeker { - c.peeker = bufio.NewReaderSize(c, defaultPeekerSize) + c.peeker = bufio.NewReaderSize(c.Conn, defaultPeekerSize) } return c.peeker.Peek(n) } @@ -93,31 +94,44 @@ func (c *ConnWrap) Servername() string { // or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http". func (c *ConnWrap) isTerminated() bool { // TODO look at SNI, may need context for Peek() timeout - if nil != c.Plain { - return true - } + /* + if nil != c.Plain { + return true + } + */ // how to know how many bytes to read? really needs timeout - b, err := c.Peek(2) - if len(b) >= 2 { - // TODO better detection? + c.SetDeadline(time.Now().Add(5 * time.Second)) + n := 6 + b, _ := c.Peek(n) + defer c.SetDeadline(time.Time{}) + if len(b) >= n { // SSL v3.x / TLS v1.x - if 0x16 == b[0] && 0x03 == b[1] { + // 0: TLS Byte + // 1: Major Version + // 2: 0-Indexed Minor Version + // 3-4: Header Length + // 5: TLS Client Hello Marker Byte + if 0x16 == b[0] && 0x03 == b[1] && 0x01 == b[5] { + //length := (int(b[3]) << 8) + int(b[4]) return false } } - if nil != err { - return true - } + return true + /* + if nil != err { + return true + } - switch conn := c.Conn.(type) { - case *ConnWrap: - return conn.isTerminated() - case *Conn: - _, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)] - return !ok - } - return false + switch conn := c.Conn.(type) { + case *ConnWrap: + return conn.isTerminated() + case *Conn: + _, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)] + return !ok + } + return false + */ } // LocalAddr returns the local network address. diff --git a/mplexer/listener.go b/mplexer/listener.go index 8a0d9ad..412b616 100644 --- a/mplexer/listener.go +++ b/mplexer/listener.go @@ -126,7 +126,7 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) { // remember where the error message goes if "error" == string(dst.scheme) { pipe.Close() - delete(l.conns, src.Network()) + delete(l.conns, src.String()) fmt.Printf("a stream errored remotely: %v\n", src) } @@ -139,12 +139,12 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) { if "end" == string(dst.scheme) { fmt.Println("[debug] end") pipe.Close() - delete(l.conns, src.Network()) + delete(l.conns, src.String()) } } func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn { - connID := src.Network() + connID := src.String() pipe, ok := l.conns[connID] // Pipe exists @@ -157,8 +157,8 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn { rawPipe, pipe := net.Pipe() newconn := &Conn{ //updated: time.Now(), - relaySourceAddr: *src, - relayTargetAddr: *dst, + relaySourceAddr: *dst, + relayTargetAddr: *src, relay: rawPipe, } l.conns[connID] = pipe @@ -173,7 +173,11 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn { // In any case, we'll just close it all newconn.Close() pipe.Close() - fmt.Printf("a stream is done: %q\n", err) + if nil != err { + fmt.Printf("a stream is done: %q\n", err) + } else { + fmt.Printf("a stream is done\n") + } }() return pipe diff --git a/mplexer/routemux.go b/mplexer/routemux.go index 7222925..930d0f9 100644 --- a/mplexer/routemux.go +++ b/mplexer/routemux.go @@ -3,6 +3,8 @@ package telebit import ( "fmt" "net" + "strconv" + "strings" "time" ) @@ -31,13 +33,42 @@ func NewRouteMux() *RouteMux { // Serve dispatches the connection to the handler whose selectors matches the attributes. func (m *RouteMux) Serve(client net.Conn) error { - wconn := &ConnWrap{Conn: client} - servername := wconn.Servername() + var wconn *ConnWrap + switch conn := client.(type) { + case *ConnWrap: + wconn = conn + default: + wconn = &ConnWrap{Conn: client} + } + + var servername string + var port string + // TODO go back to Servername on conn, but with SNI + //servername := wconn.Servername() + fam := wconn.LocalAddr().Network() + if "tun" == fam { + switch laddr := wconn.LocalAddr().(type) { + case *Addr: + servername = laddr.Hostname() + port = ":" + strconv.Itoa(laddr.Port()) + default: + panic("impossible type switch: Addr is 'tun' but didn't match") + } + } else { + // TODO make an AddrWrap to do this switch + addr := wconn.LocalAddr().String() + parts := strings.Split(addr, ":") + port = ":" + parts[len(parts)-1] + servername = strings.Join(parts[:len(parts)-1], ":") + } + fmt.Println("Addr:", fam, servername, port) for _, meta := range m.routes { - if servername == meta.addr || "*" == meta.addr { + // TODO '*.example.com' + fmt.Println("Meta:", meta.addr) + if servername == meta.addr || "*" == meta.addr || port == meta.addr { //fmt.Println("[debug] test of route:", meta) - if err := meta.handler.Serve(client); nil != err { + if err := meta.handler.Serve(wconn); nil != err { // error should be EOF if successful return err } @@ -78,14 +109,22 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err addr: servername, terminate: true, handler: HandlerFunc(func(client net.Conn) error { - wrap := &ConnWrap{Conn: client} - if wrap.isTerminated() { + var wconn *ConnWrap + switch conn := client.(type) { + case *ConnWrap: + wconn = conn + default: + panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }") + } + + if wconn.isTerminated() { // nil to skip return nil } + //NewTerminator(acme, handler)(client) //return handler.Serve(client) - return handler.Serve(TerminateTLS(client, acme)) + return handler.Serve(TerminateTLS(wconn, acme)) }), }) return nil diff --git a/mplexer/telebit.go b/mplexer/telebit.go index f7085af..c2fd790 100644 --- a/mplexer/telebit.go +++ b/mplexer/telebit.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "os" + "strings" "time" "github.com/caddyserver/certmagic" @@ -27,6 +28,11 @@ var defaultPeekerSize = 1024 // ErrBadGateway means that the target did not accept the connection var ErrBadGateway = errors.New("EBADGATEWAY") +// The proper handling of this error +// is still being debated as of Jun 9, 2020 +// https://github.com/golang/go/issues/4373 +var errNetClosing = "use of closed network connection" + // A Handler routes, proxies, terminates, or responds to a net.Conn. type Handler interface { Serve(net.Conn) error @@ -104,7 +110,13 @@ func Forward(client net.Conn, target net.Conn, timeout time.Duration) error { } }() - fmt.Println("[debug] forwarding tcp connection") + fmt.Println( + "[debug] forwarding tcp connection", + client.LocalAddr(), + client.RemoteAddr(), + target.LocalAddr(), + target.RemoteAddr(), + ) var err error = nil ForwardData: @@ -131,7 +143,7 @@ ForwardData: if nil == err { break ForwardData } - if io.EOF != err { + if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) { fmt.Printf("read from remote client failed: %q\n", err.Error()) } else { fmt.Printf("Connection closed (possibly by remote client)\n") @@ -141,7 +153,7 @@ ForwardData: if nil == err { break ForwardData } - if io.EOF != err { + if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) { fmt.Printf("read from local target failed: %q\n", err.Error()) } else { fmt.Printf("Connection closed (possibly by local target)\n") @@ -187,7 +199,11 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn { var err error magic, err = newCertMagic(acme) if nil != err { - fmt.Fprintf(os.Stderr, "failed to initialize certificate management (discovery url? local folder perms?): %s\n", err) + fmt.Fprintf( + os.Stderr, + "failed to initialize certificate management (discovery url? local folder perms?): %s\n", + err, + ) os.Exit(1) } acmecert = magic diff --git a/mplexer/v1.go b/mplexer/v1.go index adbf1a2..cca0da0 100644 --- a/mplexer/v1.go +++ b/mplexer/v1.go @@ -126,6 +126,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) { return b, nil } parts := strings.Split(string(p.state.header), ",") + fmt.Println("[debug] Tun Header", string(p.state.header)) p.state.header = nil if len(parts) < 5 { return nil, errors.New("error unpacking header")