diff --git a/cmd/tcpfwd/main.go b/cmd/tcpfwd/main.go index 25ce006..235d3c0 100644 --- a/cmd/tcpfwd/main.go +++ b/cmd/tcpfwd/main.go @@ -1,13 +1,19 @@ package main import ( + "errors" "flag" "fmt" "io" "log" "net" "os" + "os/signal" "strings" + "sync" + "sync/atomic" + "syscall" + "time" ) const ( @@ -45,15 +51,116 @@ func parseForward(s string) (forward, error) { return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil } +// connEntry tracks an active proxied connection pair. +type connEntry struct { + lastRead atomic.Int64 // UnixNano of last read from either side + lastWrite atomic.Int64 // UnixNano of last write to either side + client net.Conn + remote net.Conn +} + +// idleSince returns the time of the most recent I/O on this connection. +func (e *connEntry) idleSince() time.Time { + lr := time.Unix(0, e.lastRead.Load()) + lw := time.Unix(0, e.lastWrite.Load()) + if lr.After(lw) { + return lr + } + return lw +} + +func (e *connEntry) isIdle(threshold time.Duration) bool { + return time.Since(e.idleSince()) > threshold +} + +func (e *connEntry) close() { + e.client.Close() + e.remote.Close() +} + +// connRegistry tracks all active connections. +type connRegistry struct { + mu sync.Mutex + conns map[*connEntry]struct{} + wg sync.WaitGroup +} + +func newConnRegistry() *connRegistry { + return &connRegistry{conns: make(map[*connEntry]struct{})} +} + +func (r *connRegistry) add(e *connEntry) { + r.wg.Add(1) + r.mu.Lock() + r.conns[e] = struct{}{} + r.mu.Unlock() +} + +func (r *connRegistry) remove(e *connEntry) { + r.mu.Lock() + delete(r.conns, e) + r.mu.Unlock() + r.wg.Done() +} + +// closeIdle closes connections idle for longer than threshold and returns the count. +func (r *connRegistry) closeIdle(threshold time.Duration) int { + r.mu.Lock() + defer r.mu.Unlock() + var n int + for e := range r.conns { + if e.isIdle(threshold) { + e.close() + n++ + } + } + return n +} + +func (r *connRegistry) closeAll() { + r.mu.Lock() + defer r.mu.Unlock() + for e := range r.conns { + e.close() + } +} + +// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O. +type trackingConn struct { + net.Conn + lastRead *atomic.Int64 + lastWrite *atomic.Int64 +} + +func (c *trackingConn) Read(b []byte) (int, error) { + n, err := c.Conn.Read(b) + if n > 0 { + c.lastRead.Store(time.Now().UnixNano()) + } + return n, err +} + +func (c *trackingConn) Write(b []byte) (int, error) { + n, err := c.Conn.Write(b) + if n > 0 { + c.lastWrite.Store(time.Now().UnixNano()) + } + return n, err +} + func main() { var listenPort string var target string var showVersion bool + var idleTimeout time.Duration + var shutdownTimeout time.Duration fs := flag.NewFlagSet(name, flag.ContinueOnError) fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)") fs.StringVar(&target, "target", "", "target host:port (use with --port)") fs.BoolVar(&showVersion, "version", false, "show version and exit") + fs.DurationVar(&idleTimeout, "idle-timeout", 5*time.Second, "close idle connections after this duration on shutdown") + fs.DurationVar(&shutdownTimeout, "shutdown-timeout", 30*time.Second, "maximum time to wait for active connections to drain on shutdown") fs.Usage = func() { fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name) @@ -127,6 +234,8 @@ func main() { os.Exit(1) } + reg := newConnRegistry() + // Note: allow unprivileged users to use this like so: // echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf // sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf @@ -152,35 +261,79 @@ func main() { for { client, err := bl.Accept() if err != nil { + if isClosedConn(err) { + return + } log.Printf("Accept error: %v", err) continue } - go handleConn(client, bl.target) + go handleConn(client, bl.target, reg) } }(bl) } - select {} // block forever + // Wait for shutdown signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + sig := <-sigCh + log.Printf("Received %s, shutting down...", sig) + + // Stop accepting new connections + for _, l := range listeners { + l.Close() + } + + // Close connections that have been idle longer than idleTimeout + if n := reg.closeIdle(idleTimeout); n > 0 { + log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout) + } + + // Wait for remaining active connections to drain, up to shutdownTimeout + drained := make(chan struct{}) + go func() { + reg.wg.Wait() + close(drained) + }() + + select { + case <-drained: + log.Printf("All connections closed cleanly") + case <-time.After(shutdownTimeout): + log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout) + reg.closeAll() + reg.wg.Wait() + } } -func handleConn(client net.Conn, target string) { - defer client.Close() - +func handleConn(client net.Conn, target string, reg *connRegistry) { remote, err := net.Dial("tcp", target) if err != nil { log.Printf("Failed to connect to %s: %v", target, err) + client.Close() return } + + now := time.Now().UnixNano() + entry := &connEntry{client: client, remote: remote} + entry.lastRead.Store(now) + entry.lastWrite.Store(now) + + reg.add(entry) + defer reg.remove(entry) + defer client.Close() defer remote.Close() log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr()) + trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} + trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} + // Bidirectional copy with error handling - go func() { _ = copyAndClose(remote, client) }() - func() { _ = copyAndClose(client, remote) }() + go func() { _ = copyAndClose(trackedRemote, trackedClient) }() + _ = copyAndClose(trackedClient, trackedRemote) } -// copyAndClose copies until EOF or error, then closes dst +// copyAndClose copies until EOF or error, then closes dst. func copyAndClose(dst, src net.Conn) error { _, err := io.Copy(dst, src) dst.Close() @@ -191,11 +344,7 @@ func copyAndClose(dst, src net.Conn) error { return err } -// isClosedConn detects common closed-connection errors +// isClosedConn detects closed-connection errors. func isClosedConn(err error) bool { - if err == nil { - return false - } - opErr, ok := err.(*net.OpError) - return ok && opErr.Err.Error() == "use of closed network connection" + return errors.Is(err, net.ErrClosed) }