telebit/internal/telebit/table.go

272 linhas
5.8 KiB
Go

package telebit
import (
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"git.rootprojects.org/root/telebit/internal/dbg"
"github.com/gorilla/websocket"
)
// Servers represent actual connections
var Servers *sync.Map
// Table makes sense to be in-memory, but it could be serialized if needed
var Table *sync.Map
func init() {
Servers = &sync.Map{}
Table = &sync.Map{}
}
func Add(server *SubscriberConn) {
var srvMap *sync.Map
srvMapX, ok := Servers.Load(server.Grants.Subject)
if ok {
srvMap = srvMapX.(*sync.Map)
} else {
srvMap = &sync.Map{}
}
srvMap.Store(server.RemoteAddr, server)
Servers.Store(server.Grants.Subject, srvMap)
// Add this server to the domain name matrix
for _, domainname := range server.Grants.Domains {
var srvMap *sync.Map
srvMapX, ok := Table.Load(domainname)
if ok {
srvMap = srvMapX.(*sync.Map)
} else {
srvMap = &sync.Map{}
}
srvMap.Store(server.RemoteAddr, server)
Table.Store(domainname, srvMap)
}
}
func RemoveServer(server *SubscriberConn) bool {
// TODO remove by RemoteAddr
//return false
fmt.Fprintf(
os.Stderr,
"[warn] RemoveServer() still calls Remove(subject) instead of removing by RemoteAddr\n",
)
return Remove(server.Grants.Subject)
}
func Remove(subject string) bool {
srvMapX, ok := Servers.Load(subject)
fmt.Printf("Remove(%s): exists? %t\n", subject, ok)
if !ok {
return false
}
srvMap := srvMapX.(*sync.Map)
srvMap.Range(func(k, v interface{}) bool {
srv := v.(*SubscriberConn)
srv.Clients.Range(func(k, v interface{}) bool {
conn := v.(net.Conn)
_ = conn.Close()
return true
})
srv.WSConn.Close()
for _, domainname := range srv.Grants.Domains {
srvMapX, ok := Table.Load(domainname)
if !ok {
continue
}
srvMap = srvMapX.(*sync.Map)
srvMap.Delete(srv.RemoteAddr)
n := 0
srvMap.Range(func(k, v interface{}) bool {
n++
return true
})
if 0 == n {
// TODO comment out to handle the bad case of 0 servers / empty map
Table.Delete(domainname)
}
}
return true
})
Servers.Delete(subject)
return true
}
// SubscriberConn represents a tunneled server, its grants, and its clients
type SubscriberConn struct {
Since *time.Time
RemoteAddr string
WSConn *websocket.Conn
WSTun net.Conn // *WebsocketTunnel
Grants *Grants
Clients *sync.Map
// TODO is this the right codec type?
MultiEncoder *Encoder
MultiDecoder *Decoder
// to fulfill Router interface
}
func (s *SubscriberConn) RouteBytes(src, dst Addr, payload []byte) {
id := fmt.Sprintf("%s:%d", src.Hostname(), src.Port())
if dbg.Debug {
fmt.Fprintf(
os.Stderr,
"[debug] Routing some more bytes: %s\n",
dbg.Trunc(payload, len(payload)),
)
fmt.Printf("\tid %s\nsrc %+v\n", id, src)
fmt.Printf("\tdst %s %+v\n", dst.Scheme(), dst)
}
clientX, ok := s.Clients.Load(id)
if !ok {
// TODO send back closed client error
fmt.Printf("RouteBytes({ %s }, %v, ...) [debug] no client found for %s\n", id, dst)
return
}
client, _ := clientX.(net.Conn)
if "end" == dst.Scheme() {
fmt.Printf("RouteBytes: { %s }.Close(): %v\n", id, dst)
_ = client.Close()
return
}
for {
n, err := client.Write(payload)
if dbg.Debug {
fmt.Fprintf(os.Stderr, "[debug] table Write %s\n", dbg.Trunc(payload, len(payload)))
}
if nil == err || io.EOF == err {
break
}
if n > 0 && io.ErrShortWrite == err {
payload = payload[n:]
continue
}
break
// TODO send back closed client error
//return err
}
}
func (s *SubscriberConn) Serve(client net.Conn) error {
var wconn *ConnWrap
switch conn := client.(type) {
case *ConnWrap:
wconn = conn
default:
// this probably isn't strictly necessary
panic("*SubscriberConn.Serve is special in that it must receive &ConnWrap{ Conn: conn }")
}
id := client.RemoteAddr().String()
if dbg.Debug {
fmt.Fprintf(os.Stderr, "[debug] NEW ID (ip:port) %s\n", id)
}
s.Clients.Store(id, client)
//fmt.Fprintf(os.Stderr, "[debug] immediately cancel client to simplify testing / debugging\n")
//_ = client.Close()
// TODO
// - Encode each client to the tunnel
// - Find the right client for decoded messages
// TODO which order is remote / local?
srcParts := strings.Split(client.RemoteAddr().String(), ":")
srcAddr := srcParts[0]
srcPort, _ := strconv.Atoi(srcParts[1])
dstParts := strings.Split(client.LocalAddr().String(), ":")
dstAddr := dstParts[0]
dstPort, _ := strconv.Atoi(dstParts[1])
if dbg.Debug {
fmt.Fprintf(os.Stderr, "[debug] srcParts %v\n", srcParts)
fmt.Fprintf(os.Stderr, "[debug] dstParts %v\n", dstParts)
}
servername := wconn.Servername()
termination := Unknown
scheme := None
if "" != servername {
dstAddr = servername
//scheme = TLS
scheme = HTTPS
}
if 80 == dstPort {
scheme = HTTPS
} else if 443 == dstPort {
// TODO dstAddr = wconn.Servername()
scheme = HTTP
}
src := NewAddr(
scheme,
termination,
srcAddr,
srcPort,
)
dst := NewAddr(
scheme,
termination,
dstAddr,
dstPort,
)
if dbg.Debug {
fmt.Fprintf(os.Stderr, "[debug] NewAddr src %+v\n", src)
fmt.Fprintf(os.Stderr, "[debug] NewAddr dst %+v\n", dst)
}
err := s.MultiEncoder.Encode(wconn, *src, *dst)
_ = wconn.Close()
if dbg.Debug {
fmt.Fprintf(os.Stderr, "[debug] Encoder Complete %+v %+v\n", id, err)
}
s.Clients.Delete(id)
return err
}
func GetServer(servername string) (*SubscriberConn, bool) {
var srv *SubscriberConn
load := -1
// TODO match *.whatever.com
srvMapX, ok := Table.Load(servername)
if !ok {
return nil, false
}
srvMap := srvMapX.(*sync.Map)
srvMap.Range(func(k, v interface{}) bool {
myLoad := 0
mySrv := v.(*SubscriberConn)
mySrv.Clients.Range(func(k, v interface{}) bool {
myLoad += 1
return true
})
// pick the least loaded server
if -1 == load || myLoad < load {
load = myLoad
srv = mySrv
}
return true
})
return srv, true
}