more parser tests

This commit is contained in:
AJ ONeal 2020-05-18 22:36:20 -06:00
parent c5df63b11d
commit f67fc7324d
4 changed files with 220 additions and 56 deletions

View File

@ -81,7 +81,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
for { for {
time.Sleep(15 * time.Second) time.Sleep(15 * time.Second)
deadline := time.Now().Add(45 * time.Second) deadline := time.Now().Add(45 * time.Second)
if err := wsconn.WriteControl(websocket.PingMessage, "", deadline); nil != err { if err := wsconn.WriteControl(websocket.PingMessage, []byte(""), deadline); nil != err {
fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err) fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err)
cancel() cancel()
break break
@ -94,7 +94,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
// TODO optimal buffer size // TODO optimal buffer size
b := make([]byte, 128*1024) b := make([]byte, 128*1024)
for { for {
n, err := listener.packer.Read(b) n, err := listener.parser.Read(b)
if n > 0 { if n > 0 {
if err := wsconn.WriteMessage(websocket.BinaryMessage, b); nil != err { if err := wsconn.WriteMessage(websocket.BinaryMessage, b); nil != err {
fmt.Fprintf(os.Stderr, "failed to write packer message to websocket: %s\n", err) fmt.Fprintf(os.Stderr, "failed to write packer message to websocket: %s\n", err)

View File

@ -3,7 +3,6 @@ package packer
import ( import (
"context" "context"
"errors" "errors"
"fmt"
) )
type Parser struct { type Parser struct {
@ -15,11 +14,11 @@ type Parser struct {
parseState State parseState State
dataReady chan struct{} dataReady chan struct{}
data []byte data []byte
written int consumed int
} }
type ParserState struct { type ParserState struct {
written int consumed int
version byte version byte
headerLen int headerLen int
header []byte header []byte
@ -60,19 +59,21 @@ func (p *Parser) Write(b []byte) (int, error) {
return 0, errors.New("developer error: wrote 0 bytes") return 0, errors.New("developer error: wrote 0 bytes")
} }
// so that we can overwrite the main state /*
// as soon as a full message has completed // so that we can overwrite the main state
// but still keep the number of bytes written // as soon as a full message has completed
if 0 == p.state.written { // but still keep the number of bytes written
p.written = 0 if 0 == p.state.written {
} p.written = 0
}
*/
switch p.parseState { switch p.parseState {
case VersionState: case VersionState:
fmt.Println("version state", b[0]) //fmt.Println("[debug] version state", b[0])
p.state.version = b[0] p.state.version = b[0]
b = b[1:] b = b[1:]
p.state.written += 1 p.consumed += 1
p.parseState += 1 p.parseState += 1
default: default:
// do nothing // do nothing
@ -80,8 +81,8 @@ func (p *Parser) Write(b []byte) (int, error) {
switch p.state.version { switch p.state.version {
case V1: case V1:
fmt.Println("v1 unmarshal") //fmt.Println("[debug] v1 unmarshal")
return p.written, p.unpackV1(b) return p.unpackV1(b)
default: default:
return 0, errors.New("incorrect version or version not implemented") return 0, errors.New("incorrect version or version not implemented")
} }

View File

@ -2,12 +2,24 @@ package packer
import ( import (
"context" "context"
"fmt"
"net" "net"
"testing" "testing"
"time" "time"
) )
var src = Addr{
family: "IPv4",
addr: "192.168.1.101",
port: 6743,
}
var dst = Addr{
family: "IPv4",
port: 80,
scheme: "http",
}
var domain = "ex1.telebit.io"
var payload = []byte("Hello, World!")
type testHandler struct { type testHandler struct {
conns map[string]*Conn conns map[string]*Conn
chunksParsed int chunksParsed int
@ -15,6 +27,7 @@ type testHandler struct {
} }
func (th *testHandler) WriteMessage(a Addr, b []byte) { func (th *testHandler) WriteMessage(a Addr, b []byte) {
th.chunksParsed += 1
addr := &a addr := &a
_, ok := th.conns[addr.Network()] _, ok := th.conns[addr.Network()]
if !ok { if !ok {
@ -27,11 +40,30 @@ func (th *testHandler) WriteMessage(a Addr, b []byte) {
} }
th.conns[addr.Network()] = conn th.conns[addr.Network()] = conn
} }
th.chunksParsed += 1
th.bytesRead += len(b) th.bytesRead += len(b)
} }
func TestParseWholeBlock(t *testing.T) { func TestParse1WholeBlock(t *testing.T) {
testParseNBlocks(t, 1)
}
func TestParse2WholeBlocks(t *testing.T) {
testParseNBlocks(t, 2)
}
func TestParse3WholeBlocks(t *testing.T) {
testParseNBlocks(t, 3)
}
func TestParse2Addrs(t *testing.T) {
testParseNBlocks(t, 4)
}
func TestParse3Addrs(t *testing.T) {
testParseNBlocks(t, 5)
}
func TestParse1AndRest(t *testing.T) {
ctx := context.Background() ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx) //ctx, cancel := context.WithCancel(ctx)
@ -40,25 +72,17 @@ func TestParseWholeBlock(t *testing.T) {
} }
p := NewParser(ctx, th) p := NewParser(ctx, th)
payload := []byte(`Hello, World!`)
fmt.Println("payload len", len(payload))
src := Addr{
family: "IPv4",
addr: "192.168.1.101",
port: 6743,
}
dst := Addr{
family: "IPv4",
port: 80,
scheme: "http",
}
domain := "ex1.telebit.io"
h, b, err := Encode(src, dst, domain, payload) h, b, err := Encode(src, dst, domain, payload)
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
raw := append(h, b...) raw := append(h, b...)
n, err := p.Write(raw) n, err := p.Write(raw[:1])
if nil != err {
t.Fatal(err)
}
m, err := p.Write(raw[1:])
if nil != err { if nil != err {
t.Fatal(err) t.Fatal(err)
} }
@ -67,12 +91,135 @@ func TestParseWholeBlock(t *testing.T) {
t.Fatal("should have parsed one connection") t.Fatal("should have parsed one connection")
} }
if 1 != th.chunksParsed { if 1 != th.chunksParsed {
t.Fatal("should have parsed one chunck") t.Fatal("should have parsed 1 chunck(s)")
} }
if len(payload) != th.bytesRead { if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead) t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
} }
if n != len(raw) { if n+m != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw)) t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
} }
} }
func TestParseRestAnd1(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
p := NewParser(ctx, th)
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw := append(h, b...)
i := len(raw)
n, err := p.Write(raw[:i-1])
if nil != err {
t.Fatal(err)
}
m, err := p.Write(raw[i-1:])
if nil != err {
t.Fatal(err)
}
if 1 != len(th.conns) {
t.Fatal("should have parsed one connection")
}
if 2 != th.chunksParsed {
t.Fatal("should have parsed 2 chunck(s)")
}
if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
}
if n+m != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
}
}
func TestParse1By1(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
p := NewParser(ctx, th)
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw := append(h, b...)
count := 0
for _, b := range raw {
n, err := p.Write([]byte{b})
if nil != err {
t.Fatal(err)
}
count += n
}
if 1 != len(th.conns) {
t.Fatal("should have parsed one connection")
}
if len(payload) != th.chunksParsed {
t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed)
}
if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
}
if count != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), count)
}
}
func testParseNBlocks(t *testing.T, count int) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
nAddr := 1
if count > 2 {
nAddr = count - 2
}
p := NewParser(ctx, th)
raw := []byte{}
for i := 0; i < count; i++ {
if i > 2 {
copied := src
src = copied
src.port += i
}
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw = append(raw, h...)
raw = append(raw, b...)
}
n, err := p.Write(raw)
if nil != err {
t.Fatal(err)
}
if nAddr != len(th.conns) {
t.Fatalf("should have parsed %d connection(s)", nAddr)
}
if count != th.chunksParsed {
t.Fatalf("should have parsed %d chunk(s)", count)
}
if count*len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead)
}
if n != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), n)
}
}

