mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
feat(cmd/tcpfwd): add multi-forward positional args and version/help
Accept any number of local-port:remote-host:remote-port forwards as positional arguments. Backward-compatible with existing --port/--target flags. Adds --version/-V/version and --help/help handling matching the auth-proxy pattern, including printVersion printed to stderr at startup. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ec64afe390
commit
b08525b024
@ -2,6 +2,7 @@ package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
@ -9,43 +10,157 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
name = "tcpfwd"
|
||||
licenseYear = "2025"
|
||||
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
|
||||
licenseType = "CC0-1.0"
|
||||
)
|
||||
|
||||
// replaced by goreleaser / ldflags
|
||||
var (
|
||||
version = "0.0.0-dev"
|
||||
commit = "0000000"
|
||||
date = "0001-01-01"
|
||||
)
|
||||
|
||||
// printVersion displays the version, commit, and build date.
|
||||
func printVersion(w io.Writer) {
|
||||
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
|
||||
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
|
||||
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
|
||||
}
|
||||
|
||||
type forward struct {
|
||||
listenAddr string // e.g. ":12345"
|
||||
target string // e.g. "example.com:2345"
|
||||
}
|
||||
|
||||
// parseForward parses a "local-port:remote-host:remote-port" string.
|
||||
func parseForward(s string) (forward, error) {
|
||||
i := strings.Index(s, ":")
|
||||
if i < 0 || !strings.Contains(s[i+1:], ":") {
|
||||
return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s)
|
||||
}
|
||||
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
var listenPort string
|
||||
var target string
|
||||
flag.StringVar(&listenPort, "port", "", "Local port to listen on (same as target by default)")
|
||||
flag.StringVar(&target, "target", "", "Target host:port (required)")
|
||||
flag.Parse()
|
||||
var showVersion bool
|
||||
|
||||
if target == "" {
|
||||
flag.Usage()
|
||||
fs := flag.NewFlagSet(name, flag.ContinueOnError)
|
||||
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
|
||||
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
|
||||
fs.BoolVar(&showVersion, "version", false, "show version and exit")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
|
||||
fmt.Fprintf(os.Stderr, "FLAGS\n")
|
||||
fs.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nEXAMPLES\n")
|
||||
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name)
|
||||
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name)
|
||||
fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name)
|
||||
}
|
||||
|
||||
// Special handling for version/help before full flag parse
|
||||
if len(os.Args) > 1 {
|
||||
arg := os.Args[1]
|
||||
if arg == "-V" || arg == "--version" || arg == "version" {
|
||||
printVersion(os.Stdout)
|
||||
os.Exit(0)
|
||||
}
|
||||
if arg == "help" || arg == "-help" || arg == "--help" {
|
||||
printVersion(os.Stdout)
|
||||
_, _ = fmt.Fprintln(os.Stdout, "")
|
||||
fs.SetOutput(os.Stdout)
|
||||
fs.Usage()
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
|
||||
printVersion(os.Stderr)
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
|
||||
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||
if err == flag.ErrHelp {
|
||||
os.Exit(0)
|
||||
}
|
||||
log.Fatalf("flag parse error: %v", err)
|
||||
}
|
||||
|
||||
if showVersion {
|
||||
printVersion(os.Stdout)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// Collect forwards
|
||||
var forwards []forward
|
||||
|
||||
// Backward-compat: --port / --target flags
|
||||
if target != "" {
|
||||
port := listenPort
|
||||
if port == "" {
|
||||
i := strings.LastIndex(target, ":")
|
||||
port = target[i+1:]
|
||||
}
|
||||
forwards = append(forwards, forward{listenAddr: ":" + port, target: target})
|
||||
} else if listenPort != "" {
|
||||
fmt.Fprintf(os.Stderr, "error: --port requires --target\n")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(listenPort) == 0 {
|
||||
i := strings.LastIndex(target, ":")
|
||||
listenPort = target[i+1:]
|
||||
// Positional args: local-port:remote-host:remote-port
|
||||
for _, arg := range fs.Args() {
|
||||
fwd, err := parseForward(arg)
|
||||
if err != nil {
|
||||
log.Fatalf("%v", err)
|
||||
}
|
||||
forwards = append(forwards, fwd)
|
||||
}
|
||||
|
||||
if len(forwards) == 0 {
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
listenAddr := ":" + listenPort
|
||||
log.Printf("TCP bridge %s → %s", listenAddr, target)
|
||||
|
||||
// Note: allow unprivileged users to use this like so:
|
||||
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
|
||||
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
||||
listener, err := net.Listen("tcp", listenAddr)
|
||||
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
||||
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to bind %s: %v", listenAddr, err)
|
||||
// Bind all listeners first (fail fast before starting any accept loops)
|
||||
type boundListener struct {
|
||||
net.Listener
|
||||
target string
|
||||
}
|
||||
log.Printf("TCP bridge listening on %s → %s", listenAddr, target)
|
||||
|
||||
for {
|
||||
client, err := listener.Accept()
|
||||
var listeners []boundListener
|
||||
for _, fwd := range forwards {
|
||||
l, err := net.Listen("tcp", fwd.listenAddr)
|
||||
if err != nil {
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err)
|
||||
}
|
||||
go handleConn(client, target)
|
||||
log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target)
|
||||
listeners = append(listeners, boundListener{l, fwd.target})
|
||||
}
|
||||
|
||||
// Start accept loops
|
||||
for _, bl := range listeners {
|
||||
go func(bl boundListener) {
|
||||
for {
|
||||
client, err := bl.Accept()
|
||||
if err != nil {
|
||||
log.Printf("Accept error: %v", err)
|
||||
continue
|
||||
}
|
||||
go handleConn(client, bl.target)
|
||||
}
|
||||
}(bl)
|
||||
}
|
||||
|
||||
select {} // block forever
|
||||
}
|
||||
|
||||
func handleConn(client net.Conn, target string) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user