better TLS termination, port-forward support

This commit is contained in:
AJ ONeal 2020-06-09 02:42:56 -06:00
parent 05fdd8bb45
commit 1872302d66
8 changed files with 163 additions and 45 deletions

View File

@ -48,6 +48,7 @@ type Forward struct {
func main() { func main() {
var domains []string var domains []string
var forwards []Forward var forwards []Forward
var portForwards []Forward
// TODO replace the websocket connection with a mock server // TODO replace the websocket connection with a mock server
appID := flag.String("app-id", "telebit.io", "a unique identifier for a deploy target environment") appID := flag.String("app-id", "telebit.io", "a unique identifier for a deploy target environment")
@ -65,6 +66,7 @@ func main() {
token := flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)") token := flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
bindAddrsStr := flag.String("listen", "", "list of bind addresses on which to listen, such as localhost:80, or :443") bindAddrsStr := flag.String("listen", "", "list of bind addresses on which to listen, such as localhost:80, or :443")
locals := flag.String("locals", "", "a list of <from-domain>:<to-port>") locals := flag.String("locals", "", "a list of <from-domain>:<to-port>")
portToPorts := flag.String("port-forward", "", "a list of <from-port>:<to-port> for raw port-forwarding")
flag.Parse() flag.Parse()
if len(os.Args) >= 2 { if len(os.Args) >= 2 {
@ -110,6 +112,16 @@ func main() {
domains = append(domains, domain) domains = append(domains, domain)
} }
if 0 == len(*portToPorts) {
*portToPorts = os.Getenv("PORT_FORWARDS")
}
portForwards, err := parsePortForwards(portToPorts)
if nil != err {
fmt.Fprintf(os.Stderr, "%s", err)
os.Exit(1)
return
}
bindAddrs, err := parseBindAddrs(*bindAddrsStr) bindAddrs, err := parseBindAddrs(*bindAddrsStr)
if nil != err { if nil != err {
fmt.Fprintf(os.Stderr, "invalid bind address(es) given to --listen\n") fmt.Fprintf(os.Stderr, "invalid bind address(es) given to --listen\n")
@ -188,6 +200,12 @@ func main() {
//mux := telebit.NewRouteMux(acme) //mux := telebit.NewRouteMux(acme)
mux := telebit.NewRouteMux() mux := telebit.NewRouteMux()
// Port forward without TerminatingTLS
for _, fwd := range portForwards {
fmt.Println("Fwd:", fwd.pattern, fwd.port)
mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second)
}
mux.HandleTLS("*", acme, mux) mux.HandleTLS("*", acme, mux)
for _, fwd := range forwards { for _, fwd := range forwards {
mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second) mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second)
@ -260,6 +278,31 @@ func main() {
} }
} }
func parsePortForwards(portToPorts *string) ([]Forward, error) {
var portForwards []Forward
for _, cfg := range strings.Fields(strings.ReplaceAll(*portToPorts, ",", " ")) {
parts := strings.Split(cfg, ":")
if 2 != len(parts) {
return nil, fmt.Errorf("--port-forward should be in the format 1234:5678, not %q", cfg)
}
if _, err := strconv.Atoi(parts[0]); nil != err {
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[0], cfg)
}
if _, err := strconv.Atoi(parts[1]); nil != err {
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[1], cfg)
}
portForwards = append(portForwards, Forward{
pattern: ":" + parts[0],
port: parts[1],
})
}
return portForwards, nil
}
func parseBindAddrs(bindAddrsStr string) ([]string, error) { func parseBindAddrs(bindAddrsStr string) ([]string, error) {
bindAddrs := []string{} bindAddrs := []string{}
@ -386,8 +429,8 @@ func newAPIDNSProvider(baseURL string, token string) (*dns01.DNSProvider, error)
t.ListenAndServe("wss://example.com", mux) t.ListenAndServe("wss://example.com", mux)
*/ */
func getToken(secret string, domains []string) (token string, err error) { func getToken(secret string, domains, ports []string) (token string, err error) {
tokenData := jwt.MapClaims{"domains": domains} tokenData := jwt.MapClaims{"domains": domains, "ports": ports}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData) jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
if token, err = jwtToken.SignedString([]byte(secret)); err != nil { if token, err = jwtToken.SignedString([]byte(secret)); err != nil {

View File

@ -1,9 +1,6 @@
package telebit package telebit
import ( import "fmt"
"fmt"
"strconv"
)
type Scheme string type Scheme string
@ -39,11 +36,15 @@ func NewAddr(s Scheme, t Termination, a string, p int) *Addr {
} }
func (a *Addr) String() string { func (a *Addr) String() string {
return fmt.Sprintf("%s:%s:%s:%d", a.family, a.Scheme(), a.addr, a.port) //return a.addr + ":" + strconv.Itoa(a.port)
return fmt.Sprintf("%s+%s:%s:%d", a.family, a.Scheme(), a.addr, a.port)
} }
// Network s typically network "family", such as "tcp" or "ip",
// but in this case will be "tun", which is a cue to do a `switch`
// to actually use the specific features of a telebit.Addr
func (a *Addr) Network() string { func (a *Addr) Network() string {
return a.addr + ":" + strconv.Itoa(a.port) return "tun"
} }
func (a *Addr) Port() int { func (a *Addr) Port() int {

View File

@ -47,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
func (c *Conn) Peek(n int) (b []byte, err error) { func (c *Conn) Peek(n int) (b []byte, err error) {
if nil == c.peeker { if nil == c.peeker {
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize) c.peeker = bufio.NewReaderSize(c.relay, defaultPeekerSize)
} }
return c.peeker.Peek(n) return c.peeker.Peek(n)
} }

View File

@ -8,6 +8,7 @@ import (
// ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn // ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn
type ConnWrap struct { type ConnWrap struct {
// TODO use io.MultiReader to unbuffer the peeker
//Conn net.Conn //Conn net.Conn
peeker *bufio.Reader peeker *bufio.Reader
Conn net.Conn Conn net.Conn
@ -31,7 +32,7 @@ func (c *ConnWrap) Peek(n int) ([]byte, error) {
default: default:
// *net.UDPConn,*net.TCPConn,*net.IPConn,*net.UnixConn // *net.UDPConn,*net.TCPConn,*net.IPConn,*net.UnixConn
if nil == c.peeker { if nil == c.peeker {
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize) c.peeker = bufio.NewReaderSize(c.Conn, defaultPeekerSize)
} }
return c.peeker.Peek(n) return c.peeker.Peek(n)
} }
@ -93,19 +94,31 @@ func (c *ConnWrap) Servername() string {
// or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http". // or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http".
func (c *ConnWrap) isTerminated() bool { func (c *ConnWrap) isTerminated() bool {
// TODO look at SNI, may need context for Peek() timeout // TODO look at SNI, may need context for Peek() timeout
/*
if nil != c.Plain { if nil != c.Plain {
return true return true
} }
*/
// how to know how many bytes to read? really needs timeout // how to know how many bytes to read? really needs timeout
b, err := c.Peek(2) c.SetDeadline(time.Now().Add(5 * time.Second))
if len(b) >= 2 { n := 6
// TODO better detection? b, _ := c.Peek(n)
defer c.SetDeadline(time.Time{})
if len(b) >= n {
// SSL v3.x / TLS v1.x // SSL v3.x / TLS v1.x
if 0x16 == b[0] && 0x03 == b[1] { // 0: TLS Byte
// 1: Major Version
// 2: 0-Indexed Minor Version
// 3-4: Header Length
// 5: TLS Client Hello Marker Byte
if 0x16 == b[0] && 0x03 == b[1] && 0x01 == b[5] {
//length := (int(b[3]) << 8) + int(b[4])
return false return false
} }
} }
return true
/*
if nil != err { if nil != err {
return true return true
} }
@ -118,6 +131,7 @@ func (c *ConnWrap) isTerminated() bool {
return !ok return !ok
} }
return false return false
*/
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.

View File

@ -126,7 +126,7 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
// remember where the error message goes // remember where the error message goes
if "error" == string(dst.scheme) { if "error" == string(dst.scheme) {
pipe.Close() pipe.Close()
delete(l.conns, src.Network()) delete(l.conns, src.String())
fmt.Printf("a stream errored remotely: %v\n", src) fmt.Printf("a stream errored remotely: %v\n", src)
} }
@ -139,12 +139,12 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
if "end" == string(dst.scheme) { if "end" == string(dst.scheme) {
fmt.Println("[debug] end") fmt.Println("[debug] end")
pipe.Close() pipe.Close()
delete(l.conns, src.Network()) delete(l.conns, src.String())
} }
} }
func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn { func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
connID := src.Network() connID := src.String()
pipe, ok := l.conns[connID] pipe, ok := l.conns[connID]
// Pipe exists // Pipe exists
@ -157,8 +157,8 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
rawPipe, pipe := net.Pipe() rawPipe, pipe := net.Pipe()
newconn := &Conn{ newconn := &Conn{
//updated: time.Now(), //updated: time.Now(),
relaySourceAddr: *src, relaySourceAddr: *dst,
relayTargetAddr: *dst, relayTargetAddr: *src,
relay: rawPipe, relay: rawPipe,
} }
l.conns[connID] = pipe l.conns[connID] = pipe
@ -173,7 +173,11 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
// In any case, we'll just close it all // In any case, we'll just close it all
newconn.Close() newconn.Close()
pipe.Close() pipe.Close()
if nil != err {
fmt.Printf("a stream is done: %q\n", err) fmt.Printf("a stream is done: %q\n", err)
} else {
fmt.Printf("a stream is done\n")
}
}() }()
return pipe return pipe

View File

@ -3,6 +3,8 @@ package telebit
import ( import (
"fmt" "fmt"
"net" "net"
"strconv"
"strings"
"time" "time"
) )
@ -31,13 +33,42 @@ func NewRouteMux() *RouteMux {
// Serve dispatches the connection to the handler whose selectors matches the attributes. // Serve dispatches the connection to the handler whose selectors matches the attributes.
func (m *RouteMux) Serve(client net.Conn) error { func (m *RouteMux) Serve(client net.Conn) error {
wconn := &ConnWrap{Conn: client} var wconn *ConnWrap
servername := wconn.Servername() switch conn := client.(type) {
case *ConnWrap:
wconn = conn
default:
wconn = &ConnWrap{Conn: client}
}
var servername string
var port string
// TODO go back to Servername on conn, but with SNI
//servername := wconn.Servername()
fam := wconn.LocalAddr().Network()
if "tun" == fam {
switch laddr := wconn.LocalAddr().(type) {
case *Addr:
servername = laddr.Hostname()
port = ":" + strconv.Itoa(laddr.Port())
default:
panic("impossible type switch: Addr is 'tun' but didn't match")
}
} else {
// TODO make an AddrWrap to do this switch
addr := wconn.LocalAddr().String()
parts := strings.Split(addr, ":")
port = ":" + parts[len(parts)-1]
servername = strings.Join(parts[:len(parts)-1], ":")
}
fmt.Println("Addr:", fam, servername, port)
for _, meta := range m.routes { for _, meta := range m.routes {
if servername == meta.addr || "*" == meta.addr { // TODO '*.example.com'
fmt.Println("Meta:", meta.addr)
if servername == meta.addr || "*" == meta.addr || port == meta.addr {
//fmt.Println("[debug] test of route:", meta) //fmt.Println("[debug] test of route:", meta)
if err := meta.handler.Serve(client); nil != err { if err := meta.handler.Serve(wconn); nil != err {
// error should be EOF if successful // error should be EOF if successful
return err return err
} }
@ -78,14 +109,22 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err
addr: servername, addr: servername,
terminate: true, terminate: true,
handler: HandlerFunc(func(client net.Conn) error { handler: HandlerFunc(func(client net.Conn) error {
wrap := &ConnWrap{Conn: client} var wconn *ConnWrap
if wrap.isTerminated() { switch conn := client.(type) {
case *ConnWrap:
wconn = conn
default:
panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }")
}
if wconn.isTerminated() {
// nil to skip // nil to skip
return nil return nil
} }
//NewTerminator(acme, handler)(client) //NewTerminator(acme, handler)(client)
//return handler.Serve(client) //return handler.Serve(client)
return handler.Serve(TerminateTLS(client, acme)) return handler.Serve(TerminateTLS(wconn, acme))
}), }),
}) })
return nil return nil

View File

@ -11,6 +11,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"strings"
"time" "time"
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
@ -27,6 +28,11 @@ var defaultPeekerSize = 1024
// ErrBadGateway means that the target did not accept the connection // ErrBadGateway means that the target did not accept the connection
var ErrBadGateway = errors.New("EBADGATEWAY") var ErrBadGateway = errors.New("EBADGATEWAY")
// The proper handling of this error
// is still being debated as of Jun 9, 2020
// https://github.com/golang/go/issues/4373
var errNetClosing = "use of closed network connection"
// A Handler routes, proxies, terminates, or responds to a net.Conn. // A Handler routes, proxies, terminates, or responds to a net.Conn.
type Handler interface { type Handler interface {
Serve(net.Conn) error Serve(net.Conn) error
@ -104,7 +110,13 @@ func Forward(client net.Conn, target net.Conn, timeout time.Duration) error {
} }
}() }()
fmt.Println("[debug] forwarding tcp connection") fmt.Println(
"[debug] forwarding tcp connection",
client.LocalAddr(),
client.RemoteAddr(),
target.LocalAddr(),
target.RemoteAddr(),
)
var err error = nil var err error = nil
ForwardData: ForwardData:
@ -131,7 +143,7 @@ ForwardData:
if nil == err { if nil == err {
break ForwardData break ForwardData
} }
if io.EOF != err { if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
fmt.Printf("read from remote client failed: %q\n", err.Error()) fmt.Printf("read from remote client failed: %q\n", err.Error())
} else { } else {
fmt.Printf("Connection closed (possibly by remote client)\n") fmt.Printf("Connection closed (possibly by remote client)\n")
@ -141,7 +153,7 @@ ForwardData:
if nil == err { if nil == err {
break ForwardData break ForwardData
} }
if io.EOF != err { if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
fmt.Printf("read from local target failed: %q\n", err.Error()) fmt.Printf("read from local target failed: %q\n", err.Error())
} else { } else {
fmt.Printf("Connection closed (possibly by local target)\n") fmt.Printf("Connection closed (possibly by local target)\n")
@ -187,7 +199,11 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn {
var err error var err error
magic, err = newCertMagic(acme) magic, err = newCertMagic(acme)
if nil != err { if nil != err {
fmt.Fprintf(os.Stderr, "failed to initialize certificate management (discovery url? local folder perms?): %s\n", err) fmt.Fprintf(
os.Stderr,
"failed to initialize certificate management (discovery url? local folder perms?): %s\n",
err,
)
os.Exit(1) os.Exit(1)
} }
acmecert = magic acmecert = magic

View File

@ -126,6 +126,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
return b, nil return b, nil
} }
parts := strings.Split(string(p.state.header), ",") parts := strings.Split(string(p.state.header), ",")
fmt.Println("[debug] Tun Header", string(p.state.header))
p.state.header = nil p.state.header = nil
if len(parts) < 5 { if len(parts) < 5 {
return nil, errors.New("error unpacking header") return nil, errors.New("error unpacking header")