package table import ( "fmt" "net" "sync" "io" "strconv" "strings" "git.rootprojects.org/root/telebit/dbg" telebit "git.rootprojects.org/root/telebit" "github.com/gorilla/websocket" ) // Servers represent actual connections var Servers *sync.Map // Table makes sense to be in-memory, but it could be serialized if needed var Table *sync.Map func init() { Servers = &sync.Map{} Table = &sync.Map{} } func Add(server *SubscriberConn) { var srvMap *sync.Map srvMapX, ok := Servers.Load(server.Grants.Subject) if ok { srvMap = srvMapX.(*sync.Map) } else { srvMap = &sync.Map{} } srvMap.Store(server.RemoteAddr, server) Servers.Store(server.Grants.Subject, srvMap) // Add this server to the domain name matrix for _, domainname := range server.Grants.Domains { var srvMap *sync.Map srvMapX, ok := Table.Load(domainname) if ok { srvMap = srvMapX.(*sync.Map) } else { srvMap = &sync.Map{} } srvMap.Store(server.RemoteAddr, server) Table.Store(domainname, srvMap) } } func RemoveServer(server *SubscriberConn) bool { // TODO remove by RemoteAddr //return false fmt.Printf("[warn] RemoveServer() still calls Remove(subject) instead of removing by RemoteAddr\n") return Remove(server.Grants.Subject) } func Remove(subject string) bool { srvMapX, ok := Servers.Load(subject) fmt.Printf("[debug] has server for %s? %t\n", subject, ok) if !ok { return false } srvMap := srvMapX.(*sync.Map) srvMap.Range(func(k, v interface{}) bool { srv := v.(*SubscriberConn) srv.Clients.Range(func(k, v interface{}) bool { conn := v.(net.Conn) _ = conn.Close() return true }) srv.WSConn.Close() for _, domainname := range srv.Grants.Domains { srvMapX, ok := Table.Load(domainname) if !ok { continue } srvMap = srvMapX.(*sync.Map) srvMap.Delete(srv.RemoteAddr) n := 0 srvMap.Range(func(k, v interface{}) bool { n++ return true }) if 0 == n { // TODO comment out to handle the bad case of 0 servers / empty map Table.Delete(domainname) } } return true }) Servers.Delete(subject) return true } // SubscriberConn represents a tunneled server, its grants, and its clients type SubscriberConn struct { RemoteAddr string WSConn *websocket.Conn WSTun net.Conn // *telebit.WebsocketTunnel Grants *telebit.Grants Clients *sync.Map // TODO is this the right codec type? MultiEncoder *telebit.Encoder MultiDecoder *telebit.Decoder // to fulfill Router interface } func (s *SubscriberConn) RouteBytes(src, dst telebit.Addr, payload []byte) { id := fmt.Sprintf("%s:%d", src.Hostname(), src.Port()) if dbg.Debug { fmt.Println("[debug] Routing some more bytes:", dbg.Trunc(payload, len(payload))) } fmt.Printf("id %s\nsrc %+v\n", id, src) fmt.Printf("dst %s %+v\n", dst.Scheme(), dst) clientX, ok := s.Clients.Load(id) if !ok { // TODO send back closed client error fmt.Println("[debug] no client found for", id) return } client, _ := clientX.(net.Conn) if "end" == dst.Scheme() { fmt.Println("[debug] closing client", id) _ = client.Close() return } for { n, err := client.Write(payload) if dbg.Debug { fmt.Println("[debug] table Write", dbg.Trunc(payload, len(payload))) } if nil == err || io.EOF == err { break } if n > 0 && io.ErrShortWrite == err { payload = payload[n:] continue } break // TODO send back closed client error //return err } } func (s *SubscriberConn) Serve(client net.Conn) error { var wconn *telebit.ConnWrap switch conn := client.(type) { case *telebit.ConnWrap: wconn = conn default: // this probably isn't strictly necessary panic("*SubscriberConn.Serve is special in that it must receive &ConnWrap{ Conn: conn }") } id := client.RemoteAddr().String() fmt.Printf("[DEBUG] NEW ID (ip:port) %s\n", id) s.Clients.Store(id, client) //fmt.Println("[debug] immediately cancel client to simplify testing / debugging") //_ = client.Close() // TODO // - Encode each client to the tunnel // - Find the right client for decoded messages // TODO which order is remote / local? srcParts := strings.Split(client.RemoteAddr().String(), ":") srcAddr := srcParts[0] srcPort, _ := strconv.Atoi(srcParts[1]) fmt.Println("[debug] srcParts", srcParts) dstParts := strings.Split(client.LocalAddr().String(), ":") dstAddr := dstParts[0] dstPort, _ := strconv.Atoi(dstParts[1]) fmt.Println("[debug] dstParts", dstParts) servername := wconn.Servername() termination := telebit.Unknown scheme := telebit.None if "" != servername { dstAddr = servername //scheme = telebit.TLS scheme = telebit.HTTPS } if 80 == dstPort { scheme = telebit.HTTPS } else if 443 == dstPort { // TODO dstAddr = wconn.Servername() scheme = telebit.HTTP } src := telebit.NewAddr( scheme, termination, srcAddr, srcPort, ) dst := telebit.NewAddr( scheme, termination, dstAddr, dstPort, ) fmt.Printf("[debug] NewAddr src %+v\n", src) fmt.Printf("[debug] NewAddr dst %+v\n", dst) err := s.MultiEncoder.Encode(wconn, *src, *dst) _ = wconn.Close() fmt.Printf("[debug] Encoder Complete %+v %+v\n", id, err) s.Clients.Delete(id) return err } func GetServer(servername string) (*SubscriberConn, bool) { var srv *SubscriberConn load := -1 // TODO match *.whatever.com srvMapX, ok := Table.Load(servername) if !ok { return nil, false } srvMap := srvMapX.(*sync.Map) srvMap.Range(func(k, v interface{}) bool { myLoad := 0 mySrv := v.(*SubscriberConn) mySrv.Clients.Range(func(k, v interface{}) bool { myLoad += 1 return true }) // pick the least loaded server if -1 == load || myLoad < load { load = myLoad srv = mySrv } return true }) return srv, true }