View File

@ -28,71 +28,86 @@ type Header struct {
Service string Service string
} }
func (p *Parser) unpackV1(b []byte) error { func (p *Parser) unpackV1(b []byte) (int, error) {
z := 0 z := 0
for { for {
if z > 10 { if z > 20 {
panic("stuck in an infinite loop?") panic("stuck in an infinite loop?")
} }
z += 1 z += 1
n := len(b) n := len(b)
// at least one loop if n < 1 {
if z > 1 && n < 1 { //fmt.Println("[debug] v1 end", z, n)
fmt.Println("v1 end", z, n)
break break
} }
var err error var err error
switch p.parseState { switch p.parseState {
case VersionState:
//fmt.Println("[debug] version state", b[0])
p.state.version = b[0]
b = b[1:]
p.consumed += 1
p.parseState += 1
case HeaderLengthState: case HeaderLengthState:
fmt.Println("v1 h len") //fmt.Println("[debug] v1 h len")
b = p.unpackV1HeaderLength(b) b = p.unpackV1HeaderLength(b)
case HeaderState: case HeaderState:
fmt.Println("v1 header") //fmt.Println("[debug] v1 header")
b, err = p.unpackV1Header(b, n) b, err = p.unpackV1Header(b, n)
if nil != err { if nil != err {
fmt.Println("v1 header err", err) //fmt.Println("[debug] v1 header err", err)
return err consumed := p.consumed
p.consumed = 0
return consumed, err
} }
case PayloadState: case PayloadState:
fmt.Println("v1 payload") //fmt.Println("[debug] v1 payload")
// if this payload is complete, reset all state // if this payload is complete, reset all state
if p.state.payloadWritten == p.state.payloadLen { if p.state.payloadWritten == p.state.payloadLen {
p.state = ParserState{} p.state = ParserState{}
p.parseState = 0
} }
b, err = p.unpackV1Payload(b, n) b, err = p.unpackV1Payload(b, n)
if nil != err { if nil != err {
return err consumed := p.consumed
p.consumed = 0
return consumed, err
} }
default: default:
fmt.Println("[debug] v1 unknown state")
// do nothing // do nothing
return errors.New("error unpacking") consumed := p.consumed
p.consumed = 0
return consumed, errors.New("error unpacking")
} }
} }
return nil consumed := p.consumed
p.consumed = 0
return consumed, nil
} }
func (p *Parser) unpackV1HeaderLength(b []byte) []byte { func (p *Parser) unpackV1HeaderLength(b []byte) []byte {
p.state.headerLen = int(b[0]) p.state.headerLen = int(b[0])
fmt.Println("unpacked header len", p.state.headerLen) //fmt.Println("[debug] unpacked header len", p.state.headerLen)
b = b[1:] b = b[1:]
p.state.written += 1 p.consumed += 1
p.parseState += 1 p.parseState += 1
return b return b
} }
func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) { func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
fmt.Println("got", len(b), "bytes", string(b)) //fmt.Println("[debug] got", len(b), "bytes", string(b))
m := len(p.state.header) m := len(p.state.header)
k := p.state.headerLen - m k := p.state.headerLen - m
if n < k { if n < k {
k = n k = n
} }
p.state.written += k p.consumed += k
c := b[0:k] c := b[0:k]
b = b[k:] b = b[k:]
fmt.Println("has", m, "want", k, "more and have", len(b), "more") //fmt.Println("[debug] has", m, "want", k, "more and have", len(b), "more")
p.state.header = append(p.state.header, c...) p.state.header = append(p.state.header, c...)
if p.state.headerLen != len(p.state.header) { if p.state.headerLen != len(p.state.header) {
return b, nil return b, nil
@ -162,12 +177,13 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
return b, nil return b, nil
*/ */
//fmt.Printf("[debug] [2] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
p.handler.WriteMessage(p.state.addr, []byte{}) p.handler.WriteMessage(p.state.addr, []byte{})
return b, nil return b, nil
} }
k := p.state.payloadLen - p.state.payloadWritten k := p.state.payloadLen - p.state.payloadWritten
if k < n { if n < k {
k = n k = n
} }
c := b[0:k] c := b[0:k]
@ -176,7 +192,6 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
// and also put backpressure on just that connection // and also put backpressure on just that connection
/* /*
m, err := p.state.conn.local.Write(c) m, err := p.state.conn.local.Write(c)
p.state.written += m
p.state.payloadWritten += m p.state.payloadWritten += m
if nil != err { if nil != err {
// TODO we want to surface this error somewhere, but not to the websocket // TODO we want to surface this error somewhere, but not to the websocket
@ -184,13 +199,14 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
} }
*/ */
p.handler.WriteMessage(p.state.addr, c) p.handler.WriteMessage(p.state.addr, c)
p.state.written += k p.consumed += k
p.state.payloadWritten += k p.state.payloadWritten += k
p.written = p.state.written
//fmt.Printf("[debug] [1] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
// if this payload is complete, reset all state // if this payload is complete, reset all state
if p.state.payloadWritten == p.state.payloadLen { if p.state.payloadWritten == p.state.payloadLen {
p.state = ParserState{} p.state = ParserState{}
p.parseState = 0
} }
return b, nil return b, nil
} }