ref(cmd/tcpfwd): extract waitWithTimeout to simplify shutdown block

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
AJ ONeal 2026-02-27 22:51:40 -07:00
parent 0d356c0b26
commit 3e4ce3ac91
No known key found for this signature in database

View File

@ -293,22 +293,28 @@ func main() {
} }
// Wait for remaining active connections to drain, up to shutdownTimeout // Wait for remaining active connections to drain, up to shutdownTimeout
drained := make(chan struct{}) if waitWithTimeout(&reg.wg, shutdownTimeout) {
go func() {
reg.wg.Wait()
close(drained)
}()
select {
case <-drained:
log.Printf("All connections closed cleanly") log.Printf("All connections closed cleanly")
case <-time.After(shutdownTimeout): } else {
log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout) log.Printf("Shutdown timeout (%s) exceeded, force-closing remaining connections", shutdownTimeout)
reg.closeAll() reg.closeAll()
reg.wg.Wait() reg.wg.Wait()
} }
} }
// waitWithTimeout waits for wg to reach zero, returning true if it drained
// before the timeout and false if the timeout was exceeded.
func waitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
done := make(chan struct{})
go func() { wg.Wait(); close(done) }()
select {
case <-done:
return true
case <-time.After(timeout):
return false
}
}
func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) { func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) {
remote, err := net.Dial("tcp", target) remote, err := net.Dial("tcp", target)
if err != nil { if err != nil {