diff --git a/mplexer/packer/decoder.go b/mplexer/packer/decoder.go new file mode 100644 index 0000000..c968e60 --- /dev/null +++ b/mplexer/packer/decoder.go @@ -0,0 +1,77 @@ +package packer + +import ( + "context" + "errors" + "io" +) + +// Decoder handles a ReadCloser stream containing mplexy-encoded clients +type Decoder struct { + ctx context.Context + r io.ReadCloser +} + +// NewDecoder returns an initialized Decoder +func NewDecoder(ctx context.Context, r io.ReadCloser) *Decoder { + return &Decoder{ + ctx: ctx, + r: r, + } +} + +// StreamDecode will call WriteMessage as often as addressable data exists, +// reading up to bufferSize (default 8192) at a time +// (header + data, though headers are often sent separately from data). +func (d *Decoder) StreamDecode(handler Handler, bufferSize int) error { + p := NewParser(handler) + rx := make(chan []byte) + rxErr := make(chan error) + + if 0 == bufferSize { + bufferSize = 8192 + } + + go func() { + b := make([]byte, bufferSize) + for { + //fmt.Println("loopers gonna loop") + n, err := d.r.Read(b) + if n > 0 { + rx <- b[:n] + } + if nil != err { + rxErr <- err + return + } + } + }() + + for { + //fmt.Println("poopers gonna poop") + select { + // TODO, do we actually need ctx here? + // would it be sufficient to expect the reader to be closed by the caller instead? + case <-d.ctx.Done(): + // TODO: verify that closing the reader will cause the goroutine to be released + d.r.Close() + return errors.New("cancelled by context") + case b := <-rx: + _, err := p.Write(b) + if nil != err { + // an error to write represents an unrecoverable error, + // not just a downstream client error + d.r.Close() + return err + } + case err := <-rxErr: + d.r.Close() + if io.EOF == err { + // it can be assumed that err will close though, right + return nil + } + return err + } + + } +} diff --git a/mplexer/packer/decoder_test.go b/mplexer/packer/decoder_test.go new file mode 100644 index 0000000..eaa2207 --- /dev/null +++ b/mplexer/packer/decoder_test.go @@ -0,0 +1,74 @@ +package packer + +import ( + "context" + "net" + "testing" +) + +func TestDecode1WholeBlock(t *testing.T) { + testDecodeNBlocks(t, 1) +} + +func testDecodeNBlocks(t *testing.T, count int) { + wp, rp := net.Pipe() + + ctx := context.Background() + decoder := NewDecoder(ctx, rp) + nAddr := 1 + if count > 2 { + nAddr = count - 2 + } + + raw := []byte{} + for i := 0; i < count; i++ { + if i > 2 { + copied := src + src = copied + src.port += i + } + h, b, err := Encode(src, dst, domain, payload) + if nil != err { + t.Fatal(err) + } + raw = append(raw, h...) + raw = append(raw, b...) + } + + var nw int + go func() { + var err error + //fmt.Println("writers gonna write") + nw, err = wp.Write(raw) + if nil != err { + //fmt.Println("writer died") + t.Fatal(err) + } + // very important: don't forget to close when done! + wp.Close() + //fmt.Println("writer done wrote") + }() + + th := &testHandler{ + conns: map[string]*Conn{}, + } + //fmt.Println("streamers gonna stream") + err := decoder.StreamDecode(th, 0) + if nil != err { + t.Fatalf("failed to decode stream: %s", err) + } + //fmt.Println("streamer done streamed") + + if nAddr != len(th.conns) { + t.Fatalf("should have parsed %d connection(s)", nAddr) + } + if count != th.chunksParsed { + t.Fatalf("should have parsed %d chunk(s)", count) + } + if count*len(payload) != th.bytesRead { + t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead) + } + if nw != len(raw) { + t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), nw) + } +} diff --git a/mplexer/packer/parser.go b/mplexer/packer/parser.go index 9688e1f..780e39a 100644 --- a/mplexer/packer/parser.go +++ b/mplexer/packer/parser.go @@ -1,12 +1,10 @@ package packer import ( - "context" "errors" ) type Parser struct { - ctx context.Context handler Handler newConns chan *Conn conns map[string]*Conn @@ -38,9 +36,8 @@ const ( VersionState State = 0 ) -func NewParser(ctx context.Context, handler Handler) *Parser { +func NewParser(handler Handler) *Parser { return &Parser{ - ctx: ctx, conns: make(map[string]*Conn), newConns: make(chan *Conn, 2), // Buffered to make testing easier dataReady: make(chan struct{}, 2), @@ -73,8 +70,8 @@ func (p *Parser) Write(b []byte) (int, error) { //fmt.Println("[debug] version state", b[0]) p.state.version = b[0] b = b[1:] - p.consumed += 1 - p.parseState += 1 + p.consumed++ + p.parseState++ default: // do nothing } diff --git a/mplexer/packer/parser_test.go b/mplexer/packer/parser_test.go index c636b98..a6ca7e9 100644 --- a/mplexer/packer/parser_test.go +++ b/mplexer/packer/parser_test.go @@ -1,7 +1,7 @@ package packer import ( - "context" + "math/rand" "net" "testing" "time" @@ -27,7 +27,7 @@ type testHandler struct { } func (th *testHandler) WriteMessage(a Addr, b []byte) { - th.chunksParsed += 1 + th.chunksParsed++ addr := &a _, ok := th.conns[addr.Network()] if !ok { @@ -63,15 +63,38 @@ func TestParse3Addrs(t *testing.T) { testParseNBlocks(t, 5) } -func TestParse1AndRest(t *testing.T) { - ctx := context.Background() - //ctx, cancel := context.WithCancel(ctx) +func TestParseBy1(t *testing.T) { + testParseByN(t, 1) +} +func TestParseByPrimes(t *testing.T) { + testParseByN(t, 2) + testParseByN(t, 3) + testParseByN(t, 5) + testParseByN(t, 7) + testParseByN(t, 11) + testParseByN(t, 13) + testParseByN(t, 17) + testParseByN(t, 19) + testParseByN(t, 23) + testParseByN(t, 29) + testParseByN(t, 31) + testParseByN(t, 37) + testParseByN(t, 41) + testParseByN(t, 43) + testParseByN(t, 47) +} + +func TestParseByRand(t *testing.T) { + testParseByN(t, 0) +} + +func TestParse1AndRest(t *testing.T) { th := &testHandler{ conns: map[string]*Conn{}, } - p := NewParser(ctx, th) + p := NewParser(th) h, b, err := Encode(src, dst, domain, payload) if nil != err { @@ -102,14 +125,11 @@ func TestParse1AndRest(t *testing.T) { } func TestParseRestAnd1(t *testing.T) { - ctx := context.Background() - //ctx, cancel := context.WithCancel(ctx) - th := &testHandler{ conns: map[string]*Conn{}, } - p := NewParser(ctx, th) + p := NewParser(th) h, b, err := Encode(src, dst, domain, payload) if nil != err { @@ -140,15 +160,13 @@ func TestParseRestAnd1(t *testing.T) { } } -func TestParse1By1(t *testing.T) { - ctx := context.Background() - //ctx, cancel := context.WithCancel(ctx) - +func testParseByN(t *testing.T, n int) { + //fmt.Printf("[debug] parse by %d\n", n) th := &testHandler{ conns: map[string]*Conn{}, } - p := NewParser(ctx, th) + p := NewParser(th) h, b, err := Encode(src, dst, domain, payload) if nil != err { @@ -156,19 +174,43 @@ func TestParse1By1(t *testing.T) { } raw := append(h, b...) count := 0 - for _, b := range raw { - n, err := p.Write([]byte{b}) + nChunk := 0 + b = raw + for { + r := 24 + c := len(b) + if 0 == c { + break + } + i := n + if 0 == n { + if c < r { + r = c + } + i = 1 + rand.Intn(r+1) + } + if c < i { + i = c + } + // TODO shouldn't this cause an error? + //a := b[:i][0] + a := b[:i] + b = b[i:] + nw, err := p.Write(a) if nil != err { t.Fatal(err) } - count += n + count += nw + if count > len(h) { + nChunk++ + } } if 1 != len(th.conns) { - t.Fatal("should have parsed one connection") + t.Fatalf("should have parsed one connection, not %d", len(th.conns)) } - if len(payload) != th.chunksParsed { - t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed) + if nChunk != th.chunksParsed { + t.Fatalf("should have parsed %d chunk(s), not %d", nChunk, th.chunksParsed) } if len(payload) != th.bytesRead { t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) @@ -179,9 +221,6 @@ func TestParse1By1(t *testing.T) { } func testParseNBlocks(t *testing.T, count int) { - ctx := context.Background() - //ctx, cancel := context.WithCancel(ctx) - th := &testHandler{ conns: map[string]*Conn{}, } @@ -190,7 +229,7 @@ func testParseNBlocks(t *testing.T, count int) { if count > 2 { nAddr = count - 2 } - p := NewParser(ctx, th) + p := NewParser(th) raw := []byte{} for i := 0; i < count; i++ { if i > 2 { diff --git a/mplexer/packer/v1.go b/mplexer/packer/v1.go index 97aed4f..6949f9e 100644 --- a/mplexer/packer/v1.go +++ b/mplexer/packer/v1.go @@ -34,7 +34,7 @@ func (p *Parser) unpackV1(b []byte) (int, error) { if z > 20 { panic("stuck in an infinite loop?") } - z += 1 + z++ n := len(b) if n < 1 { //fmt.Println("[debug] v1 end", z, n) @@ -47,8 +47,8 @@ func (p *Parser) unpackV1(b []byte) (int, error) { //fmt.Println("[debug] version state", b[0]) p.state.version = b[0] b = b[1:] - p.consumed += 1 - p.parseState += 1 + p.consumed++ + p.parseState++ case HeaderLengthState: //fmt.Println("[debug] v1 h len") b = p.unpackV1HeaderLength(b) @@ -92,8 +92,8 @@ func (p *Parser) unpackV1HeaderLength(b []byte) []byte { p.state.headerLen = int(b[0]) //fmt.Println("[debug] unpacked header len", p.state.headerLen) b = b[1:] - p.consumed += 1 - p.parseState += 1 + p.consumed++ + p.parseState++ return b } @@ -154,7 +154,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) { p.newConns <- p.state.conn } */ - p.parseState += 1 + p.parseState++ return b, nil }