From c80c87c6677beb34dc5ed71a51ce5931e8334a18 Mon Sep 17 00:00:00 2001 From: tigerbot Date: Thu, 30 Mar 2017 18:04:28 -0600 Subject: [PATCH] changed the routine structure of the client --- go-rvpn-client/main.go | 7 +- rvpn/client/client.go | 28 ++---- rvpn/client/local_conns.go | 82 --------------- rvpn/client/ws_handler.go | 199 +++++++++++++++++++++++++++++++++++++ 4 files changed, 212 insertions(+), 104 deletions(-) delete mode 100644 rvpn/client/local_conns.go create mode 100644 rvpn/client/ws_handler.go diff --git a/go-rvpn-client/main.go b/go-rvpn-client/main.go index 03e4585..58a2a65 100644 --- a/go-rvpn-client/main.go +++ b/go-rvpn-client/main.go @@ -1,6 +1,8 @@ package main import ( + "context" + "git.daplie.com/Daplie/go-rvpn-server/rvpn/client" jwt "github.com/dgrijalva/jwt-go" ) @@ -18,11 +20,14 @@ func main() { panic(err) } + ctx, quit := context.WithCancel(context.Background()) + defer quit() + config := client.Config{ Server: "wss://localhost.daplie.me:9999", Services: map[string]int{"https": 8443}, Token: tokenStr, Insecure: true, } - panic(client.Run(&config)) + panic(client.Run(ctx, &config)) } diff --git a/rvpn/client/client.go b/rvpn/client/client.go index c9e2450..a0821f5 100644 --- a/rvpn/client/client.go +++ b/rvpn/client/client.go @@ -1,12 +1,11 @@ package client import ( + "context" "crypto/tls" + "fmt" "net/url" - "fmt" - - "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" "github.com/gorilla/websocket" ) @@ -17,7 +16,7 @@ type Config struct { Insecure bool } -func Run(config *Config) error { +func Run(ctx context.Context, config *Config) error { serverURL, err := url.Parse(config.Server) if err != nil { return fmt.Errorf("Invalid server URL: %v", err) @@ -36,26 +35,13 @@ func Run(config *Config) error { dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} } + handler := NewWsHandler(config.Services) + conn, _, err := dialer.Dial(serverURL.String(), nil) if err != nil { return fmt.Errorf("First connection to server failed - check auth: %v", err) } - localConns := newLocalConns(conn, config.Services) - for { - _, message, err := conn.ReadMessage() - if err != nil { - return fmt.Errorf("websocket read errored: %v", err) - } - - p, err := packer.ReadMessage(message) - if err != nil { - return fmt.Errorf("packer read failed: %v", err) - } - - err = localConns.Write(p) - if err != nil { - return fmt.Errorf("failed to write data: %v", err) - } - } + handler.HandleConn(ctx, conn) + return nil } diff --git a/rvpn/client/local_conns.go b/rvpn/client/local_conns.go deleted file mode 100644 index 704a98c..0000000 --- a/rvpn/client/local_conns.go +++ /dev/null @@ -1,82 +0,0 @@ -package client - -import ( - "fmt" - "net" - "sync" - - "github.com/gorilla/websocket" - - "io" - - "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" -) - -type localConns struct { - lock sync.RWMutex - locals map[string]net.Conn - services map[string]int - remote *websocket.Conn -} - -func newLocalConns(remote *websocket.Conn, services map[string]int) *localConns { - l := new(localConns) - l.services = services - l.remote = remote - l.locals = make(map[string]net.Conn) - return l -} - -func (l *localConns) Write(p *packer.Packer) error { - l.lock.RLock() - defer l.lock.RUnlock() - - key := fmt.Sprintf("%s:%d", p.Address(), p.Port()) - if conn := l.locals[key]; conn != nil { - _, err := conn.Write(p.Data.Data()) - return err - } - - go l.startConnection(p) - return nil -} - -func (l *localConns) startConnection(orig *packer.Packer) { - key := fmt.Sprintf("%s:%d", orig.Address(), orig.Port()) - addr := fmt.Sprintf("127.0.0.1:%d", l.services[orig.Service()]) - conn, err := net.Dial("tcp", addr) - if err != nil { - loginfo.Println("failed to open connection to", addr, err) - return - } - loginfo.Println("opened connection to", addr, "with key", key) - defer loginfo.Println("finished connection to", addr, "with key", key) - - conn.Write(orig.Data.Data()) - - l.lock.Lock() - l.locals[key] = conn - l.lock.Unlock() - defer func() { - l.lock.Lock() - delete(l.locals, key) - l.lock.Unlock() - conn.Close() - }() - - buf := make([]byte, 4096) - for { - size, err := conn.Read(buf) - if err != nil { - if err != io.EOF { - loginfo.Println("failed to read from local connection to", addr, err) - } - return - } - - p := packer.NewPacker(&orig.Header) - p.Data.AppendBytes(buf[:size]) - packed := p.PackV1() - l.remote.WriteMessage(websocket.BinaryMessage, packed.Bytes()) - } -} diff --git a/rvpn/client/ws_handler.go b/rvpn/client/ws_handler.go new file mode 100644 index 0000000..be3e59b --- /dev/null +++ b/rvpn/client/ws_handler.go @@ -0,0 +1,199 @@ +package client + +import ( + "context" + "fmt" + "io" + "net" + "sync" + "time" + + "github.com/gorilla/websocket" + + "git.daplie.com/Daplie/go-rvpn-server/rvpn/packer" +) + +// WsHandler handles all of reading and writing for the websocket connection to the RVPN server +// and the TCP connections to the local servers. +type WsHandler struct { + lock sync.Mutex + localConns map[string]net.Conn + + servicePorts map[string]int + + ctx context.Context + dataChan chan *packer.Packer +} + +// NewWsHandler creates a new handler ready to be given a websocket connection. The services +// argument specifies what port each service type should be directed to on the local interface. +func NewWsHandler(services map[string]int) *WsHandler { + h := new(WsHandler) + h.servicePorts = services + h.localConns = make(map[string]net.Conn) + return h +} + +// HandleConn handles all of the traffic on the provided websocket connection. The function +// will not return until the connection ends. +// +// The WsHandler is designed to handle exactly one connection at a time. If HandleConn is called +// again while the instance is still handling another connection (or if the previous connection +// failed to fully cleanup) calling HandleConn again will panic. +func (h *WsHandler) HandleConn(ctx context.Context, conn *websocket.Conn) { + if h.dataChan != nil { + panic("WsHandler.HandleConn called while handling a previous connection") + } + if len(h.localConns) > 0 { + panic(fmt.Sprintf("WsHandler has lingering local connections: %v", h.localConns)) + } + h.dataChan = make(chan *packer.Packer) + + // The sub context allows us to clean up all of the goroutines associated with this websocket + // if it closes at any point for any reason. + subCtx, socketQuit := context.WithCancel(ctx) + defer socketQuit() + h.ctx = subCtx + + // Start the routine that will write all of the data from the local connection to the + // remote websocket connection. + go h.writeRemote(conn) + + for { + _, message, err := conn.ReadMessage() + if err != nil { + loginfo.Println("failed to read message from websocket", err) + return + } + + p, err := packer.ReadMessage(message) + if err != nil { + loginfo.Println("failed to parse message from websocket", err) + return + } + + h.writeLocal(p) + } +} + +func (h *WsHandler) writeRemote(conn *websocket.Conn) { + defer h.closeConnections() + defer func() { h.dataChan = nil }() + + for { + select { + case <-h.ctx.Done(): + // We can't tell if this happened because the websocket is already closed/errored or + // if it happened because the main context closed (in which case it would be preferable + // to properly close the connection). As such we try to close the connection and ignore + // all errors if it doesn't work. + message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection") + deadline := time.Now().Add(10 * time.Second) + conn.WriteControl(websocket.CloseMessage, message, deadline) + conn.Close() + return + + case p := <-h.dataChan: + packed := p.PackV1() + conn.WriteMessage(websocket.BinaryMessage, packed.Bytes()) + } + } +} + +func (h *WsHandler) sendSpecial(header *packer.Header, service string) { + p := packer.NewPacker(header) + p.SetService(service) + + // Avoid blocking on the data channel if the websocket is already closed + select { + case h.dataChan <- p: + case <-h.ctx.Done(): + } +} + +func (h *WsHandler) getLocalConn(p *packer.Packer) net.Conn { + h.lock.Lock() + defer h.lock.Unlock() + + key := fmt.Sprintf("%s:%d", p.Address(), p.Port()) + // Simplest case: it's already open, just return it. + if conn := h.localConns[key]; conn != nil { + return conn + } + + port := h.servicePorts[p.Service()] + if port == 0 { + loginfo.Println("cannot open connection for invalid service", p.Service()) + return nil + } + + conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + loginfo.Println("unable to open local connection on port", port, err) + return nil + } + + loginfo.Println("opened new connection to port", port, "for", key) + h.localConns[key] = conn + go h.readLocal(key, &p.Header) + return conn +} + +func (h *WsHandler) writeLocal(p *packer.Packer) { + conn := h.getLocalConn(p) + if conn == nil { + h.sendSpecial(&p.Header, "error") + return + } + + if p.Service() == "error" || p.Service() == "end" { + conn.Close() + return + } + + if _, err := conn.Write(p.Data.Data()); err != nil { + h.sendSpecial(&p.Header, "error") + loginfo.Println("failed to write to local connection", err) + } +} + +func (h *WsHandler) readLocal(key string, header *packer.Header) { + h.lock.Lock() + conn := h.localConns[key] + h.lock.Unlock() + + defer conn.Close() + defer func() { + h.lock.Lock() + delete(h.localConns, key) + h.lock.Unlock() + }() + defer loginfo.Println("finished with client", key) + + buf := make([]byte, 4096) + for { + size, err := conn.Read(buf) + if err != nil { + if err == io.EOF { + h.sendSpecial(header, "end") + } else { + loginfo.Println("failed to read from local connection for", key, err) + h.sendSpecial(header, "error") + } + return + } + + p := packer.NewPacker(header) + p.Data.AppendBytes(buf[:size]) + h.dataChan <- p + } +} + +func (h *WsHandler) closeConnections() { + h.lock.Lock() + defer h.lock.Unlock() + + for _, conn := range h.localConns { + conn.Close() + } +}