feat(cmd/tcpfwd): add multi-forward positional args and version/help

Accept any number of local-port:remote-host:remote-port forwards as
positional arguments. Backward-compatible with existing --port/--target
flags. Adds --version/-V/version and --help/help handling matching the
auth-proxy pattern, including printVersion printed to stderr at startup.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
AJ ONeal 2026-02-27 21:52:41 -07:00
parent ec64afe390
commit b08525b024
No known key found for this signature in database

View File

@ -2,6 +2,7 @@ package main
import ( import (
"flag" "flag"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@ -9,43 +10,157 @@ import (
"strings" "strings"
) )
const (
name = "tcpfwd"
licenseYear = "2025"
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
licenseType = "CC0-1.0"
)
// replaced by goreleaser / ldflags
var (
version = "0.0.0-dev"
commit = "0000000"
date = "0001-01-01"
)
// printVersion displays the version, commit, and build date.
func printVersion(w io.Writer) {
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
}
type forward struct {
listenAddr string // e.g. ":12345"
target string // e.g. "example.com:2345"
}
// parseForward parses a "local-port:remote-host:remote-port" string.
func parseForward(s string) (forward, error) {
i := strings.Index(s, ":")
if i < 0 || !strings.Contains(s[i+1:], ":") {
return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s)
}
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
}
func main() { func main() {
var listenPort string var listenPort string
var target string var target string
flag.StringVar(&listenPort, "port", "", "Local port to listen on (same as target by default)") var showVersion bool
flag.StringVar(&target, "target", "", "Target host:port (required)")
flag.Parse()
if target == "" { fs := flag.NewFlagSet(name, flag.ContinueOnError)
flag.Usage() fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
fs.BoolVar(&showVersion, "version", false, "show version and exit")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
fmt.Fprintf(os.Stderr, "FLAGS\n")
fs.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nEXAMPLES\n")
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name)
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name)
fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name)
}
// Special handling for version/help before full flag parse
if len(os.Args) > 1 {
arg := os.Args[1]
if arg == "-V" || arg == "--version" || arg == "version" {
printVersion(os.Stdout)
os.Exit(0)
}
if arg == "help" || arg == "-help" || arg == "--help" {
printVersion(os.Stdout)
_, _ = fmt.Fprintln(os.Stdout, "")
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
printVersion(os.Stderr)
fmt.Fprintln(os.Stderr, "")
if err := fs.Parse(os.Args[1:]); err != nil {
if err == flag.ErrHelp {
os.Exit(0)
}
log.Fatalf("flag parse error: %v", err)
}
if showVersion {
printVersion(os.Stdout)
os.Exit(0)
}
// Collect forwards
var forwards []forward
// Backward-compat: --port / --target flags
if target != "" {
port := listenPort
if port == "" {
i := strings.LastIndex(target, ":")
port = target[i+1:]
}
forwards = append(forwards, forward{listenAddr: ":" + port, target: target})
} else if listenPort != "" {
fmt.Fprintf(os.Stderr, "error: --port requires --target\n")
fs.Usage()
os.Exit(1) os.Exit(1)
} }
if len(listenPort) == 0 { // Positional args: local-port:remote-host:remote-port
i := strings.LastIndex(target, ":") for _, arg := range fs.Args() {
listenPort = target[i+1:] fwd, err := parseForward(arg)
if err != nil {
log.Fatalf("%v", err)
}
forwards = append(forwards, fwd)
}
if len(forwards) == 0 {
fs.Usage()
os.Exit(1)
} }
listenAddr := ":" + listenPort
log.Printf("TCP bridge %s → %s", listenAddr, target)
// Note: allow unprivileged users to use this like so: // Note: allow unprivileged users to use this like so:
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf // echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf // sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
listener, err := net.Listen("tcp", listenAddr)
if err != nil { // Bind all listeners first (fail fast before starting any accept loops)
log.Fatalf("Failed to bind %s: %v", listenAddr, err) type boundListener struct {
net.Listener
target string
}
var listeners []boundListener
for _, fwd := range forwards {
l, err := net.Listen("tcp", fwd.listenAddr)
if err != nil {
log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err)
}
log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target)
listeners = append(listeners, boundListener{l, fwd.target})
} }
log.Printf("TCP bridge listening on %s → %s", listenAddr, target)
// Start accept loops
for _, bl := range listeners {
go func(bl boundListener) {
for { for {
client, err := listener.Accept() client, err := bl.Accept()
if err != nil { if err != nil {
log.Printf("Accept error: %v", err) log.Printf("Accept error: %v", err)
continue continue
} }
go handleConn(client, target) go handleConn(client, bl.target)
} }
}(bl)
}
select {} // block forever
} }
func handleConn(client net.Conn, target string) { func handleConn(client net.Conn, target string) {