mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
feat(cmd/tcpfwd): add multi-forward positional args and version/help
Accept any number of local-port:remote-host:remote-port forwards as positional arguments. Backward-compatible with existing --port/--target flags. Adds --version/-V/version and --help/help handling matching the auth-proxy pattern, including printVersion printed to stderr at startup. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
ec64afe390
commit
b08525b024
@ -2,6 +2,7 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
@ -9,43 +10,157 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
name = "tcpfwd"
|
||||||
|
licenseYear = "2025"
|
||||||
|
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
|
||||||
|
licenseType = "CC0-1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
// replaced by goreleaser / ldflags
|
||||||
|
var (
|
||||||
|
version = "0.0.0-dev"
|
||||||
|
commit = "0000000"
|
||||||
|
date = "0001-01-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
// printVersion displays the version, commit, and build date.
|
||||||
|
func printVersion(w io.Writer) {
|
||||||
|
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
|
||||||
|
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
|
||||||
|
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
|
||||||
|
}
|
||||||
|
|
||||||
|
type forward struct {
|
||||||
|
listenAddr string // e.g. ":12345"
|
||||||
|
target string // e.g. "example.com:2345"
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseForward parses a "local-port:remote-host:remote-port" string.
|
||||||
|
func parseForward(s string) (forward, error) {
|
||||||
|
i := strings.Index(s, ":")
|
||||||
|
if i < 0 || !strings.Contains(s[i+1:], ":") {
|
||||||
|
return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s)
|
||||||
|
}
|
||||||
|
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var listenPort string
|
var listenPort string
|
||||||
var target string
|
var target string
|
||||||
flag.StringVar(&listenPort, "port", "", "Local port to listen on (same as target by default)")
|
var showVersion bool
|
||||||
flag.StringVar(&target, "target", "", "Target host:port (required)")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if target == "" {
|
fs := flag.NewFlagSet(name, flag.ContinueOnError)
|
||||||
flag.Usage()
|
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
|
||||||
|
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
|
||||||
|
fs.BoolVar(&showVersion, "version", false, "show version and exit")
|
||||||
|
|
||||||
|
fs.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
|
||||||
|
fmt.Fprintf(os.Stderr, "FLAGS\n")
|
||||||
|
fs.PrintDefaults()
|
||||||
|
fmt.Fprintf(os.Stderr, "\nEXAMPLES\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name)
|
||||||
|
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name)
|
||||||
|
fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for version/help before full flag parse
|
||||||
|
if len(os.Args) > 1 {
|
||||||
|
arg := os.Args[1]
|
||||||
|
if arg == "-V" || arg == "--version" || arg == "version" {
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
if arg == "help" || arg == "-help" || arg == "--help" {
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
_, _ = fmt.Fprintln(os.Stdout, "")
|
||||||
|
fs.SetOutput(os.Stdout)
|
||||||
|
fs.Usage()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printVersion(os.Stderr)
|
||||||
|
fmt.Fprintln(os.Stderr, "")
|
||||||
|
|
||||||
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||||
|
if err == flag.ErrHelp {
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
log.Fatalf("flag parse error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if showVersion {
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collect forwards
|
||||||
|
var forwards []forward
|
||||||
|
|
||||||
|
// Backward-compat: --port / --target flags
|
||||||
|
if target != "" {
|
||||||
|
port := listenPort
|
||||||
|
if port == "" {
|
||||||
|
i := strings.LastIndex(target, ":")
|
||||||
|
port = target[i+1:]
|
||||||
|
}
|
||||||
|
forwards = append(forwards, forward{listenAddr: ":" + port, target: target})
|
||||||
|
} else if listenPort != "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: --port requires --target\n")
|
||||||
|
fs.Usage()
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(listenPort) == 0 {
|
// Positional args: local-port:remote-host:remote-port
|
||||||
i := strings.LastIndex(target, ":")
|
for _, arg := range fs.Args() {
|
||||||
listenPort = target[i+1:]
|
fwd, err := parseForward(arg)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("%v", err)
|
||||||
|
}
|
||||||
|
forwards = append(forwards, fwd)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(forwards) == 0 {
|
||||||
|
fs.Usage()
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
listenAddr := ":" + listenPort
|
|
||||||
log.Printf("TCP bridge %s → %s", listenAddr, target)
|
|
||||||
|
|
||||||
// Note: allow unprivileged users to use this like so:
|
// Note: allow unprivileged users to use this like so:
|
||||||
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
|
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
|
||||||
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
||||||
listener, err := net.Listen("tcp", listenAddr)
|
|
||||||
|
|
||||||
if err != nil {
|
// Bind all listeners first (fail fast before starting any accept loops)
|
||||||
log.Fatalf("Failed to bind %s: %v", listenAddr, err)
|
type boundListener struct {
|
||||||
|
net.Listener
|
||||||
|
target string
|
||||||
}
|
}
|
||||||
log.Printf("TCP bridge listening on %s → %s", listenAddr, target)
|
var listeners []boundListener
|
||||||
|
for _, fwd := range forwards {
|
||||||
for {
|
l, err := net.Listen("tcp", fwd.listenAddr)
|
||||||
client, err := listener.Accept()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Accept error: %v", err)
|
log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
go handleConn(client, target)
|
log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target)
|
||||||
|
listeners = append(listeners, boundListener{l, fwd.target})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Start accept loops
|
||||||
|
for _, bl := range listeners {
|
||||||
|
go func(bl boundListener) {
|
||||||
|
for {
|
||||||
|
client, err := bl.Accept()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Accept error: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
go handleConn(client, bl.target)
|
||||||
|
}
|
||||||
|
}(bl)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {} // block forever
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleConn(client net.Conn, target string) {
|
func handleConn(client net.Conn, target string) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user