diff --git a/go-rvpn-server.yaml b/go-rvpn-server.yaml index 90e137e..2cd7ad6 100644 --- a/go-rvpn-server.yaml +++ b/go-rvpn-server.yaml @@ -4,7 +4,7 @@ rvpn: admindomain: rvpn.daplie.invalid genericlistener: 9999 deadtime: - dwell: 600 + dwell: 120 idle: 60 cancelcheck: 10 domains: diff --git a/main.go b/main.go index 4922e33..0ca1e81 100644 --- a/main.go +++ b/main.go @@ -98,7 +98,7 @@ func main() { connectionTable = genericlistener.NewTable(dwell, idle) serverStatus.ConnectionTable = connectionTable - go connectionTable.Run(ctx) + go connectionTable.Run(ctx, lbDefaultMethod) genericListeners := genericlistener.NewGenerListeners(ctx, secretKey, certbundle, serverStatus) serverStatus.GenericListeners = genericListeners diff --git a/rvpn/genericlistener/api_interface.go b/rvpn/genericlistener/api_interface.go index 5f3ec96..a661477 100644 --- a/rvpn/genericlistener/api_interface.go +++ b/rvpn/genericlistener/api_interface.go @@ -101,9 +101,13 @@ func getDomainsEndpoint(w http.ResponseWriter, r *http.Request) { domainsContainer := NewDomainsAPIContainer() for domain := range connectionTable.domains { - conn := connectionTable.domains[domain] - domainAPI := NewDomainsAPI(conn, conn.DomainTrack[domain]) - domainsContainer.Domains = append(domainsContainer.Domains, domainAPI) + domainLB := connectionTable.domains[domain] + conns := domainLB.Connections() + for pos := range conns { + conn := conns[pos] + domainAPI := NewDomainsAPI(conn, conn.DomainTrack[domain]) + domainsContainer.Domains = append(domainsContainer.Domains, domainAPI) + } } @@ -130,14 +134,19 @@ func getDomainEndpoint(w http.ResponseWriter, r *http.Request) { env.ErrorDescription = "domain API requires a domain-name" } else { domainName := id - if conn, ok := connectionTable.domains[domainName]; !ok { + if domainLB, ok := connectionTable.domains[domainName]; !ok { env.Error = "domain-name was not found" env.ErrorURI = r.RequestURI env.ErrorDescription = "domain-name not found" } else { - - domainAPI := NewDomainAPI(conn, conn.DomainTrack[domainName]) - env.Result = domainAPI + var domainAPIContainer []*DomainAPI + conns := domainLB.Connections() + for pos := range conns { + conn := conns[pos] + domainAPI := NewDomainAPI(conn, conn.DomainTrack[domainName]) + domainAPIContainer = append(domainAPIContainer, domainAPI) + } + env.Result = domainAPIContainer } } w.Header().Set("Content-Type", "application/json; charset=UTF-8") diff --git a/rvpn/genericlistener/connection.go b/rvpn/genericlistener/connection.go index 1149240..8c779e7 100755 --- a/rvpn/genericlistener/connection.go +++ b/rvpn/genericlistener/connection.go @@ -279,6 +279,7 @@ func (c *Connection) Writer() { w, err := c.NextWriter(websocket.BinaryMessage) loginfo.Println("next writer ", w) if err != nil { + c.SetState(false) return } diff --git a/rvpn/genericlistener/connection_table.go b/rvpn/genericlistener/connection_table.go index 94ae5ee..14c461b 100755 --- a/rvpn/genericlistener/connection_table.go +++ b/rvpn/genericlistener/connection_table.go @@ -14,7 +14,7 @@ const ( //Table maintains the set of connections type Table struct { connections map[*Connection][]string - domains map[string]*Connection + domains map[string]*DomainLoadBalance register chan *Registration unregister chan *Connection domainAnnounce chan *DomainMapping @@ -27,7 +27,7 @@ type Table struct { func NewTable(dwell, idle int) (p *Table) { p = new(Table) p.connections = make(map[*Connection][]string) - p.domains = make(map[string]*Connection) + p.domains = make(map[string]*DomainLoadBalance) p.register = make(chan *Registration) p.unregister = make(chan *Connection) p.domainAnnounce = make(chan *DomainMapping) @@ -42,10 +42,19 @@ func (c *Table) Connections() map[*Connection][]string { return c.connections } -//ConnByDomain -- Obtains a connection from a domain announcement. +//ConnByDomain -- Obtains a connection from a domain announcement. A domain may be announced more than once +//if that is the case the system stores these connections and then sends traffic back round-robin +//back to the WSS connections func (c *Table) ConnByDomain(domain string) (*Connection, bool) { - conn, ok := c.domains[domain] - return conn, ok + for dn := range c.domains { + loginfo.Println(dn, domain) + } + if domainsLB, ok := c.domains[domain]; ok { + loginfo.Println("found") + conn := domainsLB.NextMember() + return conn, ok + } + return nil, false } //reaper -- @@ -79,7 +88,7 @@ func (c *Table) GetConnection(serverID int64) (*Connection, error) { } //Run -- Execute -func (c *Table) Run(ctx context.Context) { +func (c *Table) Run(ctx context.Context, defaultMethod string) { loginfo.Println("ConnectionTable starting") go c.reaper(c.dwell, c.idle) @@ -104,7 +113,17 @@ func (c *Table) Run(ctx context.Context) { newDomain := string(domain.(string)) loginfo.Println("adding domain ", newDomain, " to connection ", connection.conn.RemoteAddr().String()) - c.domains[newDomain] = connection + + //check to see if domain is already present. + if _, ok := c.domains[newDomain]; ok { + + //append to a list of connections for that domain + c.domains[newDomain].AddConnection(connection) + } else { + //if not, then add as the 1st to the list of connections + c.domains[newDomain] = NewDomainLoadBalance(defaultMethod) + c.domains[newDomain].AddConnection(connection) + } // add to the connection domain list s := c.connections[connection] @@ -115,11 +134,26 @@ func (c *Table) Run(ctx context.Context) { case connection := <-c.unregister: loginfo.Println("closing connection ", connection.conn.RemoteAddr().String()) + + //does connection exist in the connection table -- should never be an issue if _, ok := c.connections[connection]; ok { + + //iterate over the connections for the domain for _, domain := range c.connections[connection] { - fmt.Println("removing domain ", domain) + loginfo.Println("remove domain", domain) + + //removing domain, make sure it is present (should never be a problem) if _, ok := c.domains[domain]; ok { - delete(c.domains, domain) + + domainLB := c.domains[domain] + domainLB.RemoveConnection(connection) + + //check to see if domain is free of connections, if yes, delete map entry + if domainLB.count > 0 { + //ignore...perhaps we will do something here dealing wtih the lb method + } else { + delete(c.domains, domain) + } } } diff --git a/rvpn/genericlistener/domain_loadbalance.go b/rvpn/genericlistener/domain_loadbalance.go new file mode 100644 index 0000000..ccbef20 --- /dev/null +++ b/rvpn/genericlistener/domain_loadbalance.go @@ -0,0 +1,107 @@ +package genericlistener + +import ( + "fmt" + "sync" +) + +const ( + lbmUnSupported string = "unsuported" + lbmRoundRobin string = "round-robin" + lbmLeastConnections string = "least-connections" +) + +//DomainLoadBalance -- Use as a structure for domain connections +//and load balancing those connections. Initial modes are round-robin +//but suspect we will need least-connections, and sticky +type DomainLoadBalance struct { + mutex sync.Mutex + + //lb method, supported round robin. + method string + + //the last connection based on calculation + lastmember int + + // a list of connections in this load balancing context + connections []*Connection + + //a counter to track total connections, so we aren't calling len all the time + count int + + //true if the system belives a recalcuation is required + recalc bool +} + +//NewDomainLoadBalance -- Constructor +func NewDomainLoadBalance(defaultMethod string) (p *DomainLoadBalance) { + p = new(DomainLoadBalance) + p.method = defaultMethod + p.lastmember = 0 + p.count = 0 + return +} + +//Connections -- Access connections +func (p *DomainLoadBalance) Connections() []*Connection { + return p.connections +} + +//NextMember -- increments the lastmember, and then checks if >= to count, if true +//the last is reset to 0 +func (p *DomainLoadBalance) NextMember() (conn *Connection) { + p.mutex.Lock() + defer p.mutex.Unlock() + + //check for round robin, if not RR then drop out and call calculate + loginfo.Println("NextMember:", p) + if p.method == lbmRoundRobin { + p.lastmember++ + if p.lastmember >= p.count { + p.lastmember = 0 + } + nextConn := p.connections[p.lastmember] + return nextConn + } + + // Not round robin + switch method := p.method; method { + default: + panic(fmt.Errorf("fatal unsupported loadbalance method %s", method)) + } +} + +//AddConnection -- Add an additional connection to the list of connections for this domain +//this should not affect the next member calculation in RR. However it many in other +//methods +func (p *DomainLoadBalance) AddConnection(conn *Connection) []*Connection { + loginfo.Println("AddConnection", fmt.Sprintf("%p", conn)) + p.mutex.Lock() + defer p.mutex.Unlock() + p.connections = append(p.connections, conn) + p.count++ + loginfo.Println("AddConnection", p) + return p.connections +} + +//RemoveConnection -- removes a matching connection from the list. This may +//affect the nextmember calculation if found so the recalc flag is set. +func (p *DomainLoadBalance) RemoveConnection(conn *Connection) { + loginfo.Println("RemoveConnection", fmt.Sprintf("%p", conn)) + + p.mutex.Lock() + defer p.mutex.Unlock() + + //scan all the connections + for pos := range p.connections { + loginfo.Println("RemoveConnection", pos, len(p.connections), p.count) + if p.connections[pos] == conn { + //found connection remove it + loginfo.Printf("found connection %p", conn) + p.connections[pos], p.connections = p.connections[len(p.connections)-1], p.connections[:len(p.connections)-1] + p.count-- + break + } + } + loginfo.Println("RemoveConnection:", p) +}