golib/cmd/tcpfwd/main.go
AJ ONeal b08525b024
feat(cmd/tcpfwd): add multi-forward positional args and version/help
Accept any number of local-port:remote-host:remote-port forwards as
positional arguments. Backward-compatible with existing --port/--target
flags. Adds --version/-V/version and --help/help handling matching the
auth-proxy pattern, including printVersion printed to stderr at startup.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-02-27 21:54:51 -07:00

202 lines
5.1 KiB
Go

package main
import (
"flag"
"fmt"
"io"
"log"
"net"
"os"
"strings"
)
const (
name = "tcpfwd"
licenseYear = "2025"
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
licenseType = "CC0-1.0"
)
// replaced by goreleaser / ldflags
var (
version = "0.0.0-dev"
commit = "0000000"
date = "0001-01-01"
)
// printVersion displays the version, commit, and build date.
func printVersion(w io.Writer) {
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
}
type forward struct {
listenAddr string // e.g. ":12345"
target string // e.g. "example.com:2345"
}
// parseForward parses a "local-port:remote-host:remote-port" string.
func parseForward(s string) (forward, error) {
i := strings.Index(s, ":")
if i < 0 || !strings.Contains(s[i+1:], ":") {
return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s)
}
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
}
func main() {
var listenPort string
var target string
var showVersion bool
fs := flag.NewFlagSet(name, flag.ContinueOnError)
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
fs.BoolVar(&showVersion, "version", false, "show version and exit")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
fmt.Fprintf(os.Stderr, "FLAGS\n")
fs.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nEXAMPLES\n")
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name)
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name)
fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name)
}
// Special handling for version/help before full flag parse
if len(os.Args) > 1 {
arg := os.Args[1]
if arg == "-V" || arg == "--version" || arg == "version" {
printVersion(os.Stdout)
os.Exit(0)
}
if arg == "help" || arg == "-help" || arg == "--help" {
printVersion(os.Stdout)
_, _ = fmt.Fprintln(os.Stdout, "")
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
printVersion(os.Stderr)
fmt.Fprintln(os.Stderr, "")
if err := fs.Parse(os.Args[1:]); err != nil {
if err == flag.ErrHelp {
os.Exit(0)
}
log.Fatalf("flag parse error: %v", err)
}
if showVersion {
printVersion(os.Stdout)
os.Exit(0)
}
// Collect forwards
var forwards []forward
// Backward-compat: --port / --target flags
if target != "" {
port := listenPort
if port == "" {
i := strings.LastIndex(target, ":")
port = target[i+1:]
}
forwards = append(forwards, forward{listenAddr: ":" + port, target: target})
} else if listenPort != "" {
fmt.Fprintf(os.Stderr, "error: --port requires --target\n")
fs.Usage()
os.Exit(1)
}
// Positional args: local-port:remote-host:remote-port
for _, arg := range fs.Args() {
fwd, err := parseForward(arg)
if err != nil {
log.Fatalf("%v", err)
}
forwards = append(forwards, fwd)
}
if len(forwards) == 0 {
fs.Usage()
os.Exit(1)
}
// Note: allow unprivileged users to use this like so:
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
// Bind all listeners first (fail fast before starting any accept loops)
type boundListener struct {
net.Listener
target string
}
var listeners []boundListener
for _, fwd := range forwards {
l, err := net.Listen("tcp", fwd.listenAddr)
if err != nil {
log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err)
}
log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target)
listeners = append(listeners, boundListener{l, fwd.target})
}
// Start accept loops
for _, bl := range listeners {
go func(bl boundListener) {
for {
client, err := bl.Accept()
if err != nil {
log.Printf("Accept error: %v", err)
continue
}
go handleConn(client, bl.target)
}
}(bl)
}
select {} // block forever
}
func handleConn(client net.Conn, target string) {
defer client.Close()
remote, err := net.Dial("tcp", target)
if err != nil {
log.Printf("Failed to connect to %s: %v", target, err)
return
}
defer remote.Close()
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
// Bidirectional copy with error handling
go func() { _ = copyAndClose(remote, client) }()
func() { _ = copyAndClose(client, remote) }()
}
// copyAndClose copies until EOF or error, then closes dst
func copyAndClose(dst, src net.Conn) error {
_, err := io.Copy(dst, src)
dst.Close()
if err != nil && !isClosedConn(err) {
log.Printf("Copy error (%s → %s): %v",
src.RemoteAddr(), dst.RemoteAddr(), err)
}
return err
}
// isClosedConn detects common closed-connection errors
func isClosedConn(err error) bool {
if err == nil {
return false
}
opErr, ok := err.(*net.OpError)
return ok && opErr.Err.Error() == "use of closed network connection"
}