more parser tests

This commit is contained in:
AJ ONeal 2020-05-18 22:36:20 -06:00
parent c5df63b11d
commit f67fc7324d
4 changed files with 220 additions and 56 deletions

View File

@ -81,7 +81,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
for {
time.Sleep(15 * time.Second)
deadline := time.Now().Add(45 * time.Second)
if err := wsconn.WriteControl(websocket.PingMessage, "", deadline); nil != err {
if err := wsconn.WriteControl(websocket.PingMessage, []byte(""), deadline); nil != err {
fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err)
cancel()
break
@ -94,7 +94,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
// TODO optimal buffer size
b := make([]byte, 128*1024)
for {
n, err := listener.packer.Read(b)
n, err := listener.parser.Read(b)
if n > 0 {
if err := wsconn.WriteMessage(websocket.BinaryMessage, b); nil != err {
fmt.Fprintf(os.Stderr, "failed to write packer message to websocket: %s\n", err)

View File

@ -3,7 +3,6 @@ package packer
import (
"context"
"errors"
"fmt"
)
type Parser struct {
@ -15,11 +14,11 @@ type Parser struct {
parseState State
dataReady chan struct{}
data []byte
written int
consumed int
}
type ParserState struct {
written int
consumed int
version byte
headerLen int
header []byte
@ -60,19 +59,21 @@ func (p *Parser) Write(b []byte) (int, error) {
return 0, errors.New("developer error: wrote 0 bytes")
}
// so that we can overwrite the main state
// as soon as a full message has completed
// but still keep the number of bytes written
if 0 == p.state.written {
p.written = 0
}
/*
// so that we can overwrite the main state
// as soon as a full message has completed
// but still keep the number of bytes written
if 0 == p.state.written {
p.written = 0
}
*/
switch p.parseState {
case VersionState:
fmt.Println("version state", b[0])
//fmt.Println("[debug] version state", b[0])
p.state.version = b[0]
b = b[1:]
p.state.written += 1
p.consumed += 1
p.parseState += 1
default:
// do nothing
@ -80,8 +81,8 @@ func (p *Parser) Write(b []byte) (int, error) {
switch p.state.version {
case V1:
fmt.Println("v1 unmarshal")
return p.written, p.unpackV1(b)
//fmt.Println("[debug] v1 unmarshal")
return p.unpackV1(b)
default:
return 0, errors.New("incorrect version or version not implemented")
}

View File

@ -2,12 +2,24 @@ package packer
import (
"context"
"fmt"
"net"
"testing"
"time"
)
var src = Addr{
family: "IPv4",
addr: "192.168.1.101",
port: 6743,
}
var dst = Addr{
family: "IPv4",
port: 80,
scheme: "http",
}
var domain = "ex1.telebit.io"
var payload = []byte("Hello, World!")
type testHandler struct {
conns map[string]*Conn
chunksParsed int
@ -15,6 +27,7 @@ type testHandler struct {
}
func (th *testHandler) WriteMessage(a Addr, b []byte) {
th.chunksParsed += 1
addr := &a
_, ok := th.conns[addr.Network()]
if !ok {
@ -27,11 +40,30 @@ func (th *testHandler) WriteMessage(a Addr, b []byte) {
}
th.conns[addr.Network()] = conn
}
th.chunksParsed += 1
th.bytesRead += len(b)
}
func TestParseWholeBlock(t *testing.T) {
func TestParse1WholeBlock(t *testing.T) {
testParseNBlocks(t, 1)
}
func TestParse2WholeBlocks(t *testing.T) {
testParseNBlocks(t, 2)
}
func TestParse3WholeBlocks(t *testing.T) {
testParseNBlocks(t, 3)
}
func TestParse2Addrs(t *testing.T) {
testParseNBlocks(t, 4)
}
func TestParse3Addrs(t *testing.T) {
testParseNBlocks(t, 5)
}
func TestParse1AndRest(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
@ -40,25 +72,17 @@ func TestParseWholeBlock(t *testing.T) {
}
p := NewParser(ctx, th)
payload := []byte(`Hello, World!`)
fmt.Println("payload len", len(payload))
src := Addr{
family: "IPv4",
addr: "192.168.1.101",
port: 6743,
}
dst := Addr{
family: "IPv4",
port: 80,
scheme: "http",
}
domain := "ex1.telebit.io"
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw := append(h, b...)
n, err := p.Write(raw)
n, err := p.Write(raw[:1])
if nil != err {
t.Fatal(err)
}
m, err := p.Write(raw[1:])
if nil != err {
t.Fatal(err)
}
@ -67,12 +91,135 @@ func TestParseWholeBlock(t *testing.T) {
t.Fatal("should have parsed one connection")
}
if 1 != th.chunksParsed {
t.Fatal("should have parsed one chunck")
t.Fatal("should have parsed 1 chunck(s)")
}
if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
}
if n != len(raw) {
if n+m != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
}
}
func TestParseRestAnd1(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
p := NewParser(ctx, th)
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw := append(h, b...)
i := len(raw)
n, err := p.Write(raw[:i-1])
if nil != err {
t.Fatal(err)
}
m, err := p.Write(raw[i-1:])
if nil != err {
t.Fatal(err)
}
if 1 != len(th.conns) {
t.Fatal("should have parsed one connection")
}
if 2 != th.chunksParsed {
t.Fatal("should have parsed 2 chunck(s)")
}
if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
}
if n+m != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
}
}
func TestParse1By1(t *testing.T) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
p := NewParser(ctx, th)
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw := append(h, b...)
count := 0
for _, b := range raw {
n, err := p.Write([]byte{b})
if nil != err {
t.Fatal(err)
}
count += n
}
if 1 != len(th.conns) {
t.Fatal("should have parsed one connection")
}
if len(payload) != th.chunksParsed {
t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed)
}
if len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
}
if count != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), count)
}
}
func testParseNBlocks(t *testing.T, count int) {
ctx := context.Background()
//ctx, cancel := context.WithCancel(ctx)
th := &testHandler{
conns: map[string]*Conn{},
}
nAddr := 1
if count > 2 {
nAddr = count - 2
}
p := NewParser(ctx, th)
raw := []byte{}
for i := 0; i < count; i++ {
if i > 2 {
copied := src
src = copied
src.port += i
}
h, b, err := Encode(src, dst, domain, payload)
if nil != err {
t.Fatal(err)
}
raw = append(raw, h...)
raw = append(raw, b...)
}
n, err := p.Write(raw)
if nil != err {
t.Fatal(err)
}
if nAddr != len(th.conns) {
t.Fatalf("should have parsed %d connection(s)", nAddr)
}
if count != th.chunksParsed {
t.Fatalf("should have parsed %d chunk(s)", count)
}
if count*len(payload) != th.bytesRead {
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead)
}
if n != len(raw) {
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), n)
}
}

