telebit/rvpn/connection/connection.go

180 lines
3.8 KiB
Go
Raw Normal View History

package connection
import (
"encoding/hex"
"time"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
// Connection track websocket and faciliates in and out data
type Connection struct {
// The main connection table (should be just one of these created at startup)
connectionTable *Table
//used to track traffic for a domain. Not use for lookup or validation only for tracking
DomainTrack map[string]*DomainTrack
// The websocket connection.
conn *websocket.Conn
// Buffered channel of outbound messages.
send chan *SendTrack
// Address of the Remote End Point
source string
// bytes in
bytesIn int64
// bytes out
bytesOut int64
// communications channel between go routines
commCh chan bool
// Connect Time
connectTime time.Time
//initialDomains - a list of domains from the JWT
initialDomains []interface{}
}
//NewConnection -- Constructor
func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress string, initialDomains []interface{}) (p *Connection) {
p = new(Connection)
p.connectionTable = connectionTable
p.conn = conn
p.source = remoteAddress
p.bytesIn = 0
p.bytesOut = 0
p.send = make(chan *SendTrack)
p.commCh = make(chan bool)
p.connectTime = time.Now()
p.initialDomains = initialDomains
p.DomainTrack = make(map[string]*DomainTrack)
for _, domain := range initialDomains {
p.AddTrackedDomain(string(domain.(string)))
}
return
}
//AddTrackedDomain -- Add a tracked domain
func (c *Connection) AddTrackedDomain(domain string) {
p := new(DomainTrack)
p.DomainName = domain
c.DomainTrack[domain] = p
}
//InitialDomains -- Property
func (c *Connection) InitialDomains() (i []interface{}) {
i = c.initialDomains
return
}
//ConnectTime -- Property
func (c *Connection) ConnectTime() (t time.Time) {
t = c.connectTime
return
}
//BytesIn -- Property
func (c *Connection) BytesIn() (b int64) {
b = c.bytesIn
return
}
//BytesOut -- Property
func (c *Connection) BytesOut() (b int64) {
b = c.bytesOut
return
}
//SendCh -- property to sending channel
func (c *Connection) SendCh() chan *SendTrack {
return c.send
}
func (c *Connection) addIn(num int64) {
c.bytesIn = c.bytesIn + num
}
func (c *Connection) addOut(num int64) {
c.bytesOut = c.bytesOut + num
}
//ConnectionTable -- property
func (c *Connection) ConnectionTable() (table *Table) {
table = c.connectionTable
return
}
//CommCh -- Property
func (c *Connection) CommCh() chan bool {
return c.commCh
}
//Reader -- export the reader function
func (c *Connection) Reader() {
defer func() {
c.connectionTable.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(1024)
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
loginfo.Printf("error: %v", err)
}
break
}
loginfo.Println(hex.Dump(message))
c.addIn(int64(len(message)))
loginfo.Println(c)
}
}
//Writer -- expoer the writer function
func (c *Connection) Writer() {
defer func() {
c.conn.Close()
}()
for {
select {
case message := <-c.send:
w, err := c.conn.NextWriter(websocket.TextMessage)
if err != nil {
return
}
w.Write(message.data)
if err := w.Close(); err != nil {
return
}
messageLen := int64(len(message.data))
c.addOut(messageLen)
//Support for tracking outbound traffic based on domain.
if domainTrack, ok := c.DomainTrack[message.domain]; ok {
//if ok then add to structure, else warn there is something wrong
domainTrack.AddOut(messageLen)
loginfo.Println("adding ", messageLen, " to ", message.domain)
} else {
logdebug.Println("attempting to add bytes to ", message.domain, "it does not exist")
logdebug.Println(c.DomainTrack)
}
loginfo.Println(c)
}
}
}