From f67fc7324d435374715c8c125330eb7772e0aef4 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 18 May 2020 22:36:20 -0600 Subject: [PATCH] more parser tests --- mplexer/listener.go | 4 +- mplexer/packer/parser.go | 27 ++--- mplexer/packer/parser_test.go | 185 ++++++++++++++++++++++++++++++---- mplexer/packer/v1.go | 60 +++++++---- 4 files changed, 220 insertions(+), 56 deletions(-) diff --git a/mplexer/listener.go b/mplexer/listener.go index c3b3b09..cd77335 100644 --- a/mplexer/listener.go +++ b/mplexer/listener.go @@ -81,7 +81,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis for { time.Sleep(15 * time.Second) deadline := time.Now().Add(45 * time.Second) - if err := wsconn.WriteControl(websocket.PingMessage, "", deadline); nil != err { + if err := wsconn.WriteControl(websocket.PingMessage, []byte(""), deadline); nil != err { fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err) cancel() break @@ -94,7 +94,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis // TODO optimal buffer size b := make([]byte, 128*1024) for { - n, err := listener.packer.Read(b) + n, err := listener.parser.Read(b) if n > 0 { if err := wsconn.WriteMessage(websocket.BinaryMessage, b); nil != err { fmt.Fprintf(os.Stderr, "failed to write packer message to websocket: %s\n", err) diff --git a/mplexer/packer/parser.go b/mplexer/packer/parser.go index 7e3b73f..9688e1f 100644 --- a/mplexer/packer/parser.go +++ b/mplexer/packer/parser.go @@ -3,7 +3,6 @@ package packer import ( "context" "errors" - "fmt" ) type Parser struct { @@ -15,11 +14,11 @@ type Parser struct { parseState State dataReady chan struct{} data []byte - written int + consumed int } type ParserState struct { - written int + consumed int version byte headerLen int header []byte @@ -60,19 +59,21 @@ func (p *Parser) Write(b []byte) (int, error) { return 0, errors.New("developer error: wrote 0 bytes") } - // so that we can overwrite the main state - // as soon as a full message has completed - // but still keep the number of bytes written - if 0 == p.state.written { - p.written = 0 - } + /* + // so that we can overwrite the main state + // as soon as a full message has completed + // but still keep the number of bytes written + if 0 == p.state.written { + p.written = 0 + } + */ switch p.parseState { case VersionState: - fmt.Println("version state", b[0]) + //fmt.Println("[debug] version state", b[0]) p.state.version = b[0] b = b[1:] - p.state.written += 1 + p.consumed += 1 p.parseState += 1 default: // do nothing @@ -80,8 +81,8 @@ func (p *Parser) Write(b []byte) (int, error) { switch p.state.version { case V1: - fmt.Println("v1 unmarshal") - return p.written, p.unpackV1(b) + //fmt.Println("[debug] v1 unmarshal") + return p.unpackV1(b) default: return 0, errors.New("incorrect version or version not implemented") } diff --git a/mplexer/packer/parser_test.go b/mplexer/packer/parser_test.go index 93f046b..c636b98 100644 --- a/mplexer/packer/parser_test.go +++ b/mplexer/packer/parser_test.go @@ -2,12 +2,24 @@ package packer import ( "context" - "fmt" "net" "testing" "time" ) +var src = Addr{ + family: "IPv4", + addr: "192.168.1.101", + port: 6743, +} +var dst = Addr{ + family: "IPv4", + port: 80, + scheme: "http", +} +var domain = "ex1.telebit.io" +var payload = []byte("Hello, World!") + type testHandler struct { conns map[string]*Conn chunksParsed int @@ -15,6 +27,7 @@ type testHandler struct { } func (th *testHandler) WriteMessage(a Addr, b []byte) { + th.chunksParsed += 1 addr := &a _, ok := th.conns[addr.Network()] if !ok { @@ -27,11 +40,30 @@ func (th *testHandler) WriteMessage(a Addr, b []byte) { } th.conns[addr.Network()] = conn } - th.chunksParsed += 1 th.bytesRead += len(b) } -func TestParseWholeBlock(t *testing.T) { +func TestParse1WholeBlock(t *testing.T) { + testParseNBlocks(t, 1) +} + +func TestParse2WholeBlocks(t *testing.T) { + testParseNBlocks(t, 2) +} + +func TestParse3WholeBlocks(t *testing.T) { + testParseNBlocks(t, 3) +} + +func TestParse2Addrs(t *testing.T) { + testParseNBlocks(t, 4) +} + +func TestParse3Addrs(t *testing.T) { + testParseNBlocks(t, 5) +} + +func TestParse1AndRest(t *testing.T) { ctx := context.Background() //ctx, cancel := context.WithCancel(ctx) @@ -40,25 +72,17 @@ func TestParseWholeBlock(t *testing.T) { } p := NewParser(ctx, th) - payload := []byte(`Hello, World!`) - fmt.Println("payload len", len(payload)) - src := Addr{ - family: "IPv4", - addr: "192.168.1.101", - port: 6743, - } - dst := Addr{ - family: "IPv4", - port: 80, - scheme: "http", - } - domain := "ex1.telebit.io" + h, b, err := Encode(src, dst, domain, payload) if nil != err { t.Fatal(err) } raw := append(h, b...) - n, err := p.Write(raw) + n, err := p.Write(raw[:1]) + if nil != err { + t.Fatal(err) + } + m, err := p.Write(raw[1:]) if nil != err { t.Fatal(err) } @@ -67,12 +91,135 @@ func TestParseWholeBlock(t *testing.T) { t.Fatal("should have parsed one connection") } if 1 != th.chunksParsed { - t.Fatal("should have parsed one chunck") + t.Fatal("should have parsed 1 chunck(s)") } if len(payload) != th.bytesRead { t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) } - if n != len(raw) { + if n+m != len(raw) { t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw)) } } + +func TestParseRestAnd1(t *testing.T) { + ctx := context.Background() + //ctx, cancel := context.WithCancel(ctx) + + th := &testHandler{ + conns: map[string]*Conn{}, + } + + p := NewParser(ctx, th) + + h, b, err := Encode(src, dst, domain, payload) + if nil != err { + t.Fatal(err) + } + raw := append(h, b...) + i := len(raw) + n, err := p.Write(raw[:i-1]) + if nil != err { + t.Fatal(err) + } + m, err := p.Write(raw[i-1:]) + if nil != err { + t.Fatal(err) + } + + if 1 != len(th.conns) { + t.Fatal("should have parsed one connection") + } + if 2 != th.chunksParsed { + t.Fatal("should have parsed 2 chunck(s)") + } + if len(payload) != th.bytesRead { + t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) + } + if n+m != len(raw) { + t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw)) + } +} + +func TestParse1By1(t *testing.T) { + ctx := context.Background() + //ctx, cancel := context.WithCancel(ctx) + + th := &testHandler{ + conns: map[string]*Conn{}, + } + + p := NewParser(ctx, th) + + h, b, err := Encode(src, dst, domain, payload) + if nil != err { + t.Fatal(err) + } + raw := append(h, b...) + count := 0 + for _, b := range raw { + n, err := p.Write([]byte{b}) + if nil != err { + t.Fatal(err) + } + count += n + } + + if 1 != len(th.conns) { + t.Fatal("should have parsed one connection") + } + if len(payload) != th.chunksParsed { + t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed) + } + if len(payload) != th.bytesRead { + t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) + } + if count != len(raw) { + t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), count) + } +} + +func testParseNBlocks(t *testing.T, count int) { + ctx := context.Background() + //ctx, cancel := context.WithCancel(ctx) + + th := &testHandler{ + conns: map[string]*Conn{}, + } + + nAddr := 1 + if count > 2 { + nAddr = count - 2 + } + p := NewParser(ctx, th) + raw := []byte{} + for i := 0; i < count; i++ { + if i > 2 { + copied := src + src = copied + src.port += i + } + h, b, err := Encode(src, dst, domain, payload) + if nil != err { + t.Fatal(err) + } + raw = append(raw, h...) + raw = append(raw, b...) + } + n, err := p.Write(raw) + if nil != err { + t.Fatal(err) + } + + if nAddr != len(th.conns) { + t.Fatalf("should have parsed %d connection(s)", nAddr) + } + if count != th.chunksParsed { + t.Fatalf("should have parsed %d chunk(s)", count) + } + if count*len(payload) != th.bytesRead { + t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead) + } + if n != len(raw) { + t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), n) + } +} diff --git a/mplexer/packer/v1.go b/mplexer/packer/v1.go index a5dc220..97aed4f 100644 --- a/mplexer/packer/v1.go +++ b/mplexer/packer/v1.go @@ -28,71 +28,86 @@ type Header struct { Service string } -func (p *Parser) unpackV1(b []byte) error { +func (p *Parser) unpackV1(b []byte) (int, error) { z := 0 for { - if z > 10 { + if z > 20 { panic("stuck in an infinite loop?") } z += 1 n := len(b) - // at least one loop - if z > 1 && n < 1 { - fmt.Println("v1 end", z, n) + if n < 1 { + //fmt.Println("[debug] v1 end", z, n) break } var err error switch p.parseState { + case VersionState: + //fmt.Println("[debug] version state", b[0]) + p.state.version = b[0] + b = b[1:] + p.consumed += 1 + p.parseState += 1 case HeaderLengthState: - fmt.Println("v1 h len") + //fmt.Println("[debug] v1 h len") b = p.unpackV1HeaderLength(b) case HeaderState: - fmt.Println("v1 header") + //fmt.Println("[debug] v1 header") b, err = p.unpackV1Header(b, n) if nil != err { - fmt.Println("v1 header err", err) - return err + //fmt.Println("[debug] v1 header err", err) + consumed := p.consumed + p.consumed = 0 + return consumed, err } case PayloadState: - fmt.Println("v1 payload") + //fmt.Println("[debug] v1 payload") // if this payload is complete, reset all state if p.state.payloadWritten == p.state.payloadLen { p.state = ParserState{} + p.parseState = 0 } b, err = p.unpackV1Payload(b, n) if nil != err { - return err + consumed := p.consumed + p.consumed = 0 + return consumed, err } default: + fmt.Println("[debug] v1 unknown state") // do nothing - return errors.New("error unpacking") + consumed := p.consumed + p.consumed = 0 + return consumed, errors.New("error unpacking") } } - return nil + consumed := p.consumed + p.consumed = 0 + return consumed, nil } func (p *Parser) unpackV1HeaderLength(b []byte) []byte { p.state.headerLen = int(b[0]) - fmt.Println("unpacked header len", p.state.headerLen) + //fmt.Println("[debug] unpacked header len", p.state.headerLen) b = b[1:] - p.state.written += 1 + p.consumed += 1 p.parseState += 1 return b } func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) { - fmt.Println("got", len(b), "bytes", string(b)) + //fmt.Println("[debug] got", len(b), "bytes", string(b)) m := len(p.state.header) k := p.state.headerLen - m if n < k { k = n } - p.state.written += k + p.consumed += k c := b[0:k] b = b[k:] - fmt.Println("has", m, "want", k, "more and have", len(b), "more") + //fmt.Println("[debug] has", m, "want", k, "more and have", len(b), "more") p.state.header = append(p.state.header, c...) if p.state.headerLen != len(p.state.header) { return b, nil @@ -162,12 +177,13 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) { return b, nil */ + //fmt.Printf("[debug] [2] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen) p.handler.WriteMessage(p.state.addr, []byte{}) return b, nil } k := p.state.payloadLen - p.state.payloadWritten - if k < n { + if n < k { k = n } c := b[0:k] @@ -176,7 +192,6 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) { // and also put backpressure on just that connection /* m, err := p.state.conn.local.Write(c) - p.state.written += m p.state.payloadWritten += m if nil != err { // TODO we want to surface this error somewhere, but not to the websocket @@ -184,13 +199,14 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) { } */ p.handler.WriteMessage(p.state.addr, c) - p.state.written += k + p.consumed += k p.state.payloadWritten += k - p.written = p.state.written + //fmt.Printf("[debug] [1] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen) // if this payload is complete, reset all state if p.state.payloadWritten == p.state.payloadLen { p.state = ParserState{} + p.parseState = 0 } return b, nil }