ref(cmd/tcpfwd): remove implicit time.Now() for testability

- connEntry.isIdle(now, threshold): caller supplies now instead of time.Since
- connRegistry.closeIdle(now, threshold): passes now through to isIdle
- trackingConn gains a clock func() time.Time field used in Read/Write
- handleConn takes clock func() time.Time; uses it to init lastRead/lastWrite
  and passes it to trackingConn
- Call sites in main pass time.Now or time.Now() explicitly

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
AJ ONeal 2026-02-27 22:08:31 -07:00
parent 89f6e04516
commit 3f16e89f5c
No known key found for this signature in database

View File

@ -69,8 +69,8 @@ func (e *connEntry) idleSince() time.Time {
return lw return lw
} }
func (e *connEntry) isIdle(threshold time.Duration) bool { func (e *connEntry) isIdle(now time.Time, threshold time.Duration) bool {
return time.Since(e.idleSince()) > threshold return now.Sub(e.idleSince()) > threshold
} }
func (e *connEntry) close() { func (e *connEntry) close() {
@ -104,12 +104,12 @@ func (r *connRegistry) remove(e *connEntry) {
} }
// closeIdle closes connections idle for longer than threshold and returns the count. // closeIdle closes connections idle for longer than threshold and returns the count.
func (r *connRegistry) closeIdle(threshold time.Duration) int { func (r *connRegistry) closeIdle(now time.Time, threshold time.Duration) int {
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
var n int var n int
for e := range r.conns { for e := range r.conns {
if e.isIdle(threshold) { if e.isIdle(now, threshold) {
e.close() e.close()
n++ n++
} }
@ -126,16 +126,18 @@ func (r *connRegistry) closeAll() {
} }
// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O. // trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O.
// clock is called to get the current time; use time.Now in production.
type trackingConn struct { type trackingConn struct {
net.Conn net.Conn
lastRead *atomic.Int64 lastRead *atomic.Int64
lastWrite *atomic.Int64 lastWrite *atomic.Int64
clock func() time.Time
} }
func (c *trackingConn) Read(b []byte) (int, error) { func (c *trackingConn) Read(b []byte) (int, error) {
n, err := c.Conn.Read(b) n, err := c.Conn.Read(b)
if n > 0 { if n > 0 {
c.lastRead.Store(time.Now().UnixNano()) c.lastRead.Store(c.clock().UnixNano())
} }
return n, err return n, err
} }
@ -143,7 +145,7 @@ func (c *trackingConn) Read(b []byte) (int, error) {
func (c *trackingConn) Write(b []byte) (int, error) { func (c *trackingConn) Write(b []byte) (int, error) {
n, err := c.Conn.Write(b) n, err := c.Conn.Write(b)
if n > 0 { if n > 0 {
c.lastWrite.Store(time.Now().UnixNano()) c.lastWrite.Store(c.clock().UnixNano())
} }
return n, err return n, err
} }
@ -267,7 +269,7 @@ func main() {
log.Printf("Accept error: %v", err) log.Printf("Accept error: %v", err)
continue continue
} }
go handleConn(client, bl.target, reg) go handleConn(client, bl.target, reg, time.Now)
} }
}(bl) }(bl)
} }
@ -284,7 +286,7 @@ func main() {
} }
// Close connections that have been idle longer than idleTimeout // Close connections that have been idle longer than idleTimeout
if n := reg.closeIdle(idleTimeout); n > 0 { if n := reg.closeIdle(time.Now(), idleTimeout); n > 0 {
log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout) log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout)
} }
@ -305,7 +307,7 @@ func main() {
} }
} }
func handleConn(client net.Conn, target string, reg *connRegistry) { func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) {
remote, err := net.Dial("tcp", target) remote, err := net.Dial("tcp", target)
if err != nil { if err != nil {
log.Printf("Failed to connect to %s: %v", target, err) log.Printf("Failed to connect to %s: %v", target, err)
@ -313,10 +315,9 @@ func handleConn(client net.Conn, target string, reg *connRegistry) {
return return
} }
now := time.Now().UnixNano()
entry := &connEntry{client: client, remote: remote} entry := &connEntry{client: client, remote: remote}
entry.lastRead.Store(now) entry.lastRead.Store(clock().UnixNano())
entry.lastWrite.Store(now) entry.lastWrite.Store(clock().UnixNano())
reg.add(entry) reg.add(entry)
defer reg.remove(entry) defer reg.remove(entry)
@ -325,8 +326,8 @@ func handleConn(client net.Conn, target string, reg *connRegistry) {
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr()) log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock}
trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite} trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock}
// Bidirectional copy with error handling // Bidirectional copy with error handling
go func() { _ = copyAndClose(trackedRemote, trackedClient) }() go func() { _ = copyAndClose(trackedRemote, trackedClient) }()