Correcting Critical Bug

- when testing streams to WSS client, I caused tunnel.js to abort in xfer.
- this caused a panic in go.
- found that connection was reaped and garbage collected during send routines.
- placed synchronize around a connection states.
- moved connection creation into connection table.
- allowed connections to hang around while in a false state…
- will have a go routine remove them after some idle time and connections being false.
This commit is contained in:
Henry Camacho 2017-02-19 14:05:06 -06:00
parent c261b5d3a3
commit ff3e63da8d
5 changed files with 148 additions and 25 deletions

View File

@ -70,15 +70,15 @@ func handleConnectionWebSocket(connectionTable *connection.Table, w http.Respons
loginfo.Println("before connection table") loginfo.Println("before connection table")
newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains) //newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains)
connectionTable.Register() <- newConnection
ok = <-newConnection.CommCh() newRegistration := connection.NewRegistration(conn, r.RemoteAddr, domains)
connectionTable.Register() <- newRegistration
ok = <-newRegistration.CommCh()
if !ok { if !ok {
loginfo.Println("connection registration failed ", newConnection) loginfo.Println("connection registration failed ", newRegistration)
return return
} }
loginfo.Println("connection registration accepted ", newConnection)
go newConnection.Writer() loginfo.Println("connection registration accepted ", newRegistration)
newConnection.Reader()
loginfo.Println("connection closing")
} }

View File

@ -4,6 +4,10 @@ import (
"encoding/hex" "encoding/hex"
"time" "time"
"sync"
"io"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -14,6 +18,8 @@ var upgrader = websocket.Upgrader{
// Connection track websocket and faciliates in and out data // Connection track websocket and faciliates in and out data
type Connection struct { type Connection struct {
mutex *sync.Mutex
// The main connection table (should be just one of these created at startup) // The main connection table (should be just one of these created at startup)
connectionTable *Table connectionTable *Table
@ -26,6 +32,9 @@ type Connection struct {
// Buffered channel of outbound messages. // Buffered channel of outbound messages.
send chan *SendTrack send chan *SendTrack
// WssState channel
// Must check state via channel before xmit
// Address of the Remote End Point // Address of the Remote End Point
source string source string
@ -41,13 +50,20 @@ type Connection struct {
// Connect Time // Connect Time
connectTime time.Time connectTime time.Time
//lastUpdate
lastUpdate time.Time
//initialDomains - a list of domains from the JWT //initialDomains - a list of domains from the JWT
initialDomains []interface{} initialDomains []interface{}
///wssState tracks a highlevel status of the connection, false means do nothing.
wssState bool
} }
//NewConnection -- Constructor //NewConnection -- Constructor
func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Connection) { func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Connection) {
p = new(Connection) p = new(Connection)
p.mutex = &sync.Mutex{}
p.connectionTable = connectionTable p.connectionTable = connectionTable
p.conn = conn p.conn = conn
p.source = remoteAddress p.source = remoteAddress
@ -62,6 +78,8 @@ func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress s
for _, domain := range initialDomains { for _, domain := range initialDomains {
p.AddTrackedDomain(string(domain.(string))) p.AddTrackedDomain(string(domain.(string)))
} }
p.State(true)
return return
} }
@ -120,18 +138,77 @@ func (c *Connection) CommCh() chan bool {
return c.commCh return c.commCh
} }
//GetState -- Get state of Socket...this is a high level state.
func (c *Connection) GetState() bool {
defer func() {
c.mutex.Unlock()
}()
c.mutex.Lock()
return c.wssState
}
//State -- Set the set of the high level connection
func (c *Connection) State(state bool) {
defer func() {
c.mutex.Unlock()
}()
c.mutex.Lock()
c.wssState = state
}
//Update -- updates the lastUpdate property tracking idle time
func (c *Connection) Update() {
defer func() {
c.mutex.Unlock()
}()
c.mutex.Lock()
c.lastUpdate = time.Now()
}
//NextWriter -- Wrapper to allow a high level state check before offering NextWriter
//The libary failes if client abends during write-cycle. a fast moving write is not caught before socket state bubbles up
//A synchronised state is maintained
func (c Connection) NextWriter(wssMessageType int) (w io.WriteCloser, err error) {
if c.GetState() == true {
w, err = c.conn.NextWriter(wssMessageType)
} else {
loginfo.Println("NextWriter aborted, state is not true")
}
return
}
//Write -- Wrapper to allow a high level state check before allowing a write to the socket.
func (c *Connection) Write(w io.WriteCloser, message []byte) (cnt int, err error) {
if c.GetState() == true {
cnt, err = w.Write(message)
}
return
}
//Reader -- export the reader function //Reader -- export the reader function
func (c *Connection) Reader() { func (c *Connection) Reader() {
defer func() { defer func() {
c.connectionTable.unregister <- c c.connectionTable.unregister <- c
c.conn.Close() c.conn.Close()
loginfo.Println("reader defer", c)
}() }()
loginfo.Println("Reader Start ", c)
c.conn.SetReadLimit(1024) c.conn.SetReadLimit(1024)
for { for {
_, message, err := c.conn.ReadMessage() _, message, err := c.conn.ReadMessage()
loginfo.Println("ReadMessage")
c.Update()
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
c.State(false)
loginfo.Printf("error: %v", err) loginfo.Printf("error: %v", err)
loginfo.Println(c.conn)
} }
break break
} }
@ -146,15 +223,24 @@ func (c *Connection) Writer() {
defer func() { defer func() {
c.conn.Close() c.conn.Close()
}() }()
loginfo.Println("Writer Start ", c)
for { for {
select { select {
case message := <-c.send: case message := <-c.send:
w, err := c.conn.NextWriter(websocket.TextMessage) w, err := c.NextWriter(websocket.BinaryMessage)
loginfo.Println("next writer ", w)
if err != nil { if err != nil {
return return
} }
w.Write(message.data) loginfo.Println(c)
loginfo.Println(w)
c.Update()
_, err = c.Write(w, message.data)
if err := w.Close(); err != nil { if err := w.Close(); err != nil {
return return

View File

@ -0,0 +1,35 @@
package connection
import "github.com/gorilla/websocket"
//Registration -- A connection registration structure used to bring up a connection
//connection table will then handle additing and sdtarting up the various readers
//else error.
type Registration struct {
// The websocket connection.
conn *websocket.Conn
// Address of the Remote End Point
source string
// communications channel between go routines
commCh chan bool
//initialDomains - a list of domains from the JWT
initialDomains []interface{}
}
//NewRegistration -- Constructor
func NewRegistration(conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Registration) {
p = new(Registration)
p.conn = conn
p.source = remoteAddress
p.commCh = make(chan bool)
p.initialDomains = initialDomains
return
}
//CommCh -- Property
func (c *Registration) CommCh() chan bool {
return c.commCh
}

View File

@ -11,7 +11,7 @@ const (
type Table struct { type Table struct {
connections map[*Connection][]string connections map[*Connection][]string
domains map[string]*Connection domains map[string]*Connection
register chan *Connection register chan *Registration
unregister chan *Connection unregister chan *Connection
domainAnnounce chan *DomainMapping domainAnnounce chan *DomainMapping
domainRevoke chan *DomainMapping domainRevoke chan *DomainMapping
@ -22,7 +22,7 @@ func NewTable() (p *Table) {
p = new(Table) p = new(Table)
p.connections = make(map[*Connection][]string) p.connections = make(map[*Connection][]string)
p.domains = make(map[string]*Connection) p.domains = make(map[string]*Connection)
p.register = make(chan *Connection) p.register = make(chan *Registration)
p.unregister = make(chan *Connection) p.unregister = make(chan *Connection)
p.domainAnnounce = make(chan *DomainMapping) p.domainAnnounce = make(chan *DomainMapping)
p.domainRevoke = make(chan *DomainMapping) p.domainRevoke = make(chan *DomainMapping)
@ -46,10 +46,12 @@ func (c *Table) Run() {
loginfo.Println("ConnectionTable starting") loginfo.Println("ConnectionTable starting")
for { for {
select { select {
case connection := <-c.register: case registration := <-c.register:
loginfo.Println("register fired") loginfo.Println("register fired")
connection := NewConnection(c, registration.conn, registration.source, registration.initialDomains)
c.connections[connection] = make([]string, initialDomains) c.connections[connection] = make([]string, initialDomains)
connection.commCh <- true registration.commCh <- true
// handle initial domain additions // handle initial domain additions
for _, domain := range connection.initialDomains { for _, domain := range connection.initialDomains {
@ -63,7 +65,8 @@ func (c *Table) Run() {
s := c.connections[connection] s := c.connections[connection]
c.connections[connection] = append(s, newDomain) c.connections[connection] = append(s, newDomain)
} }
go connection.Writer()
go connection.Reader()
loginfo.Println("register exiting") loginfo.Println("register exiting")
case connection := <-c.unregister: case connection := <-c.unregister:
@ -76,8 +79,8 @@ func (c *Table) Run() {
} }
} }
delete(c.connections, connection) //delete(c.connections, connection)
close(connection.send) //close(connection.send)
} }
case domainMapping := <-c.domainAnnounce: case domainMapping := <-c.domainAnnounce:
@ -96,7 +99,7 @@ func (c *Table) Run() {
} }
//Register -- Property //Register -- Property
func (c *Table) Register() (r chan *Connection) { func (c *Table) Register() (r chan *Registration) {
r = c.register r = c.register
return return
} }

View File

@ -2,7 +2,6 @@ package packer
import ( import (
"bytes" "bytes"
"encoding/hex"
"fmt" "fmt"
) )
@ -44,12 +43,12 @@ func (p *Packer) PackV1() (b bytes.Buffer) {
buf.Write(headerBuf.Bytes()) buf.Write(headerBuf.Bytes())
buf.Write(p.Data.buffer.Bytes()) buf.Write(p.Data.buffer.Bytes())
fmt.Println("header: ", headerBuf.String()) //fmt.Println("header: ", headerBuf.String())
fmt.Println("meta: ", metaBuf) //fmt.Println("meta: ", metaBuf)
fmt.Println("Data: ", p.Data.buffer) //fmt.Println("Data: ", p.Data.buffer)
fmt.Println("Buffer: ", buf.Bytes()) //fmt.Println("Buffer: ", buf.Bytes())
fmt.Println("Buffer: ", hex.Dump(buf.Bytes())) //fmt.Println("Buffer: ", hex.Dump(buf.Bytes()))
fmt.Printf("Buffer %s", buf.Bytes()) //fmt.Printf("Buffer %s", buf.Bytes())
b = buf b = buf