From b08525b024a8266578575dfdf680d8456671d308 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Fri, 27 Feb 2026 21:52:41 -0700 Subject: [PATCH] feat(cmd/tcpfwd): add multi-forward positional args and version/help Accept any number of local-port:remote-host:remote-port forwards as positional arguments. Backward-compatible with existing --port/--target flags. Adds --version/-V/version and --help/help handling matching the auth-proxy pattern, including printVersion printed to stderr at startup. Co-Authored-By: Claude Sonnet 4.6 --- cmd/tcpfwd/main.go | 157 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 136 insertions(+), 21 deletions(-) diff --git a/cmd/tcpfwd/main.go b/cmd/tcpfwd/main.go index a8900f2..25ce006 100644 --- a/cmd/tcpfwd/main.go +++ b/cmd/tcpfwd/main.go @@ -2,6 +2,7 @@ package main import ( "flag" + "fmt" "io" "log" "net" @@ -9,43 +10,157 @@ import ( "strings" ) +const ( + name = "tcpfwd" + licenseYear = "2025" + licenseOwner = "AJ ONeal (https://therootcompany.com)" + licenseType = "CC0-1.0" +) + +// replaced by goreleaser / ldflags +var ( + version = "0.0.0-dev" + commit = "0000000" + date = "0001-01-01" +) + +// printVersion displays the version, commit, and build date. +func printVersion(w io.Writer) { + _, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date) + _, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner) + _, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType) +} + +type forward struct { + listenAddr string // e.g. ":12345" + target string // e.g. "example.com:2345" +} + +// parseForward parses a "local-port:remote-host:remote-port" string. +func parseForward(s string) (forward, error) { + i := strings.Index(s, ":") + if i < 0 || !strings.Contains(s[i+1:], ":") { + return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s) + } + return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil +} + func main() { var listenPort string var target string - flag.StringVar(&listenPort, "port", "", "Local port to listen on (same as target by default)") - flag.StringVar(&target, "target", "", "Target host:port (required)") - flag.Parse() + var showVersion bool - if target == "" { - flag.Usage() + fs := flag.NewFlagSet(name, flag.ContinueOnError) + fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)") + fs.StringVar(&target, "target", "", "target host:port (use with --port)") + fs.BoolVar(&showVersion, "version", false, "show version and exit") + + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name) + fmt.Fprintf(os.Stderr, "FLAGS\n") + fs.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nEXAMPLES\n") + fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name) + fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name) + fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name) + } + + // Special handling for version/help before full flag parse + if len(os.Args) > 1 { + arg := os.Args[1] + if arg == "-V" || arg == "--version" || arg == "version" { + printVersion(os.Stdout) + os.Exit(0) + } + if arg == "help" || arg == "-help" || arg == "--help" { + printVersion(os.Stdout) + _, _ = fmt.Fprintln(os.Stdout, "") + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + + printVersion(os.Stderr) + fmt.Fprintln(os.Stderr, "") + + if err := fs.Parse(os.Args[1:]); err != nil { + if err == flag.ErrHelp { + os.Exit(0) + } + log.Fatalf("flag parse error: %v", err) + } + + if showVersion { + printVersion(os.Stdout) + os.Exit(0) + } + + // Collect forwards + var forwards []forward + + // Backward-compat: --port / --target flags + if target != "" { + port := listenPort + if port == "" { + i := strings.LastIndex(target, ":") + port = target[i+1:] + } + forwards = append(forwards, forward{listenAddr: ":" + port, target: target}) + } else if listenPort != "" { + fmt.Fprintf(os.Stderr, "error: --port requires --target\n") + fs.Usage() os.Exit(1) } - if len(listenPort) == 0 { - i := strings.LastIndex(target, ":") - listenPort = target[i+1:] + // Positional args: local-port:remote-host:remote-port + for _, arg := range fs.Args() { + fwd, err := parseForward(arg) + if err != nil { + log.Fatalf("%v", err) + } + forwards = append(forwards, fwd) + } + + if len(forwards) == 0 { + fs.Usage() + os.Exit(1) } - listenAddr := ":" + listenPort - log.Printf("TCP bridge %s → %s", listenAddr, target) // Note: allow unprivileged users to use this like so: // echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf - // sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf - listener, err := net.Listen("tcp", listenAddr) + // sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf - if err != nil { - log.Fatalf("Failed to bind %s: %v", listenAddr, err) + // Bind all listeners first (fail fast before starting any accept loops) + type boundListener struct { + net.Listener + target string } - log.Printf("TCP bridge listening on %s → %s", listenAddr, target) - - for { - client, err := listener.Accept() + var listeners []boundListener + for _, fwd := range forwards { + l, err := net.Listen("tcp", fwd.listenAddr) if err != nil { - log.Printf("Accept error: %v", err) - continue + log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err) } - go handleConn(client, target) + log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target) + listeners = append(listeners, boundListener{l, fwd.target}) } + + // Start accept loops + for _, bl := range listeners { + go func(bl boundListener) { + for { + client, err := bl.Accept() + if err != nil { + log.Printf("Accept error: %v", err) + continue + } + go handleConn(client, bl.target) + } + }(bl) + } + + select {} // block forever } func handleConn(client net.Conn, target string) {