mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
Accept any number of local-port:remote-host:remote-port forwards as positional arguments. Backward-compatible with existing --port/--target flags. Adds --version/-V/version and --help/help handling matching the auth-proxy pattern, including printVersion printed to stderr at startup. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
202 lines
5.1 KiB
Go
202 lines
5.1 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
name = "tcpfwd"
|
|
licenseYear = "2025"
|
|
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
|
|
licenseType = "CC0-1.0"
|
|
)
|
|
|
|
// replaced by goreleaser / ldflags
|
|
var (
|
|
version = "0.0.0-dev"
|
|
commit = "0000000"
|
|
date = "0001-01-01"
|
|
)
|
|
|
|
// printVersion displays the version, commit, and build date.
|
|
func printVersion(w io.Writer) {
|
|
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
|
|
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
|
|
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
|
|
}
|
|
|
|
type forward struct {
|
|
listenAddr string // e.g. ":12345"
|
|
target string // e.g. "example.com:2345"
|
|
}
|
|
|
|
// parseForward parses a "local-port:remote-host:remote-port" string.
|
|
func parseForward(s string) (forward, error) {
|
|
i := strings.Index(s, ":")
|
|
if i < 0 || !strings.Contains(s[i+1:], ":") {
|
|
return forward{}, fmt.Errorf("invalid forward %q: expected local-port:remote-host:remote-port", s)
|
|
}
|
|
return forward{listenAddr: ":" + s[:i], target: s[i+1:]}, nil
|
|
}
|
|
|
|
func main() {
|
|
var listenPort string
|
|
var target string
|
|
var showVersion bool
|
|
|
|
fs := flag.NewFlagSet(name, flag.ContinueOnError)
|
|
fs.StringVar(&listenPort, "port", "", "local port to listen on (use with --target)")
|
|
fs.StringVar(&target, "target", "", "target host:port (use with --port)")
|
|
fs.BoolVar(&showVersion, "version", false, "show version and exit")
|
|
|
|
fs.Usage = func() {
|
|
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags] [local-port:remote-host:remote-port ...]\n\n", name)
|
|
fmt.Fprintf(os.Stderr, "FLAGS\n")
|
|
fs.PrintDefaults()
|
|
fmt.Fprintf(os.Stderr, "\nEXAMPLES\n")
|
|
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345\n", name)
|
|
fmt.Fprintf(os.Stderr, " %s 12345:example.com:2345 22222:other.host:22\n", name)
|
|
fmt.Fprintf(os.Stderr, " %s --port 12345 --target example.com:2345\n", name)
|
|
}
|
|
|
|
// Special handling for version/help before full flag parse
|
|
if len(os.Args) > 1 {
|
|
arg := os.Args[1]
|
|
if arg == "-V" || arg == "--version" || arg == "version" {
|
|
printVersion(os.Stdout)
|
|
os.Exit(0)
|
|
}
|
|
if arg == "help" || arg == "-help" || arg == "--help" {
|
|
printVersion(os.Stdout)
|
|
_, _ = fmt.Fprintln(os.Stdout, "")
|
|
fs.SetOutput(os.Stdout)
|
|
fs.Usage()
|
|
os.Exit(0)
|
|
}
|
|
}
|
|
|
|
printVersion(os.Stderr)
|
|
fmt.Fprintln(os.Stderr, "")
|
|
|
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
|
if err == flag.ErrHelp {
|
|
os.Exit(0)
|
|
}
|
|
log.Fatalf("flag parse error: %v", err)
|
|
}
|
|
|
|
if showVersion {
|
|
printVersion(os.Stdout)
|
|
os.Exit(0)
|
|
}
|
|
|
|
// Collect forwards
|
|
var forwards []forward
|
|
|
|
// Backward-compat: --port / --target flags
|
|
if target != "" {
|
|
port := listenPort
|
|
if port == "" {
|
|
i := strings.LastIndex(target, ":")
|
|
port = target[i+1:]
|
|
}
|
|
forwards = append(forwards, forward{listenAddr: ":" + port, target: target})
|
|
} else if listenPort != "" {
|
|
fmt.Fprintf(os.Stderr, "error: --port requires --target\n")
|
|
fs.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Positional args: local-port:remote-host:remote-port
|
|
for _, arg := range fs.Args() {
|
|
fwd, err := parseForward(arg)
|
|
if err != nil {
|
|
log.Fatalf("%v", err)
|
|
}
|
|
forwards = append(forwards, fwd)
|
|
}
|
|
|
|
if len(forwards) == 0 {
|
|
fs.Usage()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Note: allow unprivileged users to use this like so:
|
|
// echo 'net.ipv4.ip_unprivileged_port_start=1' | sudo tee /etc/sysctl.d/01-deprivilege-ports.conf
|
|
// sudo sysctl -p /etc/sysctl.d/01-deprivilege-ports.conf
|
|
|
|
// Bind all listeners first (fail fast before starting any accept loops)
|
|
type boundListener struct {
|
|
net.Listener
|
|
target string
|
|
}
|
|
var listeners []boundListener
|
|
for _, fwd := range forwards {
|
|
l, err := net.Listen("tcp", fwd.listenAddr)
|
|
if err != nil {
|
|
log.Fatalf("Failed to bind %s: %v", fwd.listenAddr, err)
|
|
}
|
|
log.Printf("TCP bridge listening on %s → %s", fwd.listenAddr, fwd.target)
|
|
listeners = append(listeners, boundListener{l, fwd.target})
|
|
}
|
|
|
|
// Start accept loops
|
|
for _, bl := range listeners {
|
|
go func(bl boundListener) {
|
|
for {
|
|
client, err := bl.Accept()
|
|
if err != nil {
|
|
log.Printf("Accept error: %v", err)
|
|
continue
|
|
}
|
|
go handleConn(client, bl.target)
|
|
}
|
|
}(bl)
|
|
}
|
|
|
|
select {} // block forever
|
|
}
|
|
|
|
func handleConn(client net.Conn, target string) {
|
|
defer client.Close()
|
|
|
|
remote, err := net.Dial("tcp", target)
|
|
if err != nil {
|
|
log.Printf("Failed to connect to %s: %v", target, err)
|
|
return
|
|
}
|
|
defer remote.Close()
|
|
|
|
log.Printf("New connection %s ↔ %s", client.RemoteAddr(), remote.RemoteAddr())
|
|
|
|
// Bidirectional copy with error handling
|
|
go func() { _ = copyAndClose(remote, client) }()
|
|
func() { _ = copyAndClose(client, remote) }()
|
|
}
|
|
|
|
// copyAndClose copies until EOF or error, then closes dst
|
|
func copyAndClose(dst, src net.Conn) error {
|
|
_, err := io.Copy(dst, src)
|
|
dst.Close()
|
|
if err != nil && !isClosedConn(err) {
|
|
log.Printf("Copy error (%s → %s): %v",
|
|
src.RemoteAddr(), dst.RemoteAddr(), err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
// isClosedConn detects common closed-connection errors
|
|
func isClosedConn(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
opErr, ok := err.(*net.OpError)
|
|
return ok && opErr.Err.Error() == "use of closed network connection"
|
|
}
|