diff --git a/cmd/sclient/main.go b/cmd/sclient/main.go index e3161f4..3eb8a14 100644 --- a/cmd/sclient/main.go +++ b/cmd/sclient/main.go @@ -33,6 +33,9 @@ func usage() { " ex: sclient example.com:8443 0.0.0.0:4080\n"+ "\n"+ " ex: sclient example.com:443 -\n"+ + "\n"+ + " ex: sclient --ssh 22 example.com 3000\n"+ + " (try TLS+ssh ALPN on 443, fall back to SSH on port 22)\n"+ "\n", ver()) flag.PrintDefaults() fmt.Println() @@ -51,10 +54,12 @@ func main() { var insecure bool var servername string var silent bool + var sshFallbackPort int flag.Usage = usage flag.StringVar(&alpnList, "alpn", "", "acceptable protocols, ex: 'h2,http/1.1' 'http/1.1' 'ssh'") + flag.IntVar(&sshFallbackPort, "ssh", 0, "enable ssh ALPN and fall back to direct SSH on if TLS+ssh fails (ex: 22)") flag.BoolVar(&insecure, "k", false, "alias for --insecure") flag.BoolVar(&silent, "s", false, "alias of --silent") flag.StringVar(&servername, "servername", "", "specify a servername different from (to disable SNI use an IP as and do not use this option)") @@ -64,6 +69,9 @@ func main() { flag.Parse() alpns := parseOptionList(alpnList) + if sshFallbackPort > 0 && len(alpns) == 0 { + alpns = []string{"ssh"} + } remotestr := flag.Arg(0) localstr := flag.Arg(1) @@ -85,6 +93,7 @@ func main() { ServerName: servername, Silent: silent, NextProtos: alpns, + SSHFallbackPort: sshFallbackPort, } remote := strings.Split(remotestr, ":") diff --git a/sclient.go b/sclient.go index 94da38d..fe583f5 100644 --- a/sclient.go +++ b/sclient.go @@ -20,23 +20,18 @@ type Tunnel struct { NextProtos []string ServerName string Silent bool + SSHFallbackPort int } // DialAndListen will create a test TLS connection to the remote address and then // begin listening locally. Each local connection will result in a separate remote connection. func (t *Tunnel) DialAndListen() error { remote := t.RemoteAddress + ":" + strconv.Itoa(t.RemotePort) - conn, err := tls.Dial("tcp", remote, - &tls.Config{ - ServerName: t.ServerName, - NextProtos: t.NextProtos, - InsecureSkipVerify: t.InsecureSkipVerify, - }) - - if err != nil { - fmt.Fprintf(os.Stderr, "[warn] '%s' may not be accepting connections: %s\n", remote, err) + testConn, _, testErr := t.dialRemote(remote) + if testErr != nil { + fmt.Fprintf(os.Stderr, "[warn] '%s' may not be accepting connections: %s\n", remote, testErr) } else { - _ = conn.Close() + _ = testConn.Close() } // use stdin/stdout @@ -142,13 +137,7 @@ func pipe(r netReadWriteCloser, w netReadWriteCloser, t string) { } func (t *Tunnel) handleConnection(remote string, conn netReadWriteCloser) { - sclient, err := tls.Dial("tcp", remote, - &tls.Config{ - ServerName: t.ServerName, - NextProtos: t.NextProtos, - InsecureSkipVerify: t.InsecureSkipVerify, - }) - + upstream, fallback, err := t.dialRemote(remote) if err != nil { fmt.Fprintf(os.Stderr, "[error] (remote) %s\n", err) _ = conn.Close() @@ -156,15 +145,46 @@ func (t *Tunnel) handleConnection(remote string, conn netReadWriteCloser) { } if !t.Silent { + target := fmt.Sprintf("%s:%d", t.RemoteAddress, t.RemotePort) + if fallback { + target = t.RemoteAddress + ":" + strconv.Itoa(t.SSHFallbackPort) + } if conn.RemoteAddr().Network() == "stdio" { - _, _ = fmt.Fprintf(os.Stdout, "(connected to %s:%d and reading from %s)\n", - t.RemoteAddress, t.RemotePort, conn.RemoteAddr().String()) + _, _ = fmt.Fprintf(os.Stdout, "(connected to %s and reading from %s)\n", + target, conn.RemoteAddr().String()) } else { - _, _ = fmt.Fprintf(os.Stdout, "[connect] %s => %s:%d\n", - strings.Replace(conn.RemoteAddr().String(), "[::1]:", "localhost:", 1), t.RemoteAddress, t.RemotePort) + _, _ = fmt.Fprintf(os.Stdout, "[connect] %s => %s\n", + strings.Replace(conn.RemoteAddr().String(), "[::1]:", "localhost:", 1), target) } } - go pipe(conn, sclient, "local") - pipe(sclient, conn, "remote") + go pipe(conn, upstream, "local") + pipe(upstream, conn, "remote") +} + +func (t *Tunnel) dialRemote(remote string) (netReadWriteCloser, bool, error) { + tlsConn, err := tls.Dial("tcp", remote, + &tls.Config{ + ServerName: t.ServerName, + NextProtos: t.NextProtos, + InsecureSkipVerify: t.InsecureSkipVerify, + }) + if err == nil { + return tlsConn, false, nil + } + + if t.SSHFallbackPort <= 0 { + return nil, false, err + } + + fallbackAddr := t.RemoteAddress + ":" + strconv.Itoa(t.SSHFallbackPort) + if !t.Silent { + fmt.Fprintf(os.Stderr, "[info] TLS+ssh failed (%s), falling back to %s\n", err, fallbackAddr) + } + + tcpConn, tcpErr := net.Dial("tcp", fallbackAddr) + if tcpErr != nil { + return nil, false, fmt.Errorf("TLS failed: %w; fallback to %s also failed: %v", err, fallbackAddr, tcpErr) + } + return tcpConn, true, nil }