feat(cmd/tcpfwd): add graceful shutdown with idle connection tracking

On SIGINT/SIGTERM:
- Stop accepting new connections
- Close connections idle longer than --idle-timeout (default 5s),
  determined by LastRead/LastWrite timestamps tracked per connection pair
- Wait for active connections to drain up to --shutdown-timeout (default 30s)
- Force-close any remaining connections if the timeout is exceeded

Also switches isClosedConn to use errors.Is(err, net.ErrClosed) and
exits the accept loop cleanly when a listener is closed during shutdown.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
AJ ONeal 2026-02-27 22:03:19 -07:00
parent b08525b024
commit 89f6e04516
No known key found for this signature in database

View File

@ -1,13 +1,19 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"log" "log"
"net" "net"
"os" "os"
"os/signal"
"strings" "strings"
"sync"
"sync/atomic"
"syscall"
"time"
) )
const ( const (
@ -45,15 +51,116 @@ func parseForward(s string) (forward, error) {
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
} }
// connEntry tracks an active proxied connection pair.
type connEntry struct {
lastRead atomic.Int64 // UnixNano of last read from either side
lastWrite atomic.Int64 // UnixNano of last write to either side
client net.Conn
remote net.Conn
}
// idleSince returns the time of the most recent I/O on this connection.
func (e *connEntry) idleSince() time.Time {
lr := time.Unix(0, e.lastRead.Load())
lw := time.Unix(0, e.lastWrite.Load())
if lr.After(lw) {
return lr
}
return lw
}
func (e *connEntry) isIdle(threshold time.Duration) bool {
return time.Since(e.idleSince()) > threshold
}
func (e *connEntry) close() {
e.client.Close()
e.remote.Close()
}
// connRegistry tracks all active connections.
type connRegistry struct {
mu sync.Mutex
conns map[*connEntry]struct{}
wg sync.WaitGroup
}
func newConnRegistry() *connRegistry {
return &connRegistry{conns: make(map[*connEntry]struct{})}
}
func (r *connRegistry) add(e *connEntry) {
r.wg.Add(1)
r.mu.Lock()
r.conns[e] = struct{}{}
r.mu.Unlock()
}
func (r *connRegistry) remove(e *connEntry) {
r.mu.Lock()
delete(r.conns, e)
r.mu.Unlock()
r.wg.Done()
}
// closeIdle closes connections idle for longer than threshold and returns the count.
func (r *connRegistry) closeIdle(threshold time.Duration) int {
r.mu.Lock()
defer r.mu.Unlock()
var n int
for e := range r.conns {
if e.isIdle(threshold) {
e.close()
n++
}
}
return n
}
func (r *connRegistry) closeAll() {
r.mu.Lock()
defer r.mu.Unlock()
for e := range r.conns {
e.close()
}
}
// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O.
type trackingConn struct {
net.Conn
lastRead *atomic.Int64
lastWrite *atomic.Int64
}
func (c *trackingConn) Read(b []byte) (int, error) {
n, err := c.Conn.Read(b)
if n > 0 {
c.lastRead.Store(time.Now().UnixNano())
}
return n, err
}
func (c *trackingConn) Write(b []byte) (int, error) {
n, err := c.Conn.Write(b)
if n > 0 {
c.lastWrite.Store(time.Now().UnixNano())
}
return n, err
}
func main() { func main() {
var listenPort string var listenPort string
var target string var target string
var showVersion bool var showVersion bool
var idleTimeout time.Duration
var shutdownTimeout time.Duration
fs := flag.NewFlagSet(name, flag.ContinueOnError) fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)") fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
fs.StringVar(&target, "target", "", "target host:port (use with --port)") fs.StringVar(&target, "target", "", "target host:port (use with --port)")
fs.BoolVar(&showVersion, "version", false, "show version and exit") fs.BoolVar(&showVersion, "version", false, "show version and exit")
fs.DurationVar(&idleTimeout, "idle-timeout", 5*time.Second, "close idle connections after this duration on shutdown")
fs.DurationVar(&shutdownTimeout, "shutdown-timeout", 30*time.Second, "maximum time to wait for active connections to drain on shutdown")
fs.Usage = func() { fs.Usage = func() {
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name) fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
@ -127,6 +234,8 @@ func main() {
os.Exit(1) os.Exit(1)
} }
reg := newConnRegistry()
// Note: allow unprivileged users to use this like so: // Note: allow unprivileged users to use this like so:
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf // echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf // sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
@ -152,35 +261,79 @@ func main() {
for { for {
client, err := bl.Accept() client, err := bl.Accept()
if err != nil { if err != nil {
if isClosedConn(err) {
return
}
log.Printf("Accept error: %v", err) log.Printf("Accept error: %v", err)
continue continue
} }
go handleConn(client, bl.target) go handleConn(client, bl.target, reg)
} }
}(bl) }(bl)
} }
select {} // block forever // Wait for shutdown signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigCh
log.Printf("Received %s, shutting down...", sig)
// Stop accepting new connections
for _, l := range listeners {
l.Close()
}
// Close connections that have been idle longer than idleTimeout
if n := reg.closeIdle(idleTimeout); n > 0 {
log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout)
}
// Wait for remaining active connections to drain, up to shutdownTimeout
drained := make(chan struct{})
go func() {
reg.wg.Wait()
close(drained)
}()
select {
case <-drained:
log.Printf("All connections closed cleanly")
case <-time.After(shutdownTimeout):
log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout)
reg.closeAll()
reg.wg.Wait()
}
} }
func handleConn(client net.Conn, target string) { func handleConn(client net.Conn, target string, reg *connRegistry) {
defer client.Close()
remote, err := net.Dial("tcp", target) remote, err := net.Dial("tcp", target)
if err != nil { if err != nil {
log.Printf("Failed to connect to %s: %v", target, err) log.Printf("Failed to connect to %s: %v", target, err)
client.Close()
return return
} }
now := time.Now().UnixNano()
entry := &connEntry{client: client, remote: remote}
entry.lastRead.Store(now)
entry.lastWrite.Store(now)
reg.add(entry)
defer reg.remove(entry)
defer client.Close()
defer remote.Close() defer remote.Close()
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr()) log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
// Bidirectional copy with error handling // Bidirectional copy with error handling
go func() { _ = copyAndClose(remote, client) }() go func() { _ = copyAndClose(trackedRemote, trackedClient) }()
func() { _ = copyAndClose(client, remote) }() _ = copyAndClose(trackedClient, trackedRemote)
} }
// copyAndClose copies until EOF or error, then closes dst // copyAndClose copies until EOF or error, then closes dst.
func copyAndClose(dst, src net.Conn) error { func copyAndClose(dst, src net.Conn) error {
_, err := io.Copy(dst, src) _, err := io.Copy(dst, src)
dst.Close() dst.Close()
@ -191,11 +344,7 @@ func copyAndClose(dst, src net.Conn) error {
return err return err
} }
// isClosedConn detects common closed-connection errors // isClosedConn detects closed-connection errors.
func isClosedConn(err error) bool { func isClosedConn(err error) bool {
if err == nil { return errors.Is(err, net.ErrClosed)
return false
}
opErr, ok := err.(*net.OpError)
return ok && opErr.Err.Error() == "use of closed network connection"
} }