From 3f16e89f5cc8fab3e487e84d9475c618d0c31412 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 27 Feb 2026 22:08:31 -0700 Subject: [PATCH] ref(cmd/tcpfwd): remove implicit time.Now() for testability - connEntry.isIdle(now, threshold): caller supplies now instead of time.Since - connRegistry.closeIdle(now, threshold): passes now through to isIdle - trackingConn gains a clock func() time.Time field used in Read/Write - handleConn takes clock func() time.Time; uses it to init lastRead/lastWrite and passes it to trackingConn - Call sites in main pass time.Now or time.Now() explicitly Co-Authored-By: Claude Sonnet 4.6 --- cmd/tcpfwd/main.go | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/cmd/tcpfwd/main.go b/cmd/tcpfwd/main.go index 235d3c0..077b2af 100644 --- a/cmd/tcpfwd/main.go +++ b/cmd/tcpfwd/main.go @@ -69,8 +69,8 @@ func (e *connEntry) idleSince() time.Time { return lw } -func (e *connEntry) isIdle(threshold time.Duration) bool { - return time.Since(e.idleSince()) > threshold +func (e *connEntry) isIdle(now time.Time, threshold time.Duration) bool { + return now.Sub(e.idleSince()) > threshold } func (e *connEntry) close() { @@ -104,12 +104,12 @@ func (r *connRegistry) remove(e *connEntry) { } // closeIdle closes connections idle for longer than threshold and returns the count. -func (r *connRegistry) closeIdle(threshold time.Duration) int { +func (r *connRegistry) closeIdle(now time.Time, threshold time.Duration) int { r.mu.Lock() defer r.mu.Unlock() var n int for e := range r.conns { - if e.isIdle(threshold) { + if e.isIdle(now, threshold) { e.close() n++ } @@ -126,16 +126,18 @@ func (r *connRegistry) closeAll() { } // trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O. +// clock is called to get the current time; use time.Now in production. type trackingConn struct { net.Conn lastRead *atomic.Int64 lastWrite *atomic.Int64 + clock func() time.Time } func (c *trackingConn) Read(b []byte) (int, error) { n, err := c.Conn.Read(b) if n > 0 { - c.lastRead.Store(time.Now().UnixNano()) + c.lastRead.Store(c.clock().UnixNano()) } return n, err } @@ -143,7 +145,7 @@ func (c *trackingConn) Read(b []byte) (int, error) { func (c *trackingConn) Write(b []byte) (int, error) { n, err := c.Conn.Write(b) if n > 0 { - c.lastWrite.Store(time.Now().UnixNano()) + c.lastWrite.Store(c.clock().UnixNano()) } return n, err } @@ -267,7 +269,7 @@ func main() { log.Printf("Accept error: %v", err) continue } - go handleConn(client, bl.target, reg) + go handleConn(client, bl.target, reg, time.Now) } }(bl) } @@ -284,7 +286,7 @@ func main() { } // Close connections that have been idle longer than idleTimeout - if n := reg.closeIdle(idleTimeout); n > 0 { + if n := reg.closeIdle(time.Now(), idleTimeout); n > 0 { log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout) } @@ -305,7 +307,7 @@ func main() { } } -func handleConn(client net.Conn, target string, reg *connRegistry) { +func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) { remote, err := net.Dial("tcp", target) if err != nil { log.Printf("Failed to connect to %s: %v", target, err) @@ -313,10 +315,9 @@ func handleConn(client net.Conn, target string, reg *connRegistry) { return } - now := time.Now().UnixNano() entry := &connEntry{client: client, remote: remote} - entry.lastRead.Store(now) - entry.lastWrite.Store(now) + entry.lastRead.Store(clock().UnixNano()) + entry.lastWrite.Store(clock().UnixNano()) reg.add(entry) defer reg.remove(entry) @@ -325,8 +326,8 @@ func handleConn(client net.Conn, target string, reg *connRegistry) { log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr()) - trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} - trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} + trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock} + trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock} // Bidirectional copy with error handling go func() { _ = copyAndClose(trackedRemote, trackedClient) }()