250 lines
5.6 KiB
Go
250 lines
5.6 KiB
Go
package table
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"sync"
|
|
|
|
"io"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"git.coolaj86.com/coolaj86/go-telebitd/dbg"
|
|
telebit "git.coolaj86.com/coolaj86/go-telebitd/mplexer"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Servers represent actual connections
|
|
var Servers *sync.Map
|
|
|
|
// Table makes sense to be in-memory, but it could be serialized if needed
|
|
var Table *sync.Map
|
|
|
|
func init() {
|
|
Servers = &sync.Map{}
|
|
Table = &sync.Map{}
|
|
}
|
|
|
|
func Add(server *SubscriberConn) {
|
|
var srvMap *sync.Map
|
|
srvMapX, ok := Servers.Load(server.Grants.Subject)
|
|
if ok {
|
|
srvMap = srvMapX.(*sync.Map)
|
|
} else {
|
|
srvMap = &sync.Map{}
|
|
}
|
|
srvMap.Store(server.RemoteAddr, server)
|
|
Servers.Store(server.Grants.Subject, srvMap)
|
|
|
|
// Add this server to the domain name matrix
|
|
for _, domainname := range server.Grants.Domains {
|
|
var srvMap *sync.Map
|
|
srvMapX, ok := Table.Load(domainname)
|
|
if ok {
|
|
srvMap = srvMapX.(*sync.Map)
|
|
} else {
|
|
srvMap = &sync.Map{}
|
|
}
|
|
srvMap.Store(server.RemoteAddr, server)
|
|
Table.Store(domainname, srvMap)
|
|
}
|
|
}
|
|
|
|
func RemoveServer(server *SubscriberConn) bool {
|
|
// TODO remove by RemoteAddr
|
|
//return false
|
|
fmt.Printf("[warn] RemoveServer() still calls Remove(subject) instead of removing by RemoteAddr\n")
|
|
return Remove(server.Grants.Subject)
|
|
}
|
|
|
|
func Remove(subject string) bool {
|
|
srvMapX, ok := Servers.Load(subject)
|
|
fmt.Printf("[debug] has server for %s? %t\n", subject, ok)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
srvMap := srvMapX.(*sync.Map)
|
|
srvMap.Range(func(k, v interface{}) bool {
|
|
srv := v.(*SubscriberConn)
|
|
srv.Clients.Range(func(k, v interface{}) bool {
|
|
conn := v.(net.Conn)
|
|
_ = conn.Close()
|
|
return true
|
|
})
|
|
srv.WSConn.Close()
|
|
for _, domainname := range srv.Grants.Domains {
|
|
srvMapX, ok := Table.Load(domainname)
|
|
if !ok {
|
|
continue
|
|
}
|
|
srvMap = srvMapX.(*sync.Map)
|
|
srvMap.Delete(srv.RemoteAddr)
|
|
n := 0
|
|
srvMap.Range(func(k, v interface{}) bool {
|
|
n++
|
|
return true
|
|
})
|
|
if 0 == n {
|
|
// TODO comment out to handle the bad case of 0 servers / empty map
|
|
Table.Delete(domainname)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
Servers.Delete(subject)
|
|
|
|
return true
|
|
}
|
|
|
|
// SubscriberConn represents a tunneled server, its grants, and its clients
|
|
type SubscriberConn struct {
|
|
RemoteAddr string
|
|
WSConn *websocket.Conn
|
|
WSTun net.Conn // *telebit.WebsocketTunnel
|
|
Grants *telebit.Grants
|
|
Clients *sync.Map
|
|
|
|
// TODO is this the right codec type?
|
|
MultiEncoder *telebit.Encoder
|
|
MultiDecoder *telebit.Decoder
|
|
|
|
// to fulfill Router interface
|
|
}
|
|
|
|
func (s *SubscriberConn) RouteBytes(src, dst telebit.Addr, payload []byte) {
|
|
id := fmt.Sprintf("%s:%d", src.Hostname(), src.Port())
|
|
if dbg.Debug {
|
|
fmt.Println("[debug] Routing some more bytes:", dbg.Trunc(payload, len(payload)))
|
|
}
|
|
fmt.Printf("id %s\nsrc %+v\n", id, src)
|
|
fmt.Printf("dst %s %+v\n", dst.Scheme(), dst)
|
|
clientX, ok := s.Clients.Load(id)
|
|
if !ok {
|
|
// TODO send back closed client error
|
|
fmt.Println("[debug] no client found for", id)
|
|
return
|
|
}
|
|
|
|
client, _ := clientX.(net.Conn)
|
|
if "end" == dst.Scheme() {
|
|
fmt.Println("[debug] closing client", id)
|
|
_ = client.Close()
|
|
return
|
|
}
|
|
|
|
for {
|
|
n, err := client.Write(payload)
|
|
if dbg.Debug {
|
|
fmt.Println("[debug] table Write", dbg.Trunc(payload, len(payload)))
|
|
}
|
|
if nil == err || io.EOF == err {
|
|
break
|
|
}
|
|
if n > 0 && io.ErrShortWrite == err {
|
|
payload = payload[n:]
|
|
continue
|
|
}
|
|
break
|
|
// TODO send back closed client error
|
|
//return err
|
|
}
|
|
}
|
|
|
|
func (s *SubscriberConn) Serve(client net.Conn) error {
|
|
var wconn *telebit.ConnWrap
|
|
switch conn := client.(type) {
|
|
case *telebit.ConnWrap:
|
|
wconn = conn
|
|
default:
|
|
// this probably isn't strictly necessary
|
|
panic("*SubscriberConn.Serve is special in that it must receive &ConnWrap{ Conn: conn }")
|
|
}
|
|
|
|
id := client.RemoteAddr().String()
|
|
fmt.Printf("[DEBUG] NEW ID (ip:port) %s\n", id)
|
|
s.Clients.Store(id, client)
|
|
|
|
//fmt.Println("[debug] immediately cancel client to simplify testing / debugging")
|
|
//_ = client.Close()
|
|
|
|
// TODO
|
|
// - Encode each client to the tunnel
|
|
// - Find the right client for decoded messages
|
|
|
|
// TODO which order is remote / local?
|
|
srcParts := strings.Split(client.RemoteAddr().String(), ":")
|
|
srcAddr := srcParts[0]
|
|
srcPort, _ := strconv.Atoi(srcParts[1])
|
|
fmt.Println("[debug] srcParts", srcParts)
|
|
|
|
dstParts := strings.Split(client.LocalAddr().String(), ":")
|
|
dstAddr := dstParts[0]
|
|
dstPort, _ := strconv.Atoi(dstParts[1])
|
|
fmt.Println("[debug] dstParts", dstParts)
|
|
servername := wconn.Servername()
|
|
|
|
termination := telebit.Unknown
|
|
scheme := telebit.None
|
|
if "" != servername {
|
|
dstAddr = servername
|
|
//scheme = telebit.TLS
|
|
scheme = telebit.HTTPS
|
|
}
|
|
if 80 == dstPort {
|
|
scheme = telebit.HTTPS
|
|
} else if 443 == dstPort {
|
|
// TODO dstAddr = wconn.Servername()
|
|
scheme = telebit.HTTP
|
|
}
|
|
|
|
src := telebit.NewAddr(
|
|
scheme,
|
|
termination,
|
|
srcAddr,
|
|
srcPort,
|
|
)
|
|
dst := telebit.NewAddr(
|
|
scheme,
|
|
termination,
|
|
dstAddr,
|
|
dstPort,
|
|
)
|
|
fmt.Printf("[debug] NewAddr src %+v\n", src)
|
|
fmt.Printf("[debug] NewAddr dst %+v\n", dst)
|
|
|
|
err := s.MultiEncoder.Encode(wconn, *src, *dst)
|
|
_ = wconn.Close()
|
|
fmt.Printf("[debug] Encoder Complete %+v %+v\n", id, err)
|
|
s.Clients.Delete(id)
|
|
return err
|
|
}
|
|
|
|
func GetServer(servername string) (*SubscriberConn, bool) {
|
|
var srv *SubscriberConn
|
|
load := -1
|
|
// TODO match *.whatever.com
|
|
srvMapX, ok := Table.Load(servername)
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
|
|
srvMap := srvMapX.(*sync.Map)
|
|
srvMap.Range(func(k, v interface{}) bool {
|
|
myLoad := 0
|
|
mySrv := v.(*SubscriberConn)
|
|
mySrv.Clients.Range(func(k, v interface{}) bool {
|
|
myLoad += 1
|
|
return true
|
|
})
|
|
// pick the least loaded server
|
|
if -1 == load || myLoad < load {
|
|
load = myLoad
|
|
srv = mySrv
|
|
}
|
|
return true
|
|
})
|
|
|
|
return srv, true
|
|
}
|