better TLS termination, port-forward support

This commit is contained in:
AJ ONeal 2020-06-09 02:42:56 -06:00
parent 05fdd8bb45
commit 1872302d66
8 changed files with 163 additions and 45 deletions

View File

@ -48,6 +48,7 @@ type Forward struct {
func main() {
var domains []string
var forwards []Forward
var portForwards []Forward
// TODO replace the websocket connection with a mock server
appID := flag.String("app-id", "telebit.io", "a unique identifier for a deploy target environment")
@ -65,6 +66,7 @@ func main() {
token := flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
bindAddrsStr := flag.String("listen", "", "list of bind addresses on which to listen, such as localhost:80, or :443")
locals := flag.String("locals", "", "a list of <from-domain>:<to-port>")
portToPorts := flag.String("port-forward", "", "a list of <from-port>:<to-port> for raw port-forwarding")
flag.Parse()
if len(os.Args) >= 2 {
@ -110,6 +112,16 @@ func main() {
domains = append(domains, domain)
}
if 0 == len(*portToPorts) {
*portToPorts = os.Getenv("PORT_FORWARDS")
}
portForwards, err := parsePortForwards(portToPorts)
if nil != err {
fmt.Fprintf(os.Stderr, "%s", err)
os.Exit(1)
return
}
bindAddrs, err := parseBindAddrs(*bindAddrsStr)
if nil != err {
fmt.Fprintf(os.Stderr, "invalid bind address(es) given to --listen\n")
@ -188,6 +200,12 @@ func main() {
//mux := telebit.NewRouteMux(acme)
mux := telebit.NewRouteMux()
// Port forward without TerminatingTLS
for _, fwd := range portForwards {
fmt.Println("Fwd:", fwd.pattern, fwd.port)
mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second)
}
mux.HandleTLS("*", acme, mux)
for _, fwd := range forwards {
mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second)
@ -260,6 +278,31 @@ func main() {
}
}
func parsePortForwards(portToPorts *string) ([]Forward, error) {
var portForwards []Forward
for _, cfg := range strings.Fields(strings.ReplaceAll(*portToPorts, ",", " ")) {
parts := strings.Split(cfg, ":")
if 2 != len(parts) {
return nil, fmt.Errorf("--port-forward should be in the format 1234:5678, not %q", cfg)
}
if _, err := strconv.Atoi(parts[0]); nil != err {
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[0], cfg)
}
if _, err := strconv.Atoi(parts[1]); nil != err {
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[1], cfg)
}
portForwards = append(portForwards, Forward{
pattern: ":" + parts[0],
port: parts[1],
})
}
return portForwards, nil
}
func parseBindAddrs(bindAddrsStr string) ([]string, error) {
bindAddrs := []string{}
@ -386,8 +429,8 @@ func newAPIDNSProvider(baseURL string, token string) (*dns01.DNSProvider, error)
t.ListenAndServe("wss://example.com", mux)
*/
func getToken(secret string, domains []string) (token string, err error) {
tokenData := jwt.MapClaims{"domains": domains}
func getToken(secret string, domains, ports []string) (token string, err error) {
tokenData := jwt.MapClaims{"domains": domains, "ports": ports}
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
if token, err = jwtToken.SignedString([]byte(secret)); err != nil {

View File

@ -1,9 +1,6 @@
package telebit
import (
"fmt"
"strconv"
)
import "fmt"
type Scheme string
@ -39,11 +36,15 @@ func NewAddr(s Scheme, t Termination, a string, p int) *Addr {
}
func (a *Addr) String() string {
return fmt.Sprintf("%s:%s:%s:%d", a.family, a.Scheme(), a.addr, a.port)
//return a.addr + ":" + strconv.Itoa(a.port)
return fmt.Sprintf("%s+%s:%s:%d", a.family, a.Scheme(), a.addr, a.port)
}
// Network s typically network "family", such as "tcp" or "ip",
// but in this case will be "tun", which is a cue to do a `switch`
// to actually use the specific features of a telebit.Addr
func (a *Addr) Network() string {
return a.addr + ":" + strconv.Itoa(a.port)
return "tun"
}
func (a *Addr) Port() int {

View File

@ -47,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
func (c *Conn) Peek(n int) (b []byte, err error) {
if nil == c.peeker {
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize)
c.peeker = bufio.NewReaderSize(c.relay, defaultPeekerSize)
}
return c.peeker.Peek(n)
}

View File

@ -8,6 +8,7 @@ import (
// ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn
type ConnWrap struct {
// TODO use io.MultiReader to unbuffer the peeker
//Conn net.Conn
peeker *bufio.Reader
Conn net.Conn
@ -31,7 +32,7 @@ func (c *ConnWrap) Peek(n int) ([]byte, error) {
default:
// *net.UDPConn,*net.TCPConn,*net.IPConn,*net.UnixConn
if nil == c.peeker {
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize)
c.peeker = bufio.NewReaderSize(c.Conn, defaultPeekerSize)
}
return c.peeker.Peek(n)
}
@ -93,31 +94,44 @@ func (c *ConnWrap) Servername() string {
// or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http".
func (c *ConnWrap) isTerminated() bool {
// TODO look at SNI, may need context for Peek() timeout
if nil != c.Plain {
return true
}
/*
if nil != c.Plain {
return true
}
*/
// how to know how many bytes to read? really needs timeout
b, err := c.Peek(2)
if len(b) >= 2 {
// TODO better detection?
c.SetDeadline(time.Now().Add(5 * time.Second))
n := 6
b, _ := c.Peek(n)
defer c.SetDeadline(time.Time{})
if len(b) >= n {
// SSL v3.x / TLS v1.x
if 0x16 == b[0] && 0x03 == b[1] {
// 0: TLS Byte
// 1: Major Version
// 2: 0-Indexed Minor Version
// 3-4: Header Length
// 5: TLS Client Hello Marker Byte
if 0x16 == b[0] && 0x03 == b[1] && 0x01 == b[5] {
//length := (int(b[3]) << 8) + int(b[4])
return false
}
}
if nil != err {
return true
}
return true
/*
if nil != err {
return true
}
switch conn := c.Conn.(type) {
case *ConnWrap:
return conn.isTerminated()
case *Conn:
_, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)]
return !ok
}
return false
switch conn := c.Conn.(type) {
case *ConnWrap:
return conn.isTerminated()
case *Conn:
_, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)]
return !ok
}
return false
*/
}
// LocalAddr returns the local network address.

View File

@ -126,7 +126,7 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
// remember where the error message goes
if "error" == string(dst.scheme) {
pipe.Close()
delete(l.conns, src.Network())
delete(l.conns, src.String())
fmt.Printf("a stream errored remotely: %v\n", src)
}
@ -139,12 +139,12 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
if "end" == string(dst.scheme) {
fmt.Println("[debug] end")
pipe.Close()
delete(l.conns, src.Network())
delete(l.conns, src.String())
}
}
func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
connID := src.Network()
connID := src.String()
pipe, ok := l.conns[connID]
// Pipe exists
@ -157,8 +157,8 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
rawPipe, pipe := net.Pipe()
newconn := &Conn{
//updated: time.Now(),
relaySourceAddr: *src,
relayTargetAddr: *dst,
relaySourceAddr: *dst,
relayTargetAddr: *src,
relay: rawPipe,
}
l.conns[connID] = pipe
@ -173,7 +173,11 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
// In any case, we'll just close it all
newconn.Close()
pipe.Close()
fmt.Printf("a stream is done: %q\n", err)
if nil != err {
fmt.Printf("a stream is done: %q\n", err)
} else {
fmt.Printf("a stream is done\n")
}
}()
return pipe

View File

@ -3,6 +3,8 @@ package telebit
import (
"fmt"
"net"
"strconv"
"strings"
"time"
)
@ -31,13 +33,42 @@ func NewRouteMux() *RouteMux {
// Serve dispatches the connection to the handler whose selectors matches the attributes.
func (m *RouteMux) Serve(client net.Conn) error {
wconn := &ConnWrap{Conn: client}
servername := wconn.Servername()
var wconn *ConnWrap
switch conn := client.(type) {
case *ConnWrap:
wconn = conn
default:
wconn = &ConnWrap{Conn: client}
}
var servername string
var port string
// TODO go back to Servername on conn, but with SNI
//servername := wconn.Servername()
fam := wconn.LocalAddr().Network()
if "tun" == fam {
switch laddr := wconn.LocalAddr().(type) {
case *Addr:
servername = laddr.Hostname()
port = ":" + strconv.Itoa(laddr.Port())
default:
panic("impossible type switch: Addr is 'tun' but didn't match")
}
} else {
// TODO make an AddrWrap to do this switch
addr := wconn.LocalAddr().String()
parts := strings.Split(addr, ":")
port = ":" + parts[len(parts)-1]
servername = strings.Join(parts[:len(parts)-1], ":")
}
fmt.Println("Addr:", fam, servername, port)
for _, meta := range m.routes {
if servername == meta.addr || "*" == meta.addr {
// TODO '*.example.com'
fmt.Println("Meta:", meta.addr)
if servername == meta.addr || "*" == meta.addr || port == meta.addr {
//fmt.Println("[debug] test of route:", meta)
if err := meta.handler.Serve(client); nil != err {
if err := meta.handler.Serve(wconn); nil != err {
// error should be EOF if successful
return err
}
@ -78,14 +109,22 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err
addr: servername,
terminate: true,
handler: HandlerFunc(func(client net.Conn) error {
wrap := &ConnWrap{Conn: client}
if wrap.isTerminated() {
var wconn *ConnWrap
switch conn := client.(type) {
case *ConnWrap:
wconn = conn
default:
panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }")
}
if wconn.isTerminated() {
// nil to skip
return nil
}
//NewTerminator(acme, handler)(client)
//return handler.Serve(client)
return handler.Serve(TerminateTLS(client, acme))
return handler.Serve(TerminateTLS(wconn, acme))
}),
})
return nil

View File

@ -11,6 +11,7 @@ import (
"net"
"net/http"
"os"
"strings"
"time"
"github.com/caddyserver/certmagic"
@ -27,6 +28,11 @@ var defaultPeekerSize = 1024
// ErrBadGateway means that the target did not accept the connection
var ErrBadGateway = errors.New("EBADGATEWAY")
// The proper handling of this error
// is still being debated as of Jun 9, 2020
// https://github.com/golang/go/issues/4373
var errNetClosing = "use of closed network connection"
// A Handler routes, proxies, terminates, or responds to a net.Conn.
type Handler interface {
Serve(net.Conn) error
@ -104,7 +110,13 @@ func Forward(client net.Conn, target net.Conn, timeout time.Duration) error {
}
}()
fmt.Println("[debug] forwarding tcp connection")
fmt.Println(
"[debug] forwarding tcp connection",
client.LocalAddr(),
client.RemoteAddr(),
target.LocalAddr(),
target.RemoteAddr(),
)
var err error = nil
ForwardData:
@ -131,7 +143,7 @@ ForwardData:
if nil == err {
break ForwardData
}
if io.EOF != err {
if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
fmt.Printf("read from remote client failed: %q\n", err.Error())
} else {
fmt.Printf("Connection closed (possibly by remote client)\n")
@ -141,7 +153,7 @@ ForwardData:
if nil == err {
break ForwardData
}
if io.EOF != err {
if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
fmt.Printf("read from local target failed: %q\n", err.Error())
} else {
fmt.Printf("Connection closed (possibly by local target)\n")
@ -187,7 +199,11 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn {
var err error
magic, err = newCertMagic(acme)
if nil != err {
fmt.Fprintf(os.Stderr, "failed to initialize certificate management (discovery url? local folder perms?): %s\n", err)
fmt.Fprintf(
os.Stderr,
"failed to initialize certificate management (discovery url? local folder perms?): %s\n",
err,
)
os.Exit(1)
}
acmecert = magic

View File

@ -126,6 +126,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
return b, nil
}
parts := strings.Split(string(p.state.header), ",")
fmt.Println("[debug] Tun Header", string(p.state.header))
p.state.header = nil
if len(parts) < 5 {
return nil, errors.New("error unpacking header")