diff --git a/rvpn/client/listener_client.go b/rvpn/client/listener_client.go index f5afde6..c53c346 100644 --- a/rvpn/client/listener_client.go +++ b/rvpn/client/listener_client.go @@ -70,15 +70,15 @@ func handleConnectionWebSocket(connectionTable *connection.Table, w http.Respons loginfo.Println("before connection table") - newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains) - connectionTable.Register() <- newConnection - ok = <-newConnection.CommCh() + //newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains) + + newRegistration := connection.NewRegistration(conn, r.RemoteAddr, domains) + connectionTable.Register() <- newRegistration + ok = <-newRegistration.CommCh() if !ok { - loginfo.Println("connection registration failed ", newConnection) + loginfo.Println("connection registration failed ", newRegistration) return } - loginfo.Println("connection registration accepted ", newConnection) - go newConnection.Writer() - newConnection.Reader() - loginfo.Println("connection closing") + + loginfo.Println("connection registration accepted ", newRegistration) } diff --git a/rvpn/connection/connection.go b/rvpn/connection/connection.go index b65784a..1faf09e 100755 --- a/rvpn/connection/connection.go +++ b/rvpn/connection/connection.go @@ -4,6 +4,10 @@ import ( "encoding/hex" "time" + "sync" + + "io" + "github.com/gorilla/websocket" ) @@ -14,6 +18,8 @@ var upgrader = websocket.Upgrader{ // Connection track websocket and faciliates in and out data type Connection struct { + mutex *sync.Mutex + // The main connection table (should be just one of these created at startup) connectionTable *Table @@ -26,6 +32,9 @@ type Connection struct { // Buffered channel of outbound messages. send chan *SendTrack + // WssState channel + // Must check state via channel before xmit + // Address of the Remote End Point source string @@ -41,13 +50,20 @@ type Connection struct { // Connect Time connectTime time.Time + //lastUpdate + lastUpdate time.Time + //initialDomains - a list of domains from the JWT initialDomains []interface{} + + ///wssState tracks a highlevel status of the connection, false means do nothing. + wssState bool } //NewConnection -- Constructor func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Connection) { p = new(Connection) + p.mutex = &sync.Mutex{} p.connectionTable = connectionTable p.conn = conn p.source = remoteAddress @@ -62,6 +78,8 @@ func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress s for _, domain := range initialDomains { p.AddTrackedDomain(string(domain.(string))) } + + p.State(true) return } @@ -120,18 +138,77 @@ func (c *Connection) CommCh() chan bool { return c.commCh } +//GetState -- Get state of Socket...this is a high level state. +func (c *Connection) GetState() bool { + defer func() { + c.mutex.Unlock() + }() + c.mutex.Lock() + return c.wssState +} + +//State -- Set the set of the high level connection +func (c *Connection) State(state bool) { + defer func() { + c.mutex.Unlock() + }() + + c.mutex.Lock() + c.wssState = state +} + +//Update -- updates the lastUpdate property tracking idle time +func (c *Connection) Update() { + defer func() { + c.mutex.Unlock() + }() + + c.mutex.Lock() + c.lastUpdate = time.Now() +} + +//NextWriter -- Wrapper to allow a high level state check before offering NextWriter +//The libary failes if client abends during write-cycle. a fast moving write is not caught before socket state bubbles up +//A synchronised state is maintained +func (c Connection) NextWriter(wssMessageType int) (w io.WriteCloser, err error) { + if c.GetState() == true { + w, err = c.conn.NextWriter(wssMessageType) + } else { + loginfo.Println("NextWriter aborted, state is not true") + } + return +} + +//Write -- Wrapper to allow a high level state check before allowing a write to the socket. +func (c *Connection) Write(w io.WriteCloser, message []byte) (cnt int, err error) { + if c.GetState() == true { + cnt, err = w.Write(message) + } + return +} + //Reader -- export the reader function func (c *Connection) Reader() { defer func() { c.connectionTable.unregister <- c c.conn.Close() + loginfo.Println("reader defer", c) }() + + loginfo.Println("Reader Start ", c) + c.conn.SetReadLimit(1024) for { _, message, err := c.conn.ReadMessage() + + loginfo.Println("ReadMessage") + c.Update() + if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) { + c.State(false) loginfo.Printf("error: %v", err) + loginfo.Println(c.conn) } break } @@ -146,15 +223,24 @@ func (c *Connection) Writer() { defer func() { c.conn.Close() }() + + loginfo.Println("Writer Start ", c) + for { select { case message := <-c.send: - w, err := c.conn.NextWriter(websocket.TextMessage) + w, err := c.NextWriter(websocket.BinaryMessage) + loginfo.Println("next writer ", w) if err != nil { return } - w.Write(message.data) + loginfo.Println(c) + loginfo.Println(w) + + c.Update() + + _, err = c.Write(w, message.data) if err := w.Close(); err != nil { return diff --git a/rvpn/connection/connection_registration.go b/rvpn/connection/connection_registration.go new file mode 100644 index 0000000..0ec8660 --- /dev/null +++ b/rvpn/connection/connection_registration.go @@ -0,0 +1,35 @@ +package connection + +import "github.com/gorilla/websocket" + +//Registration -- A connection registration structure used to bring up a connection +//connection table will then handle additing and sdtarting up the various readers +//else error. +type Registration struct { + // The websocket connection. + conn *websocket.Conn + + // Address of the Remote End Point + source string + + // communications channel between go routines + commCh chan bool + + //initialDomains - a list of domains from the JWT + initialDomains []interface{} +} + +//NewRegistration -- Constructor +func NewRegistration(conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Registration) { + p = new(Registration) + p.conn = conn + p.source = remoteAddress + p.commCh = make(chan bool) + p.initialDomains = initialDomains + return +} + +//CommCh -- Property +func (c *Registration) CommCh() chan bool { + return c.commCh +} diff --git a/rvpn/connection/connection_table.go b/rvpn/connection/connection_table.go index 46c787e..b732453 100755 --- a/rvpn/connection/connection_table.go +++ b/rvpn/connection/connection_table.go @@ -11,7 +11,7 @@ const ( type Table struct { connections map[*Connection][]string domains map[string]*Connection - register chan *Connection + register chan *Registration unregister chan *Connection domainAnnounce chan *DomainMapping domainRevoke chan *DomainMapping @@ -22,7 +22,7 @@ func NewTable() (p *Table) { p = new(Table) p.connections = make(map[*Connection][]string) p.domains = make(map[string]*Connection) - p.register = make(chan *Connection) + p.register = make(chan *Registration) p.unregister = make(chan *Connection) p.domainAnnounce = make(chan *DomainMapping) p.domainRevoke = make(chan *DomainMapping) @@ -46,10 +46,12 @@ func (c *Table) Run() { loginfo.Println("ConnectionTable starting") for { select { - case connection := <-c.register: + case registration := <-c.register: loginfo.Println("register fired") + + connection := NewConnection(c, registration.conn, registration.source, registration.initialDomains) c.connections[connection] = make([]string, initialDomains) - connection.commCh <- true + registration.commCh <- true // handle initial domain additions for _, domain := range connection.initialDomains { @@ -63,7 +65,8 @@ func (c *Table) Run() { s := c.connections[connection] c.connections[connection] = append(s, newDomain) } - + go connection.Writer() + go connection.Reader() loginfo.Println("register exiting") case connection := <-c.unregister: @@ -76,8 +79,8 @@ func (c *Table) Run() { } } - delete(c.connections, connection) - close(connection.send) + //delete(c.connections, connection) + //close(connection.send) } case domainMapping := <-c.domainAnnounce: @@ -96,7 +99,7 @@ func (c *Table) Run() { } //Register -- Property -func (c *Table) Register() (r chan *Connection) { +func (c *Table) Register() (r chan *Registration) { r = c.register return } diff --git a/rvpn/packer/packer.go b/rvpn/packer/packer.go index 349556c..36cf083 100644 --- a/rvpn/packer/packer.go +++ b/rvpn/packer/packer.go @@ -2,7 +2,6 @@ package packer import ( "bytes" - "encoding/hex" "fmt" ) @@ -44,12 +43,12 @@ func (p *Packer) PackV1() (b bytes.Buffer) { buf.Write(headerBuf.Bytes()) buf.Write(p.Data.buffer.Bytes()) - fmt.Println("header: ", headerBuf.String()) - fmt.Println("meta: ", metaBuf) - fmt.Println("Data: ", p.Data.buffer) - fmt.Println("Buffer: ", buf.Bytes()) - fmt.Println("Buffer: ", hex.Dump(buf.Bytes())) - fmt.Printf("Buffer %s", buf.Bytes()) + //fmt.Println("header: ", headerBuf.String()) + //fmt.Println("meta: ", metaBuf) + //fmt.Println("Data: ", p.Data.buffer) + //fmt.Println("Buffer: ", buf.Bytes()) + //fmt.Println("Buffer: ", hex.Dump(buf.Bytes())) + //fmt.Printf("Buffer %s", buf.Bytes()) b = buf