mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
ref(cmd/tcpfwd): remove implicit time.Now() for testability
- connEntry.isIdle(now, threshold): caller supplies now instead of time.Since - connRegistry.closeIdle(now, threshold): passes now through to isIdle - trackingConn gains a clock func() time.Time field used in Read/Write - handleConn takes clock func() time.Time; uses it to init lastRead/lastWrite and passes it to trackingConn - Call sites in main pass time.Now or time.Now() explicitly Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
89f6e04516
commit
3f16e89f5c
@ -69,8 +69,8 @@ func (e *connEntry) idleSince() time.Time {
|
|||||||
return lw
|
return lw
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *connEntry) isIdle(threshold time.Duration) bool {
|
func (e *connEntry) isIdle(now time.Time, threshold time.Duration) bool {
|
||||||
return time.Since(e.idleSince()) > threshold
|
return now.Sub(e.idleSince()) > threshold
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *connEntry) close() {
|
func (e *connEntry) close() {
|
||||||
@ -104,12 +104,12 @@ func (r *connRegistry) remove(e *connEntry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// closeIdle closes connections idle for longer than threshold and returns the count.
|
// closeIdle closes connections idle for longer than threshold and returns the count.
|
||||||
func (r *connRegistry) closeIdle(threshold time.Duration) int {
|
func (r *connRegistry) closeIdle(now time.Time, threshold time.Duration) int {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
var n int
|
var n int
|
||||||
for e := range r.conns {
|
for e := range r.conns {
|
||||||
if e.isIdle(threshold) {
|
if e.isIdle(now, threshold) {
|
||||||
e.close()
|
e.close()
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
@ -126,16 +126,18 @@ func (r *connRegistry) closeAll() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O.
|
// trackingConn wraps a net.Conn and updates shared lastRead/lastWrite atomics on I/O.
|
||||||
|
// clock is called to get the current time; use time.Now in production.
|
||||||
type trackingConn struct {
|
type trackingConn struct {
|
||||||
net.Conn
|
net.Conn
|
||||||
lastRead *atomic.Int64
|
lastRead *atomic.Int64
|
||||||
lastWrite *atomic.Int64
|
lastWrite *atomic.Int64
|
||||||
|
clock func() time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *trackingConn) Read(b []byte) (int, error) {
|
func (c *trackingConn) Read(b []byte) (int, error) {
|
||||||
n, err := c.Conn.Read(b)
|
n, err := c.Conn.Read(b)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
c.lastRead.Store(time.Now().UnixNano())
|
c.lastRead.Store(c.clock().UnixNano())
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
@ -143,7 +145,7 @@ func (c *trackingConn) Read(b []byte) (int, error) {
|
|||||||
func (c *trackingConn) Write(b []byte) (int, error) {
|
func (c *trackingConn) Write(b []byte) (int, error) {
|
||||||
n, err := c.Conn.Write(b)
|
n, err := c.Conn.Write(b)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
c.lastWrite.Store(time.Now().UnixNano())
|
c.lastWrite.Store(c.clock().UnixNano())
|
||||||
}
|
}
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
@ -267,7 +269,7 @@ func main() {
|
|||||||
log.Printf("Accept error: %v", err)
|
log.Printf("Accept error: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
go handleConn(client, bl.target, reg)
|
go handleConn(client, bl.target, reg, time.Now)
|
||||||
}
|
}
|
||||||
}(bl)
|
}(bl)
|
||||||
}
|
}
|
||||||
@ -284,7 +286,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close connections that have been idle longer than idleTimeout
|
// Close connections that have been idle longer than idleTimeout
|
||||||
if n := reg.closeIdle(idleTimeout); n > 0 {
|
if n := reg.closeIdle(time.Now(), idleTimeout); n > 0 {
|
||||||
log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout)
|
log.Printf("Closed %d idle connection(s) (idle > %s)", n, idleTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -305,7 +307,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConn(client net.Conn, target string, reg *connRegistry) {
|
func handleConn(client net.Conn, target string, reg *connRegistry, clock func() time.Time) {
|
||||||
remote, err := net.Dial("tcp", target)
|
remote, err := net.Dial("tcp", target)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to connect to %s: %v", target, err)
|
log.Printf("Failed to connect to %s: %v", target, err)
|
||||||
@ -313,10 +315,9 @@ func handleConn(client net.Conn, target string, reg *connRegistry) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now().UnixNano()
|
|
||||||
entry := &connEntry{client: client, remote: remote}
|
entry := &connEntry{client: client, remote: remote}
|
||||||
entry.lastRead.Store(now)
|
entry.lastRead.Store(clock().UnixNano())
|
||||||
entry.lastWrite.Store(now)
|
entry.lastWrite.Store(clock().UnixNano())
|
||||||
|
|
||||||
reg.add(entry)
|
reg.add(entry)
|
||||||
defer reg.remove(entry)
|
defer reg.remove(entry)
|
||||||
@ -325,8 +326,8 @@ func handleConn(client net.Conn, target string, reg *connRegistry) {
|
|||||||
|
|
||||||
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
|
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
|
||||||
|
|
||||||
trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
|
trackedClient := &trackingConn{Conn: client, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock}
|
||||||
trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite}
|
trackedRemote := &trackingConn{Conn: remote, lastRead: &entry.lastRead, lastWrite: &entry.lastWrite, clock: clock}
|
||||||
|
|
||||||
// Bidirectional copy with error handling
|
// Bidirectional copy with error handling
|
||||||
go func() { _ = copyAndClose(trackedRemote, trackedClient) }()
|
go func() { _ = copyAndClose(trackedRemote, trackedClient) }()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user