better TLS termination, port-forward support
This commit is contained in:
parent
05fdd8bb45
commit
1872302d66
|
@ -48,6 +48,7 @@ type Forward struct {
|
|||
func main() {
|
||||
var domains []string
|
||||
var forwards []Forward
|
||||
var portForwards []Forward
|
||||
|
||||
// TODO replace the websocket connection with a mock server
|
||||
appID := flag.String("app-id", "telebit.io", "a unique identifier for a deploy target environment")
|
||||
|
@ -65,6 +66,7 @@ func main() {
|
|||
token := flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
|
||||
bindAddrsStr := flag.String("listen", "", "list of bind addresses on which to listen, such as localhost:80, or :443")
|
||||
locals := flag.String("locals", "", "a list of <from-domain>:<to-port>")
|
||||
portToPorts := flag.String("port-forward", "", "a list of <from-port>:<to-port> for raw port-forwarding")
|
||||
flag.Parse()
|
||||
|
||||
if len(os.Args) >= 2 {
|
||||
|
@ -110,6 +112,16 @@ func main() {
|
|||
domains = append(domains, domain)
|
||||
}
|
||||
|
||||
if 0 == len(*portToPorts) {
|
||||
*portToPorts = os.Getenv("PORT_FORWARDS")
|
||||
}
|
||||
portForwards, err := parsePortForwards(portToPorts)
|
||||
if nil != err {
|
||||
fmt.Fprintf(os.Stderr, "%s", err)
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
bindAddrs, err := parseBindAddrs(*bindAddrsStr)
|
||||
if nil != err {
|
||||
fmt.Fprintf(os.Stderr, "invalid bind address(es) given to --listen\n")
|
||||
|
@ -188,6 +200,12 @@ func main() {
|
|||
|
||||
//mux := telebit.NewRouteMux(acme)
|
||||
mux := telebit.NewRouteMux()
|
||||
|
||||
// Port forward without TerminatingTLS
|
||||
for _, fwd := range portForwards {
|
||||
fmt.Println("Fwd:", fwd.pattern, fwd.port)
|
||||
mux.ForwardTCP(fwd.pattern, "localhost:"+fwd.port, 120*time.Second)
|
||||
}
|
||||
mux.HandleTLS("*", acme, mux)
|
||||
for _, fwd := range forwards {
|
||||
mux.ForwardTCP("*", "localhost:"+fwd.port, 120*time.Second)
|
||||
|
@ -260,6 +278,31 @@ func main() {
|
|||
}
|
||||
}
|
||||
|
||||
func parsePortForwards(portToPorts *string) ([]Forward, error) {
|
||||
var portForwards []Forward
|
||||
|
||||
for _, cfg := range strings.Fields(strings.ReplaceAll(*portToPorts, ",", " ")) {
|
||||
parts := strings.Split(cfg, ":")
|
||||
if 2 != len(parts) {
|
||||
return nil, fmt.Errorf("--port-forward should be in the format 1234:5678, not %q", cfg)
|
||||
}
|
||||
|
||||
if _, err := strconv.Atoi(parts[0]); nil != err {
|
||||
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[0], cfg)
|
||||
}
|
||||
if _, err := strconv.Atoi(parts[1]); nil != err {
|
||||
return nil, fmt.Errorf("couldn't parse port %q of %q", parts[1], cfg)
|
||||
}
|
||||
|
||||
portForwards = append(portForwards, Forward{
|
||||
pattern: ":" + parts[0],
|
||||
port: parts[1],
|
||||
})
|
||||
}
|
||||
|
||||
return portForwards, nil
|
||||
}
|
||||
|
||||
func parseBindAddrs(bindAddrsStr string) ([]string, error) {
|
||||
bindAddrs := []string{}
|
||||
|
||||
|
@ -386,8 +429,8 @@ func newAPIDNSProvider(baseURL string, token string) (*dns01.DNSProvider, error)
|
|||
t.ListenAndServe("wss://example.com", mux)
|
||||
*/
|
||||
|
||||
func getToken(secret string, domains []string) (token string, err error) {
|
||||
tokenData := jwt.MapClaims{"domains": domains}
|
||||
func getToken(secret string, domains, ports []string) (token string, err error) {
|
||||
tokenData := jwt.MapClaims{"domains": domains, "ports": ports}
|
||||
|
||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
|
||||
if token, err = jwtToken.SignedString([]byte(secret)); err != nil {
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
package telebit
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
import "fmt"
|
||||
|
||||
type Scheme string
|
||||
|
||||
|
@ -39,11 +36,15 @@ func NewAddr(s Scheme, t Termination, a string, p int) *Addr {
|
|||
}
|
||||
|
||||
func (a *Addr) String() string {
|
||||
return fmt.Sprintf("%s:%s:%s:%d", a.family, a.Scheme(), a.addr, a.port)
|
||||
//return a.addr + ":" + strconv.Itoa(a.port)
|
||||
return fmt.Sprintf("%s+%s:%s:%d", a.family, a.Scheme(), a.addr, a.port)
|
||||
}
|
||||
|
||||
// Network s typically network "family", such as "tcp" or "ip",
|
||||
// but in this case will be "tun", which is a cue to do a `switch`
|
||||
// to actually use the specific features of a telebit.Addr
|
||||
func (a *Addr) Network() string {
|
||||
return a.addr + ":" + strconv.Itoa(a.port)
|
||||
return "tun"
|
||||
}
|
||||
|
||||
func (a *Addr) Port() int {
|
||||
|
|
|
@ -47,7 +47,7 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||
|
||||
func (c *Conn) Peek(n int) (b []byte, err error) {
|
||||
if nil == c.peeker {
|
||||
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize)
|
||||
c.peeker = bufio.NewReaderSize(c.relay, defaultPeekerSize)
|
||||
}
|
||||
return c.peeker.Peek(n)
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
|
||||
// ConnWrap is just a cheap way to DRY up some switch conn.(type) statements to handle special features of Conn
|
||||
type ConnWrap struct {
|
||||
// TODO use io.MultiReader to unbuffer the peeker
|
||||
//Conn net.Conn
|
||||
peeker *bufio.Reader
|
||||
Conn net.Conn
|
||||
|
@ -31,7 +32,7 @@ func (c *ConnWrap) Peek(n int) ([]byte, error) {
|
|||
default:
|
||||
// *net.UDPConn,*net.TCPConn,*net.IPConn,*net.UnixConn
|
||||
if nil == c.peeker {
|
||||
c.peeker = bufio.NewReaderSize(c, defaultPeekerSize)
|
||||
c.peeker = bufio.NewReaderSize(c.Conn, defaultPeekerSize)
|
||||
}
|
||||
return c.peeker.Peek(n)
|
||||
}
|
||||
|
@ -93,31 +94,44 @@ func (c *ConnWrap) Servername() string {
|
|||
// or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http".
|
||||
func (c *ConnWrap) isTerminated() bool {
|
||||
// TODO look at SNI, may need context for Peek() timeout
|
||||
if nil != c.Plain {
|
||||
return true
|
||||
}
|
||||
/*
|
||||
if nil != c.Plain {
|
||||
return true
|
||||
}
|
||||
*/
|
||||
|
||||
// how to know how many bytes to read? really needs timeout
|
||||
b, err := c.Peek(2)
|
||||
if len(b) >= 2 {
|
||||
// TODO better detection?
|
||||
c.SetDeadline(time.Now().Add(5 * time.Second))
|
||||
n := 6
|
||||
b, _ := c.Peek(n)
|
||||
defer c.SetDeadline(time.Time{})
|
||||
if len(b) >= n {
|
||||
// SSL v3.x / TLS v1.x
|
||||
if 0x16 == b[0] && 0x03 == b[1] {
|
||||
// 0: TLS Byte
|
||||
// 1: Major Version
|
||||
// 2: 0-Indexed Minor Version
|
||||
// 3-4: Header Length
|
||||
// 5: TLS Client Hello Marker Byte
|
||||
if 0x16 == b[0] && 0x03 == b[1] && 0x01 == b[5] {
|
||||
//length := (int(b[3]) << 8) + int(b[4])
|
||||
return false
|
||||
}
|
||||
}
|
||||
if nil != err {
|
||||
return true
|
||||
}
|
||||
return true
|
||||
/*
|
||||
if nil != err {
|
||||
return true
|
||||
}
|
||||
|
||||
switch conn := c.Conn.(type) {
|
||||
case *ConnWrap:
|
||||
return conn.isTerminated()
|
||||
case *Conn:
|
||||
_, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)]
|
||||
return !ok
|
||||
}
|
||||
return false
|
||||
switch conn := c.Conn.(type) {
|
||||
case *ConnWrap:
|
||||
return conn.isTerminated()
|
||||
case *Conn:
|
||||
_, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)]
|
||||
return !ok
|
||||
}
|
||||
return false
|
||||
*/
|
||||
}
|
||||
|
||||
// LocalAddr returns the local network address.
|
||||
|
|
|
@ -126,7 +126,7 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
|
|||
// remember where the error message goes
|
||||
if "error" == string(dst.scheme) {
|
||||
pipe.Close()
|
||||
delete(l.conns, src.Network())
|
||||
delete(l.conns, src.String())
|
||||
fmt.Printf("a stream errored remotely: %v\n", src)
|
||||
}
|
||||
|
||||
|
@ -139,12 +139,12 @@ func (l *Listener) RouteBytes(srcAddr, dstAddr Addr, b []byte) {
|
|||
if "end" == string(dst.scheme) {
|
||||
fmt.Println("[debug] end")
|
||||
pipe.Close()
|
||||
delete(l.conns, src.Network())
|
||||
delete(l.conns, src.String())
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
|
||||
connID := src.Network()
|
||||
connID := src.String()
|
||||
pipe, ok := l.conns[connID]
|
||||
|
||||
// Pipe exists
|
||||
|
@ -157,8 +157,8 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
|
|||
rawPipe, pipe := net.Pipe()
|
||||
newconn := &Conn{
|
||||
//updated: time.Now(),
|
||||
relaySourceAddr: *src,
|
||||
relayTargetAddr: *dst,
|
||||
relaySourceAddr: *dst,
|
||||
relayTargetAddr: *src,
|
||||
relay: rawPipe,
|
||||
}
|
||||
l.conns[connID] = pipe
|
||||
|
@ -173,7 +173,11 @@ func (l *Listener) getPipe(src, dst *Addr, count int) net.Conn {
|
|||
// In any case, we'll just close it all
|
||||
newconn.Close()
|
||||
pipe.Close()
|
||||
fmt.Printf("a stream is done: %q\n", err)
|
||||
if nil != err {
|
||||
fmt.Printf("a stream is done: %q\n", err)
|
||||
} else {
|
||||
fmt.Printf("a stream is done\n")
|
||||
}
|
||||
}()
|
||||
|
||||
return pipe
|
||||
|
|
|
@ -3,6 +3,8 @@ package telebit
|
|||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
|
@ -31,13 +33,42 @@ func NewRouteMux() *RouteMux {
|
|||
|
||||
// Serve dispatches the connection to the handler whose selectors matches the attributes.
|
||||
func (m *RouteMux) Serve(client net.Conn) error {
|
||||
wconn := &ConnWrap{Conn: client}
|
||||
servername := wconn.Servername()
|
||||
var wconn *ConnWrap
|
||||
switch conn := client.(type) {
|
||||
case *ConnWrap:
|
||||
wconn = conn
|
||||
default:
|
||||
wconn = &ConnWrap{Conn: client}
|
||||
}
|
||||
|
||||
var servername string
|
||||
var port string
|
||||
// TODO go back to Servername on conn, but with SNI
|
||||
//servername := wconn.Servername()
|
||||
fam := wconn.LocalAddr().Network()
|
||||
if "tun" == fam {
|
||||
switch laddr := wconn.LocalAddr().(type) {
|
||||
case *Addr:
|
||||
servername = laddr.Hostname()
|
||||
port = ":" + strconv.Itoa(laddr.Port())
|
||||
default:
|
||||
panic("impossible type switch: Addr is 'tun' but didn't match")
|
||||
}
|
||||
} else {
|
||||
// TODO make an AddrWrap to do this switch
|
||||
addr := wconn.LocalAddr().String()
|
||||
parts := strings.Split(addr, ":")
|
||||
port = ":" + parts[len(parts)-1]
|
||||
servername = strings.Join(parts[:len(parts)-1], ":")
|
||||
}
|
||||
fmt.Println("Addr:", fam, servername, port)
|
||||
|
||||
for _, meta := range m.routes {
|
||||
if servername == meta.addr || "*" == meta.addr {
|
||||
// TODO '*.example.com'
|
||||
fmt.Println("Meta:", meta.addr)
|
||||
if servername == meta.addr || "*" == meta.addr || port == meta.addr {
|
||||
//fmt.Println("[debug] test of route:", meta)
|
||||
if err := meta.handler.Serve(client); nil != err {
|
||||
if err := meta.handler.Serve(wconn); nil != err {
|
||||
// error should be EOF if successful
|
||||
return err
|
||||
}
|
||||
|
@ -78,14 +109,22 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err
|
|||
addr: servername,
|
||||
terminate: true,
|
||||
handler: HandlerFunc(func(client net.Conn) error {
|
||||
wrap := &ConnWrap{Conn: client}
|
||||
if wrap.isTerminated() {
|
||||
var wconn *ConnWrap
|
||||
switch conn := client.(type) {
|
||||
case *ConnWrap:
|
||||
wconn = conn
|
||||
default:
|
||||
panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }")
|
||||
}
|
||||
|
||||
if wconn.isTerminated() {
|
||||
// nil to skip
|
||||
return nil
|
||||
}
|
||||
|
||||
//NewTerminator(acme, handler)(client)
|
||||
//return handler.Serve(client)
|
||||
return handler.Serve(TerminateTLS(client, acme))
|
||||
return handler.Serve(TerminateTLS(wconn, acme))
|
||||
}),
|
||||
})
|
||||
return nil
|
||||
|
|
|
@ -11,6 +11,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
|
@ -27,6 +28,11 @@ var defaultPeekerSize = 1024
|
|||
// ErrBadGateway means that the target did not accept the connection
|
||||
var ErrBadGateway = errors.New("EBADGATEWAY")
|
||||
|
||||
// The proper handling of this error
|
||||
// is still being debated as of Jun 9, 2020
|
||||
// https://github.com/golang/go/issues/4373
|
||||
var errNetClosing = "use of closed network connection"
|
||||
|
||||
// A Handler routes, proxies, terminates, or responds to a net.Conn.
|
||||
type Handler interface {
|
||||
Serve(net.Conn) error
|
||||
|
@ -104,7 +110,13 @@ func Forward(client net.Conn, target net.Conn, timeout time.Duration) error {
|
|||
}
|
||||
}()
|
||||
|
||||
fmt.Println("[debug] forwarding tcp connection")
|
||||
fmt.Println(
|
||||
"[debug] forwarding tcp connection",
|
||||
client.LocalAddr(),
|
||||
client.RemoteAddr(),
|
||||
target.LocalAddr(),
|
||||
target.RemoteAddr(),
|
||||
)
|
||||
var err error = nil
|
||||
|
||||
ForwardData:
|
||||
|
@ -131,7 +143,7 @@ ForwardData:
|
|||
if nil == err {
|
||||
break ForwardData
|
||||
}
|
||||
if io.EOF != err {
|
||||
if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
|
||||
fmt.Printf("read from remote client failed: %q\n", err.Error())
|
||||
} else {
|
||||
fmt.Printf("Connection closed (possibly by remote client)\n")
|
||||
|
@ -141,7 +153,7 @@ ForwardData:
|
|||
if nil == err {
|
||||
break ForwardData
|
||||
}
|
||||
if io.EOF != err {
|
||||
if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
|
||||
fmt.Printf("read from local target failed: %q\n", err.Error())
|
||||
} else {
|
||||
fmt.Printf("Connection closed (possibly by local target)\n")
|
||||
|
@ -187,7 +199,11 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn {
|
|||
var err error
|
||||
magic, err = newCertMagic(acme)
|
||||
if nil != err {
|
||||
fmt.Fprintf(os.Stderr, "failed to initialize certificate management (discovery url? local folder perms?): %s\n", err)
|
||||
fmt.Fprintf(
|
||||
os.Stderr,
|
||||
"failed to initialize certificate management (discovery url? local folder perms?): %s\n",
|
||||
err,
|
||||
)
|
||||
os.Exit(1)
|
||||
}
|
||||
acmecert = magic
|
||||
|
|
|
@ -126,6 +126,7 @@ func (p *Parser) unpackV1Header(b []byte, n int) ([]byte, error) {
|
|||
return b, nil
|
||||
}
|
||||
parts := strings.Split(string(p.state.header), ",")
|
||||
fmt.Println("[debug] Tun Header", string(p.state.header))
|
||||
p.state.header = nil
|
||||
if len(parts) < 5 {
|
||||
return nil, errors.New("error unpacking header")
|
||||
|
|
Loading…
Reference in New Issue