mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
feat(cmd/tcpfwd): add graceful shutdown with idle connection tracking
On SIGINT/SIGTERM: - Stop accepting new connections - Close connections idle longer than --idle-timeout (default 5s), determined by LastRead/LastWrite timestamps tracked per connection pair - Wait for active connections to drain up to --shutdown-timeout (default 30s) - Force-close any remaining connections if the timeout is exceeded Also switches isClosedConn to use errors.Is(err, net.ErrClosed) and exits the accept loop cleanly when a listener is closed during shutdown. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
b08525b024
commit
89f6e04516
@ -1,13 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -45,15 +51,116 @@ func parseForward(s string) (forward, error) {
|
||||
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
|
||||
}
|
||||
|
||||
// connEntry tracks an active proxied connection pair.
|
||||
type connEntry struct {
|
||||
lastRead atomic.Int64 // UnixNano of last read from either side
|
||||
lastWrite atomic.Int64 // UnixNano of last write to either side
|
||||
client net.Conn
|
||||
remote net.Conn
|
||||
}
|
||||
|
||||
// idleSince returns the time of the most recent I/O on this connection.
|
||||
func (e *connEntry) idleSince() time.Time {
|
||||
lr := time.Unix(0, e.lastRead.Load())
|
||||
lw := time.Unix(0, e.lastWrite.Load())
|
||||
if lr.After(lw) {
|
||||
return lr
|
||||
}
|
||||
return lw
|
||||
}
|
||||
|
||||
func (e *connEntry) isIdle(threshold time.Duration) bool {
|
||||
return time.Since(e.idleSince()) > threshold
|
||||
}
|
||||
|
||||
func (e *connEntry) close() {
|
||||
e.client.Close()
|
||||
e.remote.Close()
|
||||
}
|
||||
|
||||
// connRegistry tracks all active connections.
|
||||
type connRegistry struct {
|
||||
mu sync.Mutex
|
||||
conns map[*connEntry]struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newConnRegistry() *connRegistry {
|
||||
return &connRegistry{conns: make(map[*connEntry]struct{})}
|
||||
}
|
||||
|
||||
func (r *connRegistry) add(e *connEntry) {
|
||||
r.wg.Add(1)
|
||||
r.mu.Lock()
|
||||
r.conns[e] = struct{}{}
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *connRegistry) remove(e *connEntry) {
|
||||
r.mu.Lock()
|
||||
delete(r.conns, e)
|
||||
r.mu.Unlock()
|
||||
r.wg.Done()
|
||||
}
|
||||
|
||||
// closeIdle closes connections idle for longer than threshold and returns the count.
|
||||
func (r *connRegistry) closeIdle(threshold time.Duration) int {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
var n int
|
||||
for e := range r.conns {
|
||||
if e.isIdle(threshold) {
|
||||
e.close()
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func (r *connRegistry) closeAll() {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
for e := range r.conns {
|
||||
e.close()
|
||||
}
|
||||
}
|
||||
|
||||
// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O.
|
||||
type trackingConn struct {
|
||||
net.Conn
|
||||
lastRead *atomic.Int64
|
||||
lastWrite *atomic.Int64
|
||||
}
|
||||
|
||||
func (c *trackingConn) Read(b []byte) (int, error) {
|
||||
n, err := c.Conn.Read(b)
|
||||
if n > 0 {
|
||||
c.lastRead.Store(time.Now().UnixNano())
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (c *trackingConn) Write(b []byte) (int, error) {
|
||||
n, err := c.Conn.Write(b)
|
||||
if n > 0 {
|
||||
c.lastWrite.Store(time.Now().UnixNano())
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func main() {
|
||||
var listenPort string
|
||||
var target string
|
||||
var showVersion bool
|
||||
var idleTimeout time.Duration
|
||||
var shutdownTimeout time.Duration
|
||||
|
||||
fs := flag.NewFlagSet(name, flag.ContinueOnError)
|
||||
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
|
||||
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
|
||||
fs.BoolVar(&showVersion, "version", false, "show version and exit")
|
||||
fs.DurationVar(&idleTimeout, "idle-timeout", 5*time.Second, "close idle connections after this duration on shutdown")
|
||||
fs.DurationVar(&shutdownTimeout, "shutdown-timeout", 30*time.Second, "maximum time to wait for active connections to drain on shutdown")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
|
||||
@ -127,6 +234,8 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
reg := newConnRegistry()
|
||||
|
||||
// Note: allow unprivileged users to use this like so:
|
||||
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
|
||||
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
||||
@ -152,35 +261,79 @@ func main() {
|
||||
for {
|
||||
client, err := bl.Accept()
|
||||
if err != nil {
|
||||
if isClosedConn(err) {
|
||||
return
|
||||
}
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go handleConn(client, bl.target)
|
||||
go handleConn(client, bl.target, reg)
|
||||
}
|
||||
}(bl)
|
||||
}
|
||||
|
||||
select {} // block forever
|
||||
// Wait for shutdown signal
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||
sig := <-sigCh
|
||||
log.Printf("Received %s, shutting down...", sig)
|
||||
|
||||
// Stop accepting new connections
|
||||
for _, l := range listeners {
|
||||
l.Close()
|
||||
}
|
||||
|
||||
func handleConn(client net.Conn, target string) {
|
||||
defer client.Close()
|
||||
// Close connections that have been idle longer than idleTimeout
|
||||
if n := reg.closeIdle(idleTimeout); n > 0 {
|
||||
log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout)
|
||||
}
|
||||
|
||||
// Wait for remaining active connections to drain, up to shutdownTimeout
|
||||
drained := make(chan struct{})
|
||||
go func() {
|
||||
reg.wg.Wait()
|
||||
close(drained)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-drained:
|
||||
log.Printf("All connections closed cleanly")
|
||||
case <-time.After(shutdownTimeout):
|
||||
log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout)
|
||||
reg.closeAll()
|
||||
reg.wg.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func handleConn(client net.Conn, target string, reg *connRegistry) {
|
||||
remote, err := net.Dial("tcp", target)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to %s: %v", target, err)
|
||||
client.Close()
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now().UnixNano()
|
||||
entry := &connEntry{client: client, remote: remote}
|
||||
entry.lastRead.Store(now)
|
||||
entry.lastWrite.Store(now)
|
||||
|
||||
reg.add(entry)
|
||||
defer reg.remove(entry)
|
||||
defer client.Close()
|
||||
defer remote.Close()
|
||||
|
||||
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
|
||||
|
||||
trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
|
||||
trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
|
||||
|
||||
// Bidirectional copy with error handling
|
||||
go func() { _ = copyAndClose(remote, client) }()
|
||||
func() { _ = copyAndClose(client, remote) }()
|
||||
go func() { _ = copyAndClose(trackedRemote, trackedClient) }()
|
||||
_ = copyAndClose(trackedClient, trackedRemote)
|
||||
}
|
||||
|
||||
// copyAndClose copies until EOF or error, then closes dst
|
||||
// copyAndClose copies until EOF or error, then closes dst.
|
||||
func copyAndClose(dst, src net.Conn) error {
|
||||
_, err := io.Copy(dst, src)
|
||||
dst.Close()
|
||||
@ -191,11 +344,7 @@ func copyAndClose(dst, src net.Conn) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// isClosedConn detects common closed-connection errors
|
||||
// isClosedConn detects closed-connection errors.
|
||||
func isClosedConn(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
opErr, ok := err.(*net.OpError)
|
||||
return ok && opErr.Err.Error() == "use of closed network connection"
|
||||
return errors.Is(err, net.ErrClosed)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user