View File

@ -28,71 +28,86 @@ type Header struct {
Service string
}
func (p *Parser) unpackV1(b []byte) error {
func (p *Parser) unpackV1(b []byte) (int, error) {
z := 0
for {
if z > 10 {
if z > 20 {
panic("stuck in an infinite loop?")
}
z += 1
n := len(b)
// at least one loop
if z > 1 && n < 1 {
fmt.Println("v1 end", z, n)
if n < 1 {
//fmt.Println("[debug] v1 end", z, n)
break
}
var err error
switch p.parseState {
case VersionState:
//fmt.Println("[debug] version state", b[0])
p.state.version = b[0]
b = b[1:]
p.consumed += 1
p.parseState += 1
case HeaderLengthState:
fmt.Println("v1 h len")
//fmt.Println("[debug] v1 h len")
b = p.unpackV1HeaderLength(b)
case HeaderState:
fmt.Println("v1 header")
//fmt.Println("[debug] v1 header")
b, err = p.unpackV1Header(b, n)
if nil != err {
fmt.Println("v1 header err", err)
return err
//fmt.Println("[debug] v1 header err", err)
consumed := p.consumed
p.consumed = 0
return consumed, err
}
case PayloadState:
fmt.Println("v1 payload")
//fmt.Println("[debug] v1 payload")
// if this payload is complete, reset all state
if p.state.payloadWritten == p.state.payloadLen {
p.state = ParserState{}
p.parseState = 0
}
b, err = p.unpackV1Payload(b, n)
if nil != err {
return err
consumed := p.consumed
p.consumed = 0
return consumed, err
}
default:
fmt.Println("[debug] v1 unknown state")
// do nothing
return errors.New("error unpacking")
consumed := p.consumed
p.consumed = 0
return consumed, errors.New("error unpacking")
}
}
return nil
consumed := p.consumed
p.consumed = 0
return consumed, nil
}
func (p *Parser) unpackV1HeaderLength(b []byte) []byte {
p.state.headerLen = int(b[0])
fmt.Println("unpacked header len", p.state.headerLen)
//fmt.Println("[debug] unpacked header len", p.state.headerLen)
b = b[1:]
p.state.written += 1
p.consumed += 1
p.parseState += 1
return b
}
func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
fmt.Println("got", len(b), "bytes", string(b))
//fmt.Println("[debug] got", len(b), "bytes", string(b))
m := len(p.state.header)
k := p.state.headerLen - m
if n < k {
k = n
}
p.state.written += k
p.consumed += k
c := b[0:k]
b = b[k:]
fmt.Println("has", m, "want", k, "more and have", len(b), "more")
//fmt.Println("[debug] has", m, "want", k, "more and have", len(b), "more")
p.state.header = append(p.state.header, c...)
if p.state.headerLen != len(p.state.header) {
return b, nil
@ -162,12 +177,13 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
return b, nil
*/
//fmt.Printf("[debug] [2] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
p.handler.WriteMessage(p.state.addr, []byte{})
return b, nil
}
k := p.state.payloadLen - p.state.payloadWritten
if k < n {
if n < k {
k = n
}
c := b[0:k]
@ -176,7 +192,6 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
// and also put backpressure on just that connection
/*
m, err := p.state.conn.local.Write(c)
p.state.written += m
p.state.payloadWritten += m
if nil != err {
// TODO we want to surface this error somewhere, but not to the websocket
@ -184,13 +199,14 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
}
*/
p.handler.WriteMessage(p.state.addr, c)
p.state.written += k
p.consumed += k
p.state.payloadWritten += k
p.written = p.state.written
//fmt.Printf("[debug] [1] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
// if this payload is complete, reset all state
if p.state.payloadWritten == p.state.payloadLen {
p.state = ParserState{}
p.parseState = 0
}
return b, nil
}