From 399e8dd65114592aed01ccc14d181f1758474346 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Jul 2020 16:07:14 -0600 Subject: [PATCH] add ping/pong and read/write deadlines websockets --- mplexer/websockettunnel.go | 92 +++++++++++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 16 deletions(-) diff --git a/mplexer/websockettunnel.go b/mplexer/websockettunnel.go index d68a98a..1f424e6 100644 --- a/mplexer/websockettunnel.go +++ b/mplexer/websockettunnel.go @@ -15,11 +15,15 @@ import ( "github.com/gorilla/websocket" ) +var defaultReadWait = 20 * time.Second +var defaultWriteWait = 20 * time.Second + // WebsocketTunnel wraps a websocket.Conn instance to behave like net.Conn. -// TODO make conform. type WebsocketTunnel struct { - wsconn WSConn - tmpr io.Reader + wsconn WSConn + readWait time.Duration + writeWait time.Duration + tmpr io.Reader //w io.WriteCloser //pingCh chan struct{} } @@ -30,6 +34,7 @@ type WSConn interface { NextWriter(messageType int) (io.WriteCloser, error) WriteControl(messageType int, data []byte, deadline time.Time) error WriteMessage(messageType int, data []byte) error + SetPongHandler(h func(appData string) error) SetReadDeadline(t time.Time) error SetWriteDeadline(t time.Time) error Close() error @@ -39,9 +44,44 @@ type WSConn interface { // NewWebsocketTunnel allocates a new websocket connection wrapper func NewWebsocketTunnel(wsconn WSConn) net.Conn { + // TODO only set ping when SetReadDeadline would otherwise fail + // See https://github.com/gorilla/websocket/blob/a6870891/examples/chat/conn.go#L86 + writeWait := defaultWriteWait + readWait := defaultReadWait + go func() { + // Ping every 15 seconds, or stop listening + for { + time.Sleep(15 * time.Second) + deadline := time.Now().Add(writeWait) + // https://www.gorillatoolkit.org/pkg/websocket + // "The Close and WriteControl methods can be called concurrently with all other methods." + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] sending ping (set write deadline %s)\n", writeWait) + } + if err := wsconn.WriteControl(websocket.PingMessage, []byte(""), deadline); nil != err { + wsconn.Close() + fmt.Fprintf(os.Stderr, "failed to write ping message to websocket: %s\n", err) + break + } + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] sent ping (cleared write deadline)\n") + } + } + }() + + wsconn.SetPongHandler(func(pong string) error { + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] received pong (reset read deadline %s): %q\n", readWait, pong) + } + wsconn.SetReadDeadline(time.Now().Add(readWait)) + return nil + }) + return &WebsocketTunnel{ - wsconn: wsconn, - tmpr: nil, + wsconn: wsconn, + readWait: readWait, + writeWait: writeWait, + tmpr: nil, } } @@ -57,16 +97,21 @@ func DialWebsocketTunnel(ctx context.Context, relay, authz string) (net.Conn, er } wsconn, _, err := wsd.DialContext(ctx, relay+sep+"access_token="+authz+"&versions=v1", headers) if nil != err { - fmt.Println("[debug] [wstun] simple dial failed", err, wsconn, ctx) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] simple dial failed %q %v %v\n", err, wsconn, ctx) + } } return NewWebsocketTunnel(wsconn), err } func (wsw *WebsocketTunnel) Read(b []byte) (int, error) { + wsw.wsconn.SetReadDeadline(time.Now().Add(wsw.readWait)) if nil == wsw.tmpr { _, msgr, err := wsw.wsconn.NextReader() if nil != err { - fmt.Println("[debug] [wstun] NextReader err:", err) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] NextReader err: %q\n", err) + } return 0, err } wsw.tmpr = msgr @@ -74,11 +119,11 @@ func (wsw *WebsocketTunnel) Read(b []byte) (int, error) { n, err := wsw.tmpr.Read(b) if dbg.Debug { - fmt.Println("[debug] [wstun] Read", n, dbg.Trunc(b, n)) + fmt.Fprintf(os.Stderr, "[debug] [wstun] Read %d %v\n", n, dbg.Trunc(b, n)) } if nil != err { if dbg.Debug { - fmt.Println("[debug] [wstun] Read (EOF=WS packet complete) err:", err) + fmt.Fprintf(os.Stderr, "[debug] [wstun] Read (EOF=WS packet complete) err: %q\n", err) } if io.EOF == err { wsw.tmpr = nil @@ -90,21 +135,30 @@ func (wsw *WebsocketTunnel) Read(b []byte) (int, error) { } func (wsw *WebsocketTunnel) Write(b []byte) (int, error) { - fmt.Println("[debug] [wstun] Write", len(b)) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] Write %d\n", len(b)) + } // TODO create or reset ping deadline // TODO document that more complete writes are preferred? + wsw.wsconn.SetWriteDeadline(time.Now().Add(wsw.writeWait)) msgw, err := wsw.wsconn.NextWriter(websocket.BinaryMessage) if nil != err { - fmt.Println("[debug] [wstun] NextWriter err:", err) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] NextWriter err: %q\n", err) + } return 0, err } n, err := msgw.Write(b) if nil != err { - fmt.Println("[debug] [wstun] Write err:", err) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] Write err: %q\n", err) + } return n, err } - fmt.Println("[debug] [wstun] Write n", n, "=", len(b)) + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] Write n %d = %d\n", n, len(b)) + } // if the message error fails, we can assume the websocket is damaged return n, msgw.Close() @@ -112,7 +166,9 @@ func (wsw *WebsocketTunnel) Write(b []byte) (int, error) { // Close will close the websocket with a control message func (wsw *WebsocketTunnel) Close() error { - fmt.Println("[debug] [wstun] closing the websocket.Conn") + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] closing the websocket.Conn\n") + } // TODO handle EOF as websocket.CloseNormal? message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection") @@ -150,12 +206,16 @@ func (wsw *WebsocketTunnel) SetDeadline(t time.Time) error { // SetReadDeadline sets the deadline for future Read calls func (wsw *WebsocketTunnel) SetReadDeadline(t time.Time) error { - fmt.Println("[debug] [wstun] read deadline") + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] read deadline\n") + } return wsw.wsconn.SetReadDeadline(t) } // SetWriteDeadline sets the deadline for future Write calls func (wsw *WebsocketTunnel) SetWriteDeadline(t time.Time) error { - fmt.Println("[debug] [wstun] write deadline") + if dbg.Debug { + fmt.Fprintf(os.Stderr, "[debug] [wstun] write deadline\n") + } return wsw.wsconn.SetWriteDeadline(t) }