more parser tests
This commit is contained in:
parent
c5df63b11d
commit
f67fc7324d
|
@ -81,7 +81,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
|
|||
for {
|
||||
time.Sleep(15 * time.Second)
|
||||
deadline := time.Now().Add(45 * time.Second)
|
||||
if err := wsconn.WriteControl(websocket.PingMessage, "", deadline); nil != err {
|
||||
if err := wsconn.WriteControl(websocket.PingMessage, []byte(""), deadline); nil != err {
|
||||
fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err)
|
||||
cancel()
|
||||
break
|
||||
|
@ -94,7 +94,7 @@ func (m *MultiplexLocal) listen(ctx context.Context, wsconn *websocket.Conn, lis
|
|||
// TODO optimal buffer size
|
||||
b := make([]byte, 128*1024)
|
||||
for {
|
||||
n, err := listener.packer.Read(b)
|
||||
n, err := listener.parser.Read(b)
|
||||
if n > 0 {
|
||||
if err := wsconn.WriteMessage(websocket.BinaryMessage, b); nil != err {
|
||||
fmt.Fprintf(os.Stderr, "failed to write packer message to websocket: %s\n", err)
|
||||
|
|
|
@ -3,7 +3,6 @@ package packer
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type Parser struct {
|
||||
|
@ -15,11 +14,11 @@ type Parser struct {
|
|||
parseState State
|
||||
dataReady chan struct{}
|
||||
data []byte
|
||||
written int
|
||||
consumed int
|
||||
}
|
||||
|
||||
type ParserState struct {
|
||||
written int
|
||||
consumed int
|
||||
version byte
|
||||
headerLen int
|
||||
header []byte
|
||||
|
@ -60,19 +59,21 @@ func (p *Parser) Write(b []byte) (int, error) {
|
|||
return 0, errors.New("developer error: wrote 0 bytes")
|
||||
}
|
||||
|
||||
/*
|
||||
// so that we can overwrite the main state
|
||||
// as soon as a full message has completed
|
||||
// but still keep the number of bytes written
|
||||
if 0 == p.state.written {
|
||||
p.written = 0
|
||||
}
|
||||
*/
|
||||
|
||||
switch p.parseState {
|
||||
case VersionState:
|
||||
fmt.Println("version state", b[0])
|
||||
//fmt.Println("[debug] version state", b[0])
|
||||
p.state.version = b[0]
|
||||
b = b[1:]
|
||||
p.state.written += 1
|
||||
p.consumed += 1
|
||||
p.parseState += 1
|
||||
default:
|
||||
// do nothing
|
||||
|
@ -80,8 +81,8 @@ func (p *Parser) Write(b []byte) (int, error) {
|
|||
|
||||
switch p.state.version {
|
||||
case V1:
|
||||
fmt.Println("v1 unmarshal")
|
||||
return p.written, p.unpackV1(b)
|
||||
//fmt.Println("[debug] v1 unmarshal")
|
||||
return p.unpackV1(b)
|
||||
default:
|
||||
return 0, errors.New("incorrect version or version not implemented")
|
||||
}
|
||||
|
|
|
@ -2,12 +2,24 @@ package packer
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var src = Addr{
|
||||
family: "IPv4",
|
||||
addr: "192.168.1.101",
|
||||
port: 6743,
|
||||
}
|
||||
var dst = Addr{
|
||||
family: "IPv4",
|
||||
port: 80,
|
||||
scheme: "http",
|
||||
}
|
||||
var domain = "ex1.telebit.io"
|
||||
var payload = []byte("Hello, World!")
|
||||
|
||||
type testHandler struct {
|
||||
conns map[string]*Conn
|
||||
chunksParsed int
|
||||
|
@ -15,6 +27,7 @@ type testHandler struct {
|
|||
}
|
||||
|
||||
func (th *testHandler) WriteMessage(a Addr, b []byte) {
|
||||
th.chunksParsed += 1
|
||||
addr := &a
|
||||
_, ok := th.conns[addr.Network()]
|
||||
if !ok {
|
||||
|
@ -27,11 +40,30 @@ func (th *testHandler) WriteMessage(a Addr, b []byte) {
|
|||
}
|
||||
th.conns[addr.Network()] = conn
|
||||
}
|
||||
th.chunksParsed += 1
|
||||
th.bytesRead += len(b)
|
||||
}
|
||||
|
||||
func TestParseWholeBlock(t *testing.T) {
|
||||
func TestParse1WholeBlock(t *testing.T) {
|
||||
testParseNBlocks(t, 1)
|
||||
}
|
||||
|
||||
func TestParse2WholeBlocks(t *testing.T) {
|
||||
testParseNBlocks(t, 2)
|
||||
}
|
||||
|
||||
func TestParse3WholeBlocks(t *testing.T) {
|
||||
testParseNBlocks(t, 3)
|
||||
}
|
||||
|
||||
func TestParse2Addrs(t *testing.T) {
|
||||
testParseNBlocks(t, 4)
|
||||
}
|
||||
|
||||
func TestParse3Addrs(t *testing.T) {
|
||||
testParseNBlocks(t, 5)
|
||||
}
|
||||
|
||||
func TestParse1AndRest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
//ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
|
@ -40,25 +72,17 @@ func TestParseWholeBlock(t *testing.T) {
|
|||
}
|
||||
|
||||
p := NewParser(ctx, th)
|
||||
payload := []byte(`Hello, World!`)
|
||||
fmt.Println("payload len", len(payload))
|
||||
src := Addr{
|
||||
family: "IPv4",
|
||||
addr: "192.168.1.101",
|
||||
port: 6743,
|
||||
}
|
||||
dst := Addr{
|
||||
family: "IPv4",
|
||||
port: 80,
|
||||
scheme: "http",
|
||||
}
|
||||
domain := "ex1.telebit.io"
|
||||
|
||||
h, b, err := Encode(src, dst, domain, payload)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw := append(h, b...)
|
||||
n, err := p.Write(raw)
|
||||
n, err := p.Write(raw[:1])
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m, err := p.Write(raw[1:])
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -67,12 +91,135 @@ func TestParseWholeBlock(t *testing.T) {
|
|||
t.Fatal("should have parsed one connection")
|
||||
}
|
||||
if 1 != th.chunksParsed {
|
||||
t.Fatal("should have parsed one chunck")
|
||||
t.Fatal("should have parsed 1 chunck(s)")
|
||||
}
|
||||
if len(payload) != th.bytesRead {
|
||||
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
|
||||
}
|
||||
if n != len(raw) {
|
||||
if n+m != len(raw) {
|
||||
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRestAnd1(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
//ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
th := &testHandler{
|
||||
conns: map[string]*Conn{},
|
||||
}
|
||||
|
||||
p := NewParser(ctx, th)
|
||||
|
||||
h, b, err := Encode(src, dst, domain, payload)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw := append(h, b...)
|
||||
i := len(raw)
|
||||
n, err := p.Write(raw[:i-1])
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
m, err := p.Write(raw[i-1:])
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if 1 != len(th.conns) {
|
||||
t.Fatal("should have parsed one connection")
|
||||
}
|
||||
if 2 != th.chunksParsed {
|
||||
t.Fatal("should have parsed 2 chunck(s)")
|
||||
}
|
||||
if len(payload) != th.bytesRead {
|
||||
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
|
||||
}
|
||||
if n+m != len(raw) {
|
||||
t.Fatalf("should have parsed all %d bytes, not just %d\n", n, len(raw))
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse1By1(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
//ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
th := &testHandler{
|
||||
conns: map[string]*Conn{},
|
||||
}
|
||||
|
||||
p := NewParser(ctx, th)
|
||||
|
||||
h, b, err := Encode(src, dst, domain, payload)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw := append(h, b...)
|
||||
count := 0
|
||||
for _, b := range raw {
|
||||
n, err := p.Write([]byte{b})
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
count += n
|
||||
}
|
||||
|
||||
if 1 != len(th.conns) {
|
||||
t.Fatal("should have parsed one connection")
|
||||
}
|
||||
if len(payload) != th.chunksParsed {
|
||||
t.Fatalf("should have parsed %d chunck(s), not %d", len(payload), th.chunksParsed)
|
||||
}
|
||||
if len(payload) != th.bytesRead {
|
||||
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", len(payload), th.bytesRead)
|
||||
}
|
||||
if count != len(raw) {
|
||||
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), count)
|
||||
}
|
||||
}
|
||||
|
||||
func testParseNBlocks(t *testing.T, count int) {
|
||||
ctx := context.Background()
|
||||
//ctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
th := &testHandler{
|
||||
conns: map[string]*Conn{},
|
||||
}
|
||||
|
||||
nAddr := 1
|
||||
if count > 2 {
|
||||
nAddr = count - 2
|
||||
}
|
||||
p := NewParser(ctx, th)
|
||||
raw := []byte{}
|
||||
for i := 0; i < count; i++ {
|
||||
if i > 2 {
|
||||
copied := src
|
||||
src = copied
|
||||
src.port += i
|
||||
}
|
||||
h, b, err := Encode(src, dst, domain, payload)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
raw = append(raw, h...)
|
||||
raw = append(raw, b...)
|
||||
}
|
||||
n, err := p.Write(raw)
|
||||
if nil != err {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if nAddr != len(th.conns) {
|
||||
t.Fatalf("should have parsed %d connection(s)", nAddr)
|
||||
}
|
||||
if count != th.chunksParsed {
|
||||
t.Fatalf("should have parsed %d chunk(s)", count)
|
||||
}
|
||||
if count*len(payload) != th.bytesRead {
|
||||
t.Fatalf("should have parsed a payload of %d bytes, but saw %d\n", count*len(payload), th.bytesRead)
|
||||
}
|
||||
if n != len(raw) {
|
||||
t.Fatalf("should have parsed all %d bytes, not just %d\n", len(raw), n)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -28,71 +28,86 @@ type Header struct {
|
|||
Service string
|
||||
}
|
||||
|
||||
func (p *Parser) unpackV1(b []byte) error {
|
||||
func (p *Parser) unpackV1(b []byte) (int, error) {
|
||||
z := 0
|
||||
for {
|
||||
if z > 10 {
|
||||
if z > 20 {
|
||||
panic("stuck in an infinite loop?")
|
||||
}
|
||||
z += 1
|
||||
n := len(b)
|
||||
// at least one loop
|
||||
if z > 1 && n < 1 {
|
||||
fmt.Println("v1 end", z, n)
|
||||
if n < 1 {
|
||||
//fmt.Println("[debug] v1 end", z, n)
|
||||
break
|
||||
}
|
||||
|
||||
var err error
|
||||
switch p.parseState {
|
||||
case VersionState:
|
||||
//fmt.Println("[debug] version state", b[0])
|
||||
p.state.version = b[0]
|
||||
b = b[1:]
|
||||
p.consumed += 1
|
||||
p.parseState += 1
|
||||
case HeaderLengthState:
|
||||
fmt.Println("v1 h len")
|
||||
//fmt.Println("[debug] v1 h len")
|
||||
b = p.unpackV1HeaderLength(b)
|
||||
case HeaderState:
|
||||
fmt.Println("v1 header")
|
||||
//fmt.Println("[debug] v1 header")
|
||||
b, err = p.unpackV1Header(b, n)
|
||||
if nil != err {
|
||||
fmt.Println("v1 header err", err)
|
||||
return err
|
||||
//fmt.Println("[debug] v1 header err", err)
|
||||
consumed := p.consumed
|
||||
p.consumed = 0
|
||||
return consumed, err
|
||||
}
|
||||
case PayloadState:
|
||||
fmt.Println("v1 payload")
|
||||
//fmt.Println("[debug] v1 payload")
|
||||
// if this payload is complete, reset all state
|
||||
if p.state.payloadWritten == p.state.payloadLen {
|
||||
p.state = ParserState{}
|
||||
p.parseState = 0
|
||||
}
|
||||
b, err = p.unpackV1Payload(b, n)
|
||||
if nil != err {
|
||||
return err
|
||||
consumed := p.consumed
|
||||
p.consumed = 0
|
||||
return consumed, err
|
||||
}
|
||||
default:
|
||||
fmt.Println("[debug] v1 unknown state")
|
||||
// do nothing
|
||||
return errors.New("error unpacking")
|
||||
consumed := p.consumed
|
||||
p.consumed = 0
|
||||
return consumed, errors.New("error unpacking")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
consumed := p.consumed
|
||||
p.consumed = 0
|
||||
return consumed, nil
|
||||
}
|
||||
|
||||
func (p *Parser) unpackV1HeaderLength(b []byte) []byte {
|
||||
p.state.headerLen = int(b[0])
|
||||
fmt.Println("unpacked header len", p.state.headerLen)
|
||||
//fmt.Println("[debug] unpacked header len", p.state.headerLen)
|
||||
b = b[1:]
|
||||
p.state.written += 1
|
||||
p.consumed += 1
|
||||
p.parseState += 1
|
||||
return b
|
||||
}
|
||||
|
||||
func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
|
||||
fmt.Println("got", len(b), "bytes", string(b))
|
||||
//fmt.Println("[debug] got", len(b), "bytes", string(b))
|
||||
m := len(p.state.header)
|
||||
k := p.state.headerLen - m
|
||||
if n < k {
|
||||
k = n
|
||||
}
|
||||
p.state.written += k
|
||||
p.consumed += k
|
||||
c := b[0:k]
|
||||
b = b[k:]
|
||||
fmt.Println("has", m, "want", k, "more and have", len(b), "more")
|
||||
//fmt.Println("[debug] has", m, "want", k, "more and have", len(b), "more")
|
||||
p.state.header = append(p.state.header, c...)
|
||||
if p.state.headerLen != len(p.state.header) {
|
||||
return b, nil
|
||||
|
@ -162,12 +177,13 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
|
|||
return b, nil
|
||||
*/
|
||||
|
||||
//fmt.Printf("[debug] [2] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
|
||||
p.handler.WriteMessage(p.state.addr, []byte{})
|
||||
return b, nil
|
||||
}
|
||||
|
||||
k := p.state.payloadLen - p.state.payloadWritten
|
||||
if k < n {
|
||||
if n < k {
|
||||
k = n
|
||||
}
|
||||
c := b[0:k]
|
||||
|
@ -176,7 +192,6 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
|
|||
// and also put backpressure on just that connection
|
||||
/*
|
||||
m, err := p.state.conn.local.Write(c)
|
||||
p.state.written += m
|
||||
p.state.payloadWritten += m
|
||||
if nil != err {
|
||||
// TODO we want to surface this error somewhere, but not to the websocket
|
||||
|
@ -184,13 +199,14 @@ func (p *Parser) unpackV1Payload(b []byte, n int) ([]byte, error) {
|
|||
}
|
||||
*/
|
||||
p.handler.WriteMessage(p.state.addr, c)
|
||||
p.state.written += k
|
||||
p.consumed += k
|
||||
p.state.payloadWritten += k
|
||||
p.written = p.state.written
|
||||
|
||||
//fmt.Printf("[debug] [1] payload written: %d | payload length: %d\n", p.state.payloadWritten, p.state.payloadLen)
|
||||
// if this payload is complete, reset all state
|
||||
if p.state.payloadWritten == p.state.payloadLen {
|
||||
p.state = ParserState{}
|
||||
p.parseState = 0
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue