diff --git a/mplexer/cmd/telebit/telebit.go b/mplexer/cmd/telebit/telebit.go new file mode 100644 index 0000000..c4651b7 --- /dev/null +++ b/mplexer/cmd/telebit/telebit.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "strings" + "time" + + "git.coolaj86.com/coolaj86/go-telebitd/mplexer/packer" + + jwt "github.com/dgrijalva/jwt-go" + "github.com/gorilla/websocket" + + _ "github.com/joho/godotenv/autoload" +) + +func main() { + // TODO replace the websocket connection with a mock server + + relay := os.Getenv("RELAY") // "wss://roottest.duckdns.org:8443" + authz, err := getToken(os.Getenv("SECRET")) + if nil != err { + panic(err) + } + + ctx := context.Background() + wsd := websocket.Dialer{} + headers := http.Header{} + headers.Set("Authorization", fmt.Sprintf("Bearer %s", authz)) + // *http.Response + sep := "?" + if strings.Contains(relay, sep) { + sep = "&" + } + wsconn, _, err := wsd.DialContext(ctx, relay+sep+"access_token="+authz, headers) + if nil != err { + fmt.Println("relay:", relay) + log.Fatal(err) + return + } + + /* + // TODO for http proxy + return mplexer.TargetOptions { + Hostname // default localhost + Termination // default TLS + XFWD // default... no? + Port // default 0 + Conn // should be dialed beforehand + }, nil + */ + + /* + t := telebit.New(token) + mux := telebit.RouteMux{} + mux.HandleTLS("*", mux) // go back to itself + mux.HandleProxy("example.com", "localhost:3000") + mux.HandleTCP("example.com", func (c *telebit.Conn) { + return httpmux.Serve() + }) + + l := t.Listen("wss://example.com") + conn := l.Accept() + telebit.Serve(listener, mux) + t.ListenAndServe("wss://example.com", mux) + */ + + mux := packer.NewRouteMux() + //mux.HandleTLS("*", mux.TerminateTLS(mux)) + mux.ForwardTCP("*", "localhost:3000", 120*time.Second) + // TODO set failure + log.Fatal("Closed server: ", packer.ListenAndServe(wsconn, mux)) +} + +func getToken(secret string) (token string, err error) { + domains := []string{"dandel.duckdns.org"} + tokenData := jwt.MapClaims{"domains": domains} + + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) + if token, err = jwtToken.SignedString([]byte(secret)); err != nil { + return "", err + } + return token, nil +} diff --git a/mplexer/packer/addr.go b/mplexer/packer/addr.go index 2dc8fd3..acde363 100644 --- a/mplexer/packer/addr.go +++ b/mplexer/packer/addr.go @@ -39,7 +39,7 @@ func NewAddr(s Scheme, t Termination, a string, p int) *Addr { } func (a *Addr) String() string { - return fmt.Sprintf("%s:%s:%s:%d", a.Network(), a.Scheme(), a.addr, a.port) + return fmt.Sprintf("%s:%s:%s:%d", a.family, a.Scheme(), a.addr, a.port) } func (a *Addr) Network() string { diff --git a/mplexer/packer/decoder_test.go b/mplexer/packer/decoder_test.go index eaa2207..ded7c70 100644 --- a/mplexer/packer/decoder_test.go +++ b/mplexer/packer/decoder_test.go @@ -1,7 +1,6 @@ package packer import ( - "context" "net" "testing" ) @@ -13,8 +12,8 @@ func TestDecode1WholeBlock(t *testing.T) { func testDecodeNBlocks(t *testing.T, count int) { wp, rp := net.Pipe() - ctx := context.Background() - decoder := NewDecoder(ctx, rp) + //ctx := context.Background() + decoder := NewDecoder(rp) nAddr := 1 if count > 2 { nAddr = count - 2 @@ -23,11 +22,11 @@ func testDecodeNBlocks(t *testing.T, count int) { raw := []byte{} for i := 0; i < count; i++ { if i > 2 { - copied := src - src = copied - src.port += i + copied := srcTestAddr + srcTestAddr = copied + srcTestAddr.port += i } - h, b, err := Encode(src, dst, domain, payload) + h, b, err := Encode(payload, srcTestAddr, dstTestAddr) if nil != err { t.Fatal(err) } @@ -53,7 +52,7 @@ func testDecodeNBlocks(t *testing.T, count int) { conns: map[string]*Conn{}, } //fmt.Println("streamers gonna stream") - err := decoder.StreamDecode(th, 0) + err := decoder.Decode(th) if nil != err { t.Fatalf("failed to decode stream: %s", err) } diff --git a/mplexer/packer/encoder.go b/mplexer/packer/encoder.go index b65818f..ca14508 100644 --- a/mplexer/packer/encoder.go +++ b/mplexer/packer/encoder.go @@ -3,6 +3,7 @@ package packer import ( "context" "errors" + "fmt" "io" "sync" ) @@ -55,10 +56,12 @@ func (enc *Encoder) Run() error { } // Encode adds MPLEXY headers to raw net traffic, and is intended to be used on each client connection -func (enc *Encoder) Encode(rin io.Reader, src Addr) error { +func (enc *Encoder) Encode(rin io.Reader, src, dst Addr) error { rx := make(chan []byte) rxErr := make(chan error) + fmt.Println("what's the source to encode?", src) + go func() { for { b := make([]byte, enc.bufferSize) @@ -87,11 +90,14 @@ func (enc *Encoder) Encode(rin io.Reader, src Addr) error { //rin.Close() return errors.New("cancelled by context") case b := <-rx: - header, _, err := Encode(src, Addr{}, "", b) + header, _, err := Encode(b, src, Addr{scheme: src.scheme, addr: "", port: -1}) if nil != err { //rin.Close() return err } + fmt.Println("[debug] encode header:", string(header)) + fmt.Println("[debug] encode payload:", string(b)) + _, err = enc.write(header, b) if nil != err { //rin.Close() @@ -101,12 +107,14 @@ func (enc *Encoder) Encode(rin io.Reader, src Addr) error { // it can be assumed that err will close though, right? //rin.Close() if io.EOF == err { - header, _, _ := Encode(src, Addr{scheme: "end"}, "", nil) + header, _, _ := Encode(nil, src, Addr{scheme: "end"}) + fmt.Println("[debug] encode end: ", header) // ignore err, which may have already closed _, _ = enc.write(header, nil) return nil } - header, _, _ := Encode(src, Addr{scheme: "error"}, "", []byte(err.Error())) + // TODO transmit message , []byte(err.Error()) + header, _, _ := Encode(nil, src, Addr{scheme: "error"}) // ignore err, which may have already closed _, _ = enc.write(header, nil) return err @@ -119,10 +127,13 @@ func (enc *Encoder) write(h, b []byte) (int, error) { // mutex here so that we can get back error info enc.mux.Lock() var m int - n, err := enc.out.Write(h) - if nil == err && len(b) > 0 { - m, err = enc.out.Write(b) - } + n, err := enc.out.Write(append(h, b...)) + /* + n, err := enc.out.Write(h) + if nil == err && len(b) > 0 { + m, err = enc.out.Write(b) + } + */ enc.mux.Unlock() if nil != err { enc.outErr <- err diff --git a/mplexer/packer/encoder_test.go b/mplexer/packer/encoder_test.go index 4620fa6..57a8569 100644 --- a/mplexer/packer/encoder_test.go +++ b/mplexer/packer/encoder_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "strings" "testing" "time" ) @@ -29,7 +30,11 @@ func TestEncodeWholeBlock(t *testing.T) { // TODO check the headers too if len(str) > 0 && 0xFE == str[0] { fmt.Printf("TODO header: %q\n", str) - continue + parts := strings.Split(str, "\n") + if len(parts) <= 1 { + continue + } + str = parts[1] } b, ok := m[str] @@ -89,6 +94,10 @@ func TestEncodeWholeBlock(t *testing.T) { family: "IPv4", addr: "192.168.1.102", port: 4834, + }, Addr{ + scheme: "https", + addr: "example.com", + port: 443, }) if nil != err { t.Fatalf("Enc Err 1: %q\n", err) @@ -108,6 +117,10 @@ func TestEncodeWholeBlock(t *testing.T) { family: "IPv4", addr: "192.168.1.103", port: 4834, + }, Addr{ + scheme: "https", + addr: "example.com", + port: 443, }) if nil != err { t.Fatalf("Enc Err 2: %q\n", err) diff --git a/mplexer/packer/listener.go b/mplexer/packer/listener.go new file mode 100644 index 0000000..59fb8e4 --- /dev/null +++ b/mplexer/packer/listener.go @@ -0,0 +1,184 @@ +package packer + +import ( + "context" + "fmt" + "io" + "net" + "net/http" +) + +// A Listener transforms a multiplexed websocket connection into individual net.Conn-like connections. +type Listener struct { + //wsconn *websocket.Conn + wsw *WSWrap + incoming chan *Conn + close chan struct{} + encoder *Encoder + chunksParsed int + bytesRead int + conns map[string]net.Conn + //conns map[string]*Conn +} + +// Listen creates a new Listener and sets it up to receive and distribute connections. +func Listen(wsconn WSConn) *Listener { + ctx := context.TODO() + + // Wrap the websocket and feed it into the Encoder and Decoder + wsw := &WSWrap{wsconn: wsconn, tmpr: nil} + listener := &Listener{ + //wsconn: wsconn, + wsw: wsw, + incoming: make(chan *Conn, 1), // buffer ever so slightly + close: make(chan struct{}), + encoder: NewEncoder(ctx, wsw), + conns: map[string]net.Conn{}, + //conns: map[string]*Conn{}, + } + + // TODO perhaps the wrapper should have a mutex + // rather than having a goroutine in the encoder + go func() { + err := listener.encoder.Run() + fmt.Printf("encoder stopped entirely: %q", err) + wsw.wsconn.Close() + }() + + // Decode the stream as it comes in + decoder := NewDecoder(wsw) + go func() { + // TODO pass error to Accept() + err := decoder.Decode(listener) + + // The listener itself must be closed explicitly because + // there's an encoder with a callback between the websocket + // and the multiplexer, so it doesn't know to stop listening otherwise + listener.Close() + fmt.Printf("the main stream is done: %q\n", err) + }() + + return listener +} + +// ListenAndServe listens on a websocket and handles the incomming net.Conn-like connections with a Handler +func ListenAndServe(wsconn WSConn, mux Handler) error { + listener := Listen(wsconn) + return Serve(listener, mux) +} + +// Serve Accept()s connections which have already been unwrapped and serves them with the given Handler +func Serve(listener *Listener, mux Handler) error { + for { + client, err := listener.Accept() + if nil != err { + return err + } + + go func() { + err = mux.Serve(client) + if nil != err { + if io.EOF != err { + fmt.Printf("client could not be served: %q\n", err.Error()) + } + } + client.Close() + }() + } +} + +// Accept returns a tunneled network connection +func (l *Listener) Accept() (*Conn, error) { + select { + case rconn, ok := <-l.incoming: + if ok { + return rconn, nil + } + return nil, io.EOF + + case <-l.close: + return nil, http.ErrServerClosed + } +} + +// Close stops accepting new connections and closes the underlying websocket. +// TODO return errors. +func (l *Listener) Close() error { + l.wsw.Close() + close(l.incoming) + l.close <- struct{}{} + return nil +} + +// RouteBytes receives address information and a buffer and creates or re-uses a pipe that can be Accept()ed. +func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) { + // TODO use context to be able to cancel many at once? + l.chunksParsed++ + + src := &srcAddr + dst := &dstAddr + pipe := l.getPipe(src, dst) + + fmt.Printf("Forwarding bytes to %#v:\n", dst) + fmt.Printf("%s\n", b) + + // handle errors before data writes because I don't + // remember where the error message goes + if "error" == string(dst.scheme) { + pipe.Close() + delete(l.conns, src.Network()) + fmt.Printf("a stream errored remotely: %v\n", src) + } + + // write data, if any + if len(b) > 0 { + l.bytesRead += len(b) + pipe.Write(b) + } + // EOF, if needed + if "end" == string(dst.scheme) { + fmt.Println("[debug] end") + pipe.Close() + delete(l.conns, src.Network()) + } +} + +//func (l *Listener) getPipe(addr *Addr) *Conn { +func (l *Listener) getPipe(src, dst *Addr) net.Conn { + connID := src.Network() + pipe, ok := l.conns[connID] + + // Pipe exists + if ok { + return pipe + } + + // Create pipe + rawPipe, pipe := net.Pipe() + newconn := &Conn{ + //updated: time.Now(), + relaySourceAddr: *src, + /* + relayRemoteAddr: Addr{ + scheme: addr.scheme, + }, + */ + relay: rawPipe, + } + l.conns[connID] = pipe + l.incoming <- newconn + + // Handle encoding + go func() { + // TODO handle err + err := l.encoder.Encode(pipe, *src, *dst) + // the error may be EOF or ErrServerClosed or ErrGoingAwawy or some such + // or it might be an actual error + // In any case, we'll just close it all + newconn.Close() + pipe.Close() + fmt.Printf("a stream is done: %q\n", err) + }() + + return pipe +} diff --git a/mplexer/packer/listener_test.go b/mplexer/packer/listener_test.go index 31d4a52..792aa61 100644 --- a/mplexer/packer/listener_test.go +++ b/mplexer/packer/listener_test.go @@ -1,416 +1,106 @@ package packer import ( - "context" "errors" - "fmt" "io" "net" - "net/http" - "os" - "strings" "testing" "time" - - jwt "github.com/dgrijalva/jwt-go" - - "github.com/gorilla/websocket" ) func TestDialServer(t *testing.T) { // TODO replace the websocket connection with a mock server - relay := "wss://roottest.duckdns.org:8443" - authz, err := getToken("xxxxyyyyssss8347") - if nil != err { - panic(err) + //ctx := context.Background() + wsconn := &WSTestConn{ + rwt: &RWTest{}, } - ctx := context.Background() - wsd := websocket.Dialer{} - headers := http.Header{} - headers.Set("Authorization", fmt.Sprintf("Bearer %s", authz)) - // *http.Response - sep := "?" - if strings.Contains(relay, sep) { - sep = "&" - } - wsconn, _, err := wsd.DialContext(ctx, relay+sep+"access_token="+authz, headers) - if nil != err { - fmt.Println("relay:", relay) - t.Fatal(err) - return - } - - /* - t := telebit.New(token) - mux := telebit.RouteMux{} - mux.HandleTLS("*", mux) // go back to itself - mux.HandleProxy("example.com", "localhost:3000") - mux.HandleTCP("example.com", func (c *telebit.Conn) { - return httpmux.Serve() - }) - - l := t.Listen("wss://example.com") - conn := l.Accept() - telebit.Serve(listener, mux) - t.ListenAndServe("wss://example.com", mux) - */ - mux := NewRouteMux() - // TODO set failure t.Fatal(ListenAndServe(wsconn, mux)) } -func getToken(secret string) (token string, err error) { - domains := []string{"dandel.duckdns.org"} - tokenData := jwt.MapClaims{"domains": domains} +var ErrNoImpl error = errors.New("not implemented") - jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) - if token, err = jwtToken.SignedString([]byte(secret)); err != nil { - return "", err +// WSTestConn is a fake websocket connection +type WSTestConn struct { + closed bool + rwt *RWTest +} + +func (wst *WSTestConn) NextReader() (messageType int, r io.Reader, err error) { + return 0, nil, ErrNoImpl +} +func (wst *WSTestConn) NextWriter(messageType int) (io.WriteCloser, error) { + return nil, ErrNoImpl +} +func (wst *WSTestConn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if wst.closed { + return io.EOF } - return token, nil + return nil } - -type Listener struct { - ws *websocket.Conn - incoming chan *Conn - close chan struct{} - encoder *Encoder - conns map[string]*Conn - chunksParsed int - bytesRead int -} - -func ListenAndServe(ws *websocket.Conn, mux Handler) error { - listener := Listen(ws) - return Serve(listener, mux) -} - -func Listen(ws *websocket.Conn) *Listener { - ctx := context.TODO() - - // Wrap the websocket and feed it into the Encoder and Decoder - rw := &WSConn{c: ws, nr: nil} - listener := &Listener{ - ws: ws, - conns: map[string]*Conn{}, - incoming: make(chan *Conn, 1), // buffer ever so slightly - close: make(chan struct{}), - encoder: NewEncoder(ctx, rw), +func (wst *WSTestConn) WriteMessage(messageType int, data []byte) error { + if wst.closed { + return io.EOF } - // TODO perhaps the wrapper should have a mutex - // rather than having a goroutine in the encoder - go func() { - err := listener.encoder.Run() - fmt.Printf("encoder stopped entirely: %q", err) - rw.c.Close() - }() - - // Decode the stream as it comes in - decoder := NewDecoder(rw) - go func() { - // TODO pass error to Accept() - err := decoder.Decode(listener) - rw.Close() - fmt.Printf("the main stream is done: %q\n", err) - }() - - return listener + return nil +} +func (wst *WSTestConn) SetReadDeadline(t time.Time) error { + return ErrNoImpl +} +func (wst *WSTestConn) Close() error { + wst.closed = true + return nil +} +func (wst *WSTestConn) RemoteAddr() net.Addr { + addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8443") + return addr } -func (l *Listener) RouteBytes(a Addr, b []byte) { - // TODO use context to be able to cancel many at once? - l.chunksParsed++ +// RWTest is a fake buffer +type RWTest struct { + closed bool + tmpr []byte +} - addr := &a - pipe := l.getPipe(addr) - - // handle errors before data writes because I don't - // remember where the error message goes - if "error" == string(addr.scheme) { - pipe.Close() - delete(l.conns, addr.Network()) - fmt.Printf("a stream errored remotely: %v\n", addr) +func (rwt *RWTest) Read(dst []byte) (int, error) { + if rwt.closed { + return 0, io.EOF } - // write data, if any - if len(b) > 0 { - l.bytesRead += len(b) - pipe.Write(b) + id := Addr{ + scheme: "http", + addr: "192.168.1.108", + port: 6732, } - // EOF, if needed - if "end" == string(addr.scheme) { - pipe.Close() - delete(l.conns, addr.Network()) - } -} - -func (l *Listener) getPipe(addr *Addr) *Conn { - connID := addr.Network() - pipe, ok := l.conns[connID] - - // Pipe exists - if ok { - return pipe + tun := Addr{ + scheme: "http", + termination: TLS, + addr: "abc.example.com", + port: 443, } - // Create pipe - rawPipe, encodable := net.Pipe() - pipe = &Conn{ - //updated: time.Now(), - relayRemoteAddr: *addr, - relay: rawPipe, - } - l.conns[connID] = pipe - l.incoming <- pipe - - // Handle encoding - go func() { - // TODO handle err - err := l.encoder.Encode(encodable, *pipe.LocalAddr()) - // the error may be EOF or ErrServerClosed or ErrGoingAwawy or some such - // or it might be an actual error - // In any case, we'll just close it all - encodable.Close() - pipe.Close() - fmt.Printf("a stream is done: %q\n", err) - }() - - return pipe -} - -func Serve(listener *Listener, mux Handler) error { - for { - client, err := listener.Accept() - if nil != err { - return err - } - - go func() { - err = mux.Serve(client) - if nil != err { - if io.EOF != err { - fmt.Printf("client could not be served: %q\n", err.Error()) - } - } - client.Close() - }() - } -} - -func (l *Listener) Accept() (*Conn, error) { - select { - case rconn, ok := <-l.incoming: - if ok { - return rconn, nil - } - return nil, io.EOF - - case <-l.close: - l.ws.Close() - // TODO is another error more suitable? - return nil, http.ErrServerClosed - } -} - -type Handler interface { - Serve(*Conn) error - GetTargetConn(*Addr) (net.Conn, error) -} - -type RouteMux struct { - defaultTimeout time.Duration -} - -func NewRouteMux() *RouteMux { - mux := &RouteMux{ - defaultTimeout: 45 * time.Second, - } - return mux -} - -func (m *RouteMux) Serve(client *Conn) error { - // TODO could proxy or handle directly, etc - target, err := m.GetTargetConn(client.RemoteAddr()) - if nil != err { - return err + if 0 == len(rwt.tmpr) { + b := []byte("Hello, World!") + h, _, _ := Encode(b, id, tun) + rwt.tmpr = append(h, b...) } - return Forward(client, target, m.defaultTimeout) + n := copy(dst, rwt.tmpr) + rwt.tmpr = rwt.tmpr[n:] + + return n, nil } -// Forward port-forwards a relay (websocket) client to a target (local) server -func Forward(client *Conn, target net.Conn, timeout time.Duration) error { - - // Something like ReadAhead(size) should signal - // to read and send up to `size` bytes without waiting - // for a response - since we can't signal 'non-read' as - // is the normal operation of tcp... or can we? - // And how do we distinguish idle from dropped? - // Maybe this should have been a udp protocol??? - - defer client.Close() - defer target.Close() - - srcCh := make(chan []byte) - dstCh := make(chan []byte) - srcErrCh := make(chan error) - dstErrCh := make(chan error) - - // Source (Relay) Read Channel - go func() { - for { - b := make([]byte, defaultBufferSize) - n, err := client.Read(b) - if n > 0 { - srcCh <- b - } - if nil != err { - // TODO let client log this server-side error (unless EOF) - // (nil here because we probably can't send the error to the relay) - srcErrCh <- err - break - } - } - }() - - // Target (Local) Read Channel - go func() { - for { - b := make([]byte, defaultBufferSize) - n, err := target.Read(b) - if n > 0 { - dstCh <- b - } - if nil != err { - if io.EOF == err { - err = nil - } - dstErrCh <- err - break - } - } - }() - - var err error = nil - for { - select { - // TODO do we need a context here? - //case <-ctx.Done(): - // break - case b := <-srcCh: - client.SetDeadline(time.Now().Add(timeout)) - _, err = target.Write(b) - if nil != err { - fmt.Printf("write to target failed: %q", err.Error()) - break - } - case b := <-dstCh: - target.SetDeadline(time.Now().Add(timeout)) - _, err = client.Write(b) - if nil != err { - fmt.Printf("write to remote failed: %q", err.Error()) - break - } - case err = <-srcErrCh: - if nil != err { - fmt.Printf("read from remote failed: %q", err.Error()) - } - break - case err = <-dstErrCh: - if nil != err { - fmt.Printf("read from target failed: %q", err.Error()) - } - break - - } +func (rwt *RWTest) Write(int, []byte) error { + if rwt.closed { + return io.EOF } - - client.Close() - return err + return nil } -// this function is very client-specific logic -func (m *RouteMux) GetTargetConn(paddr *Addr) (net.Conn, error) { - //if target := GetTargetByPort(paddr.Port()); nil != target { } - if target := m.GetTargetByServername(paddr.Hostname()); nil != target { - tconn, err := net.Dial(target.Network(), target.Hostname()) - if nil != err { - return nil, err - } - /* - // TODO for http proxy - return mplexer.TargetOptions { - Hostname // default localhost - Termination // default TLS - XFWD // default... no? - Port // default 0 - Conn // should be dialed beforehand - }, nil - */ - return tconn, nil - } - // TODO - return nil, errors.New("Bad Gateway") -} - -func (m *RouteMux) GetTargetByServername(servername string) *Addr { - return NewAddr( - HTTPS, - TCP, // TCP -> termination.None? / Plain? - "localhost", - 3000, - ) -} - -type WSConn struct { - c *websocket.Conn - nr io.Reader - //w io.WriteCloser - //pingCh chan struct{} -} - -func (ws *WSConn) Read(b []byte) (int, error) { - if nil == ws.nr { - _, r, err := ws.c.NextReader() - if nil != err { - return 0, err - } - ws.nr = r - } - n, err := ws.nr.Read(b) - if io.EOF == err { - err = nil - } - return n, err -} - -func (ws *WSConn) Write(b []byte) (int, error) { - // TODO create or reset ping deadline - // TODO document that more complete writes are preferred? - - w, err := ws.c.NextWriter(websocket.BinaryMessage) - if nil != err { - return 0, err - } - n, err := w.Write(b) - if nil != err { - return n, err - } - err = w.Close() - return n, err -} - -func (ws *WSConn) Close() error { - // TODO handle EOF as websocket.CloseNormal? - message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection") - deadline := time.Now().Add(10 * time.Second) - err := ws.c.WriteControl(websocket.CloseMessage, message, deadline) - if nil != err { - fmt.Fprintf(os.Stderr, "failed to write close message to websocket: %s\n", err) - } - _ = ws.c.Close() - return err +func (rwt *RWTest) Close() error { + rwt.closed = true + return nil } diff --git a/mplexer/packer/packer.go b/mplexer/packer/packer.go index f0ce543..93ab897 100644 --- a/mplexer/packer/packer.go +++ b/mplexer/packer/packer.go @@ -5,8 +5,9 @@ import ( ) // Encode creates an MPLEXY V1 header for the given addresses and payload -func Encode(id, tun Addr, domain string, payload []byte) ([]byte, []byte, error) { +func Encode(payload []byte, id, tun Addr) ([]byte, []byte, error) { n := len(payload) + domain := tun.addr header := []byte(fmt.Sprintf( "%s,%s,%d,%d,%s,%d,%s,\n", id.family, id.addr, id.port, diff --git a/mplexer/packer/packer_test.go b/mplexer/packer/packer_test.go index e59606c..82292f5 100644 --- a/mplexer/packer/packer_test.go +++ b/mplexer/packer/packer_test.go @@ -13,17 +13,17 @@ func TestEncodeDataMessage(t *testing.T) { } tun := Addr{ family: id.family, + addr: "ex1.telebit.io", port: 80, scheme: "http", } - domain := "ex1.telebit.io" payload := []byte("Hello, World!") header := []byte("IPv4,192.168.1.101,6743," + strconv.Itoa(len(payload)) + ",http,80,ex1.telebit.io,\n") //header = append([]byte{V1, byte(len(header))}, header...) header = append([]byte{254, byte(len(header))}, header...) - h, b, err := Encode(id, tun, domain, payload) + h, b, err := Encode(payload, id, tun) if nil != err { t.Fatal(err) } diff --git a/mplexer/packer/parser.go b/mplexer/packer/parser.go index 3a515ce..ed2f329 100644 --- a/mplexer/packer/parser.go +++ b/mplexer/packer/parser.go @@ -21,7 +21,8 @@ type ParserState struct { headerLen int header []byte payloadLen int - addr Addr + srcAddr Addr + dstAddr Addr payloadWritten int } @@ -47,7 +48,7 @@ func NewParser(handler Router) *Parser { } type Router interface { - RouteBytes(Addr, []byte) + RouteBytes(src, dst Addr, payload []byte) } // Write receives tunnel data and creates or writes to connections diff --git a/mplexer/packer/parser_test.go b/mplexer/packer/parser_test.go index a6ca7e9..64de1b6 100644 --- a/mplexer/packer/parser_test.go +++ b/mplexer/packer/parser_test.go @@ -4,15 +4,14 @@ import ( "math/rand" "net" "testing" - "time" ) -var src = Addr{ +var srcTestAddr = Addr{ family: "IPv4", addr: "192.168.1.101", port: 6743, } -var dst = Addr{ +var dstTestAddr = Addr{ family: "IPv4", port: 80, scheme: "http", @@ -26,19 +25,21 @@ type testHandler struct { bytesRead int } -func (th *testHandler) WriteMessage(a Addr, b []byte) { +func (th *testHandler) RouteBytes(srcAddr, dstAddr Addr, b []byte) { th.chunksParsed++ - addr := &a - _, ok := th.conns[addr.Network()] + src := &srcAddr + dst := &dstAddr + _, ok := th.conns[src.Network()] if !ok { rconn, wconn := net.Pipe() conn := &Conn{ - updated: time.Now(), - relayRemoteAddr: *addr, + //updated: time.Now(), + relaySourceAddr: *src, + relayRemoteAddr: *dst, relay: rconn, local: wconn, } - th.conns[addr.Network()] = conn + th.conns[src.Network()] = conn } th.bytesRead += len(b) } @@ -96,7 +97,7 @@ func TestParse1AndRest(t *testing.T) { p := NewParser(th) - h, b, err := Encode(src, dst, domain, payload) + h, b, err := Encode(payload, srcTestAddr, dstTestAddr) if nil != err { t.Fatal(err) } @@ -131,7 +132,7 @@ func TestParseRestAnd1(t *testing.T) { p := NewParser(th) - h, b, err := Encode(src, dst, domain, payload) + h, b, err := Encode(payload, srcTestAddr, dstTestAddr) if nil != err { t.Fatal(err) } @@ -168,7 +169,7 @@ func testParseByN(t *testing.T, n int) { p := NewParser(th) - h, b, err := Encode(src, dst, domain, payload) + h, b, err := Encode(payload, srcTestAddr, dstTestAddr) if nil != err { t.Fatal(err) } @@ -233,11 +234,11 @@ func testParseNBlocks(t *testing.T, count int) { raw := []byte{} for i := 0; i < count; i++ { if i > 2 { - copied := src - src = copied - src.port += i + copied := srcTestAddr + srcTestAddr = copied + srcTestAddr.port += i } - h, b, err := Encode(src, dst, domain, payload) + h, b, err := Encode(payload, srcTestAddr, dstTestAddr) if nil != err { t.Fatal(err) } diff --git a/mplexer/packer/routemux.go b/mplexer/packer/routemux.go new file mode 100644 index 0000000..7bf4444 --- /dev/null +++ b/mplexer/packer/routemux.go @@ -0,0 +1,68 @@ +package packer + +import ( + "errors" + "time" +) + +// A RouteMux is a net.Conn multiplexer. +// +// It matches the port, domain, or connection type of a connection +// and selects the matching handler. +type RouteMux struct { + defaultTimeout time.Duration + list []meta +} + +type meta struct { + addr string + handler Handler +} + +// NewRouteMux allocates and returns a new RouteMux. +func NewRouteMux() *RouteMux { + mux := &RouteMux{ + defaultTimeout: 45 * time.Second, + } + return mux +} + +// Serve dispatches the connection to the handler whose selectors matches the attributes. +func (m *RouteMux) Serve(client *Conn) error { + addr := client.RemoteAddr() + + for _, meta := range m.list { + if addr.addr == meta.addr || "*" == meta.addr { + if err := meta.handler.Serve(client); nil != err { + return err + } + } + } + + return client.Close() +} + +// ForwardTCP creates and returns a connection to a local handler target. +func (m *RouteMux) ForwardTCP(servername string, target string, timeout time.Duration) error { + // TODO check servername + m.list = append(m.list, meta{ + addr: servername, + handler: NewForwarder(target, timeout), + }) + return nil +} + +// HandleTCP creates and returns a connection to a local handler target. +func (m *RouteMux) HandleTCP(servername string, handler Handler) error { + // TODO check servername + m.list = append(m.list, meta{ + addr: servername, + handler: handler, + }) + return nil +} + +// HandleTLS creates and returns a connection to a local handler target. +func (m *RouteMux) HandleTLS(servername string, serve Handler) error { + return errors.New("not implemented") +} diff --git a/mplexer/packer/telebit.go b/mplexer/packer/telebit.go index 9295a99..654fae8 100644 --- a/mplexer/packer/telebit.go +++ b/mplexer/packer/telebit.go @@ -1,6 +1,12 @@ package packer -import "errors" +import ( + "errors" + "fmt" + "io" + "net" + "time" +) // Note: 64k is the TCP max, but 1460b is the 100mbit Ethernet max (1500 MTU - overhead), // but 1Gbit Ethernet (Jumbo frame) has an 9000b MTU @@ -10,3 +16,117 @@ var defaultBufferSize = 8192 // ErrBadGateway means that the target did not accept the connection var ErrBadGateway = errors.New("EBADGATEWAY") + +// A Handler routes, proxies, terminates, or responds to a net.Conn. +type Handler interface { + Serve(*Conn) error +} + +type HandlerFunc func(*Conn) error + +// Serve calls f(conn). +func (f HandlerFunc) Serve(conn *Conn) error { + return f(conn) +} + +// NewForwarder creates a handler that port-forwards to a target +func NewForwarder(target string, timeout time.Duration) HandlerFunc { + return func(client *Conn) error { + tconn, err := net.Dial("tcp", target) + if nil != err { + return err + } + return Forward(client, tconn, timeout) + } +} + +// Forward port-forwards a relay (websocket) client to a target (local) server +func Forward(client *Conn, target net.Conn, timeout time.Duration) error { + + // Something like ReadAhead(size) should signal + // to read and send up to `size` bytes without waiting + // for a response - since we can't signal 'non-read' as + // is the normal operation of tcp... or can we? + // And how do we distinguish idle from dropped? + // Maybe this should have been a udp protocol??? + + defer client.Close() + defer target.Close() + + srcCh := make(chan []byte) + dstCh := make(chan []byte) + srcErrCh := make(chan error) + dstErrCh := make(chan error) + + // Source (Relay) Read Channel + go func() { + for { + b := make([]byte, defaultBufferSize) + n, err := client.Read(b) + if n > 0 { + srcCh <- b[:n] + } + if nil != err { + // TODO let client log this server-side error (unless EOF) + // (nil here because we probably can't send the error to the relay) + srcErrCh <- err + break + } + } + }() + + // Target (Local) Read Channel + go func() { + for { + b := make([]byte, defaultBufferSize) + n, err := target.Read(b) + if n > 0 { + dstCh <- b[:n] + } + if nil != err { + if io.EOF == err { + err = nil + } + dstErrCh <- err + break + } + } + }() + + var err error = nil + for { + select { + // TODO do we need a context here? + //case <-ctx.Done(): + // break + case b := <-srcCh: + client.SetDeadline(time.Now().Add(timeout)) + _, err = target.Write(b) + if nil != err { + fmt.Printf("write to target failed: %q", err.Error()) + break + } + case b := <-dstCh: + target.SetDeadline(time.Now().Add(timeout)) + _, err = client.Write(b) + if nil != err { + fmt.Printf("write to remote failed: %q", err.Error()) + break + } + case err = <-srcErrCh: + if nil != err { + fmt.Printf("read from remote failed: %q", err.Error()) + } + break + case err = <-dstErrCh: + if nil != err { + fmt.Printf("read from target failed: %q", err.Error()) + } + break + + } + } + + client.Close() + return err +} diff --git a/mplexer/packer/v1.go b/mplexer/packer/v1.go index 5a36858..7d49d0c 100644 --- a/mplexer/packer/v1.go +++ b/mplexer/packer/v1.go @@ -27,6 +27,10 @@ const ( LengthIndex // ServiceIndex is the 5th (4) address element, the Scheme or Control message type ServiceIndex + // RelayPortIndex is the 6th (5) address element, the port on which the connection was established + RelayPortIndex + // ServernameIndex is the 7th (6) address element, the SNI Servername or Hostname + ServernameIndex ) // Header is the MPLEXY address/control meta data that comes before a packet @@ -140,13 +144,24 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) { return nil, errors.New("'control' messages not implemented") } - addr := Addr{ + src := Addr{ family: parts[FamilyIndex], addr: parts[AddressIndex], port: port, + //scheme: Scheme(service), + } + dst := Addr{ scheme: Scheme(service), } - p.state.addr = addr + if len(parts) > RelayPortIndex { + port, _ := strconv.Atoi(parts[RelayPortIndex]) + dst.port = port + } + if len(parts) > ServernameIndex { + dst.addr = parts[ServernameIndex] + } + p.state.srcAddr = src + p.state.dstAddr = dst /* p.state.conn = p.conns[addr.Network()] if nil == p.state.conn { @@ -187,7 +202,7 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) { */ //fmt.Printf("[debug] [2] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen) - p.handler.RouteBytes(p.state.addr, []byte{}) + p.handler.RouteBytes(p.state.srcAddr, p.state.dstAddr, []byte{}) return b, nil } @@ -207,7 +222,7 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) { return b, nil } */ - p.handler.RouteBytes(p.state.addr, c) + p.handler.RouteBytes(p.state.srcAddr, p.state.dstAddr, c) p.consumed += k p.state.payloadWritten += k diff --git a/mplexer/packer/wswrap.go b/mplexer/packer/wswrap.go new file mode 100644 index 0000000..02ce5a9 --- /dev/null +++ b/mplexer/packer/wswrap.go @@ -0,0 +1,88 @@ +package packer + +import ( + "fmt" + "io" + "net" + "os" + "time" + + "github.com/gorilla/websocket" +) + +// WSWrap wraps a websocket.Conn instance to behave like net.Conn. +// TODO make conform. +type WSWrap struct { + wsconn WSConn + tmpr io.Reader + //w io.WriteCloser + //pingCh chan struct{} +} + +// WSConn defines a interface for gorilla websockets for the purpose of testing +type WSConn interface { + NextReader() (messageType int, r io.Reader, err error) + NextWriter(messageType int) (io.WriteCloser, error) + WriteControl(messageType int, data []byte, deadline time.Time) error + WriteMessage(messageType int, data []byte) error + SetReadDeadline(t time.Time) error + Close() error + RemoteAddr() net.Addr + // LocalAddr() net.Addr +} + +func (wsw *WSWrap) Read(b []byte) (int, error) { + if nil == wsw.tmpr { + _, msgr, err := wsw.wsconn.NextReader() + if nil != err { + fmt.Println("debug wsw NextReader err:", err) + return 0, err + } + wsw.tmpr = msgr + } + + n, err := wsw.tmpr.Read(b) + if nil != err { + fmt.Println("debug wsw Read err:", err) + if io.EOF == err { + wsw.tmpr = nil + // ignore the message EOF because it's not the websocket EOF + err = nil + } + } + return n, err +} + +func (wsw *WSWrap) Write(b []byte) (int, error) { + // TODO create or reset ping deadline + // TODO document that more complete writes are preferred? + + msgw, err := wsw.wsconn.NextWriter(websocket.BinaryMessage) + if nil != err { + fmt.Println("debug wsw NextWriter err:", err) + return 0, err + } + n, err := msgw.Write(b) + if nil != err { + fmt.Println("debug wsw Write err:", err) + return n, err + } + + // if the message error fails, we can assume the websocket is damaged + return n, msgw.Close() +} + +// Close will close the websocket with a control message +func (wsw *WSWrap) Close() error { + fmt.Println("[debug] closing the websocket.Conn") + + // TODO handle EOF as websocket.CloseNormal? + message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection") + deadline := time.Now().Add(10 * time.Second) + err := wsw.wsconn.WriteControl(websocket.CloseMessage, message, deadline) + if nil != err { + fmt.Fprintf(os.Stderr, "failed to write close message to websocket: %s\n", err) + } + _ = wsw.wsconn.Close() + return err +}