From 3e4ce3ac91bc96ea1d9ec19d9d22c4d0c3ae192a Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 27 Feb 2026 22:51:40 -0700 Subject: [PATCH] ref(cmd/tcpfwd): extract waitWithTimeout to simplify shutdown block Co-Authored-By: Claude Sonnet 4.6 --- cmd/tcpfwd/main.go | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/cmd/tcpfwd/main.go b/cmd/tcpfwd/main.go index 3a03a36..30c8c68 100644 --- a/cmd/tcpfwd/main.go +++ b/cmd/tcpfwd/main.go @@ -293,22 +293,28 @@ func main() { } // Wait for remaining active connections to drain, up to shutdownTimeout - drained := make(chan struct{}) - go func() { - reg.wg.Wait() - close(drained) - }() - - select { - case <-drained: + if waitWithTimeout(®.wg, shutdownTimeout) { log.Printf("All connections closed cleanly") - case <-time.After(shutdownTimeout): + } else { log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout) reg.closeAll() reg.wg.Wait() } } +// waitWithTimeout waits for wg to reach zero, returning true if it drained +// before the timeout and false if the timeout was exceeded. +func waitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) { remote, err := net.Dial("tcp", target) if err != nil {