add Decoder and more Parser tests

This commit is contained in:
AJ ONeal 2020-05-19 01:06:10 -06:00
parent f67fc7324d
commit 7a6b5741a5
5 changed files with 224 additions and 37 deletions

77
mplexer/packer/decoder.go Normal file
View File

@ -0,0 +1,77 @@
package packer
import (
"context"
"errors"
"io"
)
// Decoder handles a ReadCloser stream containing mplexy-encoded clients
type Decoder struct {
ctx context.Context
r io.ReadCloser
}
// NewDecoder returns an initialized Decoder
func NewDecoder(ctx context.Context, r io.ReadCloser) *Decoder {
return &Decoder{
ctx: ctx,
r: r,
}
}
// StreamDecode will call WriteMessage as often as addressable data exists,
// reading up to bufferSize (default 8192) at a time
// (header + data, though headers are often sent separately from data).
func (d *Decoder) StreamDecode(handler Handler, bufferSize int) error {
p := NewParser(handler)
rx := make(chan []byte)
rxErr := make(chan error)
if 0 == bufferSize {
bufferSize = 8192
}
go func() {
b := make([]byte, bufferSize)
for {
//fmt.Println("loopers gonna loop")
n, err := d.r.Read(b)
if n > 0 {
rx <- b[:n]
}
if nil != err {
rxErr <- err
return
}
}
}()
for {
//fmt.Println("poopers gonna poop")
select {
// TODO, do we actually need ctx here?
// would it be sufficient to expect the reader to be closed by the caller instead?
case <-d.ctx.Done():
// TODO: verify that closing the reader will cause the goroutine to be released
d.r.Close()
return errors.New("cancelled by context")
case b := <-rx:
_, err := p.Write(b)
if nil != err {
// an error to write represents an unrecoverable error,
// not just a downstream client error
d.r.Close()
return err
}
case err := <-rxErr:
d.r.Close()
if io.EOF == err {
// it can be assumed that err will close though, right
return nil
}
return err
}
}
}

View File

@ -0,0 +1,74 @@
package packer
import (
"context"
"net"
"testing"
)
func TestDecode1WholeBlock(t *testing.T) {
testDecodeNBlocks(t, 1)
}
func testDecodeNBlocks(t *testing.T, count int) {
wp, rp := net.Pipe()
ctx := context.Background()
decoder := NewDecoder(ctx, rp)
nAddr := 1
if count > 2 {
nAddr = count - 2
}
raw := []byte{}
for i := 0; i < count; i++ {
if i > 2 {
copied := src
src = copied
src.port += i
}
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw = append(raw, h...)
raw = append(raw, b...)
}
var nw int
go func() {
var err error
//fmt.Println("writers gonna write")
nw, err = wp.Write(raw)
if nil != err {
//fmt.Println("writer died")
t.Fatal(err)
}
// very important: don't forget to close when done!
wp.Close()
//fmt.Println("writer done wrote")
}()
th := &testHandler{
conns: map[string]*Conn{},
}
//fmt.Println("streamers gonna stream")
err := decoder.StreamDecode(th, 0)
if nil != err {
t.Fatalf("failed to decode stream: %s", err)
}
//fmt.Println("streamer done streamed")
if nAddr != len(th.conns) {
t.Fatalf("should have parsed %d connection(s)", nAddr)
}
if count != th.chunksParsed {
t.Fatalf("should have parsed %d chunk(s)", count)
}
if count*len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead)
}
if nw != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), nw)
}
}

View File

@ -1,12 +1,10 @@
package packer package packer
import ( import (
"context"
"errors" "errors"
) )
type Parser struct { type Parser struct {
ctx context.Context
handler Handler handler Handler
newConns chan *Conn newConns chan *Conn
conns map[string]*Conn conns map[string]*Conn
@ -38,9 +36,8 @@ const (
VersionState State = 0 VersionState State = 0
) )
func NewParser(ctx context.Context, handler Handler) *Parser { func NewParser(handler Handler) *Parser {
return &Parser{ return &Parser{
ctx: ctx,
conns: make(map[string]*Conn), conns: make(map[string]*Conn),
newConns: make(chan *Conn, 2), // Buffered to make testing easier newConns: make(chan *Conn, 2), // Buffered to make testing easier
dataReady: make(chan struct{}, 2), dataReady: make(chan struct{}, 2),
@ -73,8 +70,8 @@ func (p *Parser) Write(b []byte) (int, error) {
//fmt.Println("[debug] version state", b[0]) //fmt.Println("[debug] version state", b[0])
p.state.version = b[0] p.state.version = b[0]
b = b[1:] b = b[1:]
p.consumed += 1 p.consumed++
p.parseState += 1 p.parseState++
default: default:
// do nothing // do nothing
} }

View File

@ -1,7 +1,7 @@
package packer package packer
import ( import (
"context" "math/rand"
"net" "net"
"testing" "testing"
"time" "time"
@ -27,7 +27,7 @@ type testHandler struct {
} }
func (th *testHandler) WriteMessage(a Addr, b []byte) { func (th *testHandler) WriteMessage(a Addr, b []byte) {
th.chunksParsed += 1 th.chunksParsed++
addr := &a addr := &a
_, ok := th.conns[addr.Network()] _, ok := th.conns[addr.Network()]
if !ok { if !ok {
@ -63,15 +63,38 @@ func TestParse3Addrs(t *testing.T) {
testParseNBlocks(t, 5) testParseNBlocks(t, 5)
} }
func TestParse1AndRest(t *testing.T) { func TestParseBy1(t *testing.T) {
ctx := context.Background() testParseByN(t, 1)
//ctx, cancel := context.WithCancel(ctx) }
func TestParseByPrimes(t *testing.T) {
testParseByN(t, 2)
testParseByN(t, 3)
testParseByN(t, 5)
testParseByN(t, 7)
testParseByN(t, 11)
testParseByN(t, 13)
testParseByN(t, 17)
testParseByN(t, 19)
testParseByN(t, 23)
testParseByN(t, 29)
testParseByN(t, 31)
testParseByN(t, 37)
testParseByN(t, 41)
testParseByN(t, 43)
testParseByN(t, 47)
}
func TestParseByRand(t *testing.T) {
testParseByN(t, 0)
}
func TestParse1AndRest(t *testing.T) {
th := &testHandler{ th := &testHandler{
conns: map[string]*Conn{}, conns: map[string]*Conn{},
} }
p := NewParser(ctx, th) p := NewParser(th)
h, b, err := Encode(src, dst, domain, payload) h, b, err := Encode(src, dst, domain, payload)
if nil != err { if nil != err {
@ -102,14 +125,11 @@ func TestParse1AndRest(t *testing.T) {
} }
func TestParseRestAnd1(t *testing.T) { func TestParseRestAnd1(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{ th := &testHandler{
conns: map[string]*Conn{}, conns: map[string]*Conn{},
} }
p := NewParser(ctx, th) p := NewParser(th)
h, b, err := Encode(src, dst, domain, payload) h, b, err := Encode(src, dst, domain, payload)
if nil != err { if nil != err {
@ -140,15 +160,13 @@ func TestParseRestAnd1(t *testing.T) {
} }
} }
func TestParse1By1(t *testing.T) { func testParseByN(t *testing.T, n int) {
ctx := context.Background() //fmt.Printf("[debug] parse by %d\n", n)
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{ th := &testHandler{
conns: map[string]*Conn{}, conns: map[string]*Conn{},
} }
p := NewParser(ctx, th) p := NewParser(th)
h, b, err := Encode(src, dst, domain, payload) h, b, err := Encode(src, dst, domain, payload)
if nil != err { if nil != err {
@ -156,19 +174,43 @@ func TestParse1By1(t *testing.T) {
} }
raw := append(h, b...) raw := append(h, b...)
count := 0 count := 0
for _, b := range raw { nChunk := 0
n, err := p.Write([]byte{b}) b = raw
for {
r := 24
c := len(b)
if 0 == c {
break
}
i := n
if 0 == n {
if c < r {
r = c
}
i = 1 + rand.Intn(r+1)
}
if c < i {
i = c
}
// TODO shouldn't this cause an error?
//a := b[:i][0]
a := b[:i]
b = b[i:]
nw, err := p.Write(a)
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
count += n count += nw
if count > len(h) {
nChunk++
}
} }
if 1 != len(th.conns) { if 1 != len(th.conns) {
t.Fatal("should have parsed one connection") t.Fatalf("should have parsed one connection, not %d", len(th.conns))
} }
if len(payload) != th.chunksParsed { if nChunk != th.chunksParsed {
t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed) t.Fatalf("should have parsed %d chunk(s), not %d", nChunk, th.chunksParsed)
} }
if len(payload) != th.bytesRead { if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
@ -179,9 +221,6 @@ func TestParse1By1(t *testing.T) {
} }
func testParseNBlocks(t *testing.T, count int) { func testParseNBlocks(t *testing.T, count int) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{ th := &testHandler{
conns: map[string]*Conn{}, conns: map[string]*Conn{},
} }
@ -190,7 +229,7 @@ func testParseNBlocks(t *testing.T, count int) {
if count > 2 { if count > 2 {
nAddr = count - 2 nAddr = count - 2
} }
p := NewParser(ctx, th) p := NewParser(th)
raw := []byte{} raw := []byte{}
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
if i > 2 { if i > 2 {

View File

@ -34,7 +34,7 @@ func (p *Parser) unpackV1(b []byte) (int, error) {
if z > 20 { if z > 20 {
panic("stuck in an infinite loop?") panic("stuck in an infinite loop?")
} }
z += 1 z++
n := len(b) n := len(b)
if n < 1 { if n < 1 {
//fmt.Println("[debug] v1 end", z, n) //fmt.Println("[debug] v1 end", z, n)
@ -47,8 +47,8 @@ func (p *Parser) unpackV1(b []byte) (int, error) {
//fmt.Println("[debug] version state", b[0]) //fmt.Println("[debug] version state", b[0])
p.state.version = b[0] p.state.version = b[0]
b = b[1:] b = b[1:]
p.consumed += 1 p.consumed++
p.parseState += 1 p.parseState++
case HeaderLengthState: case HeaderLengthState:
//fmt.Println("[debug] v1 h len") //fmt.Println("[debug] v1 h len")
b = p.unpackV1HeaderLength(b) b = p.unpackV1HeaderLength(b)
@ -92,8 +92,8 @@ func (p *Parser) unpackV1HeaderLength(b []byte) []byte {
p.state.headerLen = int(b[0]) p.state.headerLen = int(b[0])
//fmt.Println("[debug] unpacked header len", p.state.headerLen) //fmt.Println("[debug] unpacked header len", p.state.headerLen)
b = b[1:] b = b[1:]
p.consumed += 1 p.consumed++
p.parseState += 1 p.parseState++
return b return b
} }
@ -154,7 +154,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
p.newConns <- p.state.conn p.newConns <- p.state.conn
} }
*/ */
p.parseState += 1 p.parseState++
return b, nil return b, nil
} }