telebit/table/table.go

250 lines
5.6 KiB
Go

package table
import (
"fmt"
"net"
"sync"
"io"
"strconv"
"strings"
"git.coolaj86.com/coolaj86/go-telebitd/dbg"
telebit "git.coolaj86.com/coolaj86/go-telebitd/mplexer"
"github.com/gorilla/websocket"
)
// Servers represent actual connections
var Servers *sync.Map
// Table makes sense to be in-memory, but it could be serialized if needed
var Table *sync.Map
func init() {
Servers = &sync.Map{}
Table = &sync.Map{}
}
func Add(server *SubscriberConn) {
var srvMap *sync.Map
srvMapX, ok := Servers.Load(server.Grants.Subject)
if ok {
srvMap = srvMapX.(*sync.Map)
} else {
srvMap = &sync.Map{}
}
srvMap.Store(server.RemoteAddr, server)
Servers.Store(server.Grants.Subject, srvMap)
// Add this server to the domain name matrix
for _, domainname := range server.Grants.Domains {
var srvMap *sync.Map
srvMapX, ok := Table.Load(domainname)
if ok {
srvMap = srvMapX.(*sync.Map)
} else {
srvMap = &sync.Map{}
}
srvMap.Store(server.RemoteAddr, server)
Table.Store(domainname, srvMap)
}
}
func RemoveServer(server *SubscriberConn) bool {
// TODO remove by RemoteAddr
//return false
fmt.Printf("[warn] RemoveServer() still calls Remove(subject) instead of removing by RemoteAddr\n")
return Remove(server.Grants.Subject)
}
func Remove(subject string) bool {
srvMapX, ok := Servers.Load(subject)
fmt.Printf("[debug] has server for %s? %t\n", subject, ok)
if !ok {
return false
}
srvMap := srvMapX.(*sync.Map)
srvMap.Range(func(k, v interface{}) bool {
srv := v.(*SubscriberConn)
srv.Clients.Range(func(k, v interface{}) bool {
conn := v.(net.Conn)
_ = conn.Close()
return true
})
srv.WSConn.Close()
for _, domainname := range srv.Grants.Domains {
srvMapX, ok := Table.Load(domainname)
if !ok {
continue
}
srvMap = srvMapX.(*sync.Map)
srvMap.Delete(srv.RemoteAddr)
n := 0
srvMap.Range(func(k, v interface{}) bool {
n++
return true
})
if 0 == n {
// TODO comment out to handle the bad case of 0 servers / empty map
Table.Delete(domainname)
}
}
return true
})
Servers.Delete(subject)
return true
}
// SubscriberConn represents a tunneled server, its grants, and its clients
type SubscriberConn struct {
RemoteAddr string
WSConn *websocket.Conn
WSTun net.Conn // *telebit.WebsocketTunnel
Grants *telebit.Grants
Clients *sync.Map
// TODO is this the right codec type?
MultiEncoder *telebit.Encoder
MultiDecoder *telebit.Decoder
// to fulfill Router interface
}
func (s *SubscriberConn) RouteBytes(src, dst telebit.Addr, payload []byte) {
id := fmt.Sprintf("%s:%d", src.Hostname(), src.Port())
if dbg.Debug {
fmt.Println("[debug] Routing some more bytes:", dbg.Trunc(payload, len(payload)))
}
fmt.Printf("id %s\nsrc %+v\n", id, src)
fmt.Printf("dst %s %+v\n", dst.Scheme(), dst)
clientX, ok := s.Clients.Load(id)
if !ok {
// TODO send back closed client error
fmt.Println("[debug] no client found for", id)
return
}
client, _ := clientX.(net.Conn)
if "end" == dst.Scheme() {
fmt.Println("[debug] closing client", id)
_ = client.Close()
return
}
for {
n, err := client.Write(payload)
if dbg.Debug {
fmt.Println("[debug] table Write", dbg.Trunc(payload, len(payload)))
}
if nil == err || io.EOF == err {
break
}
if n > 0 && io.ErrShortWrite == err {
payload = payload[n:]
continue
}
break
// TODO send back closed client error
//return err
}
}
func (s *SubscriberConn) Serve(client net.Conn) error {
var wconn *telebit.ConnWrap
switch conn := client.(type) {
case *telebit.ConnWrap:
wconn = conn
default:
// this probably isn't strictly necessary
panic("*SubscriberConn.Serve is special in that it must receive &ConnWrap{ Conn: conn }")
}
id := client.RemoteAddr().String()
fmt.Printf("[DEBUG] NEW ID (ip:port) %s\n", id)
s.Clients.Store(id, client)
//fmt.Println("[debug] immediately cancel client to simplify testing / debugging")
//_ = client.Close()
// TODO
// - Encode each client to the tunnel
// - Find the right client for decoded messages
// TODO which order is remote / local?
srcParts := strings.Split(client.RemoteAddr().String(), ":")
srcAddr := srcParts[0]
srcPort, _ := strconv.Atoi(srcParts[1])
fmt.Println("[debug] srcParts", srcParts)
dstParts := strings.Split(client.LocalAddr().String(), ":")
dstAddr := dstParts[0]
dstPort, _ := strconv.Atoi(dstParts[1])
fmt.Println("[debug] dstParts", dstParts)
servername := wconn.Servername()
termination := telebit.Unknown
scheme := telebit.None
if "" != servername {
dstAddr = servername
//scheme = telebit.TLS
scheme = telebit.HTTPS
}
if 80 == dstPort {
scheme = telebit.HTTPS
} else if 443 == dstPort {
// TODO dstAddr = wconn.Servername()
scheme = telebit.HTTP
}
src := telebit.NewAddr(
scheme,
termination,
srcAddr,
srcPort,
)
dst := telebit.NewAddr(
scheme,
termination,
dstAddr,
dstPort,
)
fmt.Printf("[debug] NewAddr src %+v\n", src)
fmt.Printf("[debug] NewAddr dst %+v\n", dst)
err := s.MultiEncoder.Encode(wconn, *src, *dst)
_ = wconn.Close()
fmt.Printf("[debug] Encoder Complete %+v %+v\n", id, err)
s.Clients.Delete(id)
return err
}
func GetServer(servername string) (*SubscriberConn, bool) {
var srv *SubscriberConn
load := -1
// TODO match *.whatever.com
srvMapX, ok := Table.Load(servername)
if !ok {
return nil, false
}
srvMap := srvMapX.(*sync.Map)
srvMap.Range(func(k, v interface{}) bool {
myLoad := 0
mySrv := v.(*SubscriberConn)
mySrv.Clients.Range(func(k, v interface{}) bool {
myLoad += 1
return true
})
// pick the least loaded server
if -1 == load || myLoad < load {
load = myLoad
srv = mySrv
}
return true
})
return srv, true
}