fix accidental early client.Close()

This commit is contained in:
AJ ONeal 2020-07-08 10:28:32 +00:00
parent c951ce1254
commit 2f1f138bda
5 changed files with 29 additions and 45 deletions

View File

@ -251,9 +251,9 @@ func main() {
fmt.Printf("[debug] Accepting API or WebSocket client %q\n", *apiHostname) fmt.Printf("[debug] Accepting API or WebSocket client %q\n", *apiHostname)
listener.Feed(client) listener.Feed(client)
fmt.Printf("[debug] done with %q client\n", *apiHostname) fmt.Printf("[debug] done with %q client\n", *apiHostname)
// TODO use a more correct non-error error? // nil now means handler in-progress (go routine)
// or perhaps (ok, error) or (handled, error)? // EOF now means handler finished
return io.EOF return nil
})) }))
} }
for _, fwd := range forwards { for _, fwd := range forwards {
@ -369,7 +369,8 @@ func routeSubscribersAndClients(client net.Conn) error {
labels := strings.Split(servername, ".") labels := strings.Split(servername, ".")
n := len(labels) n := len(labels)
if n < 3 { if n < 3 {
return nil // skip
return telebit.ErrNotHandled
} }
for i := 1; i < n-1; i++ { for i := 1; i < n-1; i++ {
wildname := "*." + strings.Join(labels[1:], ".") wildname := "*." + strings.Join(labels[1:], ".")
@ -379,7 +380,7 @@ func routeSubscribersAndClients(client net.Conn) error {
} }
// skip // skip
return nil return telebit.ErrNotHandled
} }
// tryToServeName picks the server tunnel with the least connections, if any // tryToServeName picks the server tunnel with the least connections, if any

View File

@ -117,28 +117,21 @@ func (c *ConnWrap) Servername() string {
} }
// this will get the servername // this will get the servername
c.isTerminated() _ = c.isEncrypted()
return c.servername return c.servername
} }
// isTerminated returns true if net.Conn is either a ConnWrap{ tls.Conn }, // isEncrypted returns true if peeking at net.Conn reveals that it is TLS-encrypted
// or a telebit.Conn with a non-encrypted `scheme` such as "tcp" or "http". func (c *ConnWrap) isEncrypted() bool {
func (c *ConnWrap) isTerminated() bool {
// TODO look at SNI, may need context for Peek() timeout
/*
if nil != c.Plain {
return true
}
*/
if nil != c.encrypted { if nil != c.encrypted {
return !*c.encrypted return *c.encrypted
} }
// how to know how many bytes to read? really needs timeout // TODO: how to allow / detect / handle protocols where the server hello happens first?
c.SetDeadline(time.Now().Add(5 * time.Second)) c.SetDeadline(time.Now().Add(5 * time.Second))
n := 6 n := 6
b, _ := c.Peek(n) b, _ := c.Peek(n)
fmt.Println("Peek(n)", b, string(b)) fmt.Println("[debug] Peek(n)", b, string(b))
defer c.SetDeadline(time.Time{}) defer c.SetDeadline(time.Time{})
var encrypted bool var encrypted bool
if len(b) >= n { if len(b) >= n {
@ -155,30 +148,16 @@ func (c *ConnWrap) isTerminated() bool {
b, err := c.Peek(n - 1 + length) b, err := c.Peek(n - 1 + length)
if nil != err { if nil != err {
c.encrypted = &encrypted c.encrypted = &encrypted
return !*c.encrypted return *c.encrypted
} }
c.servername, _ = sni.GetHostname(b) c.servername, _ = sni.GetHostname(b)
encrypted = true encrypted = true
c.encrypted = &encrypted c.encrypted = &encrypted
return !*c.encrypted return *c.encrypted
} }
} }
c.encrypted = &encrypted c.encrypted = &encrypted
return !*c.encrypted return *c.encrypted
/*
if nil != err {
return true
}
switch conn := c.Conn.(type) {
case *ConnWrap:
return conn.isTerminated()
case *Conn:
_, ok := encryptedSchemes[string(conn.relayTargetAddr.scheme)]
return !ok
}
return false
*/
} }
// LocalAddr returns the local network address. // LocalAddr returns the local network address.

View File

@ -76,13 +76,16 @@ func Serve(listener net.Listener, mux Handler) error {
} }
go func() { go func() {
err = mux.Serve(client) // nil means being handled
if nil != err { // non-nil means handled
// io.EOF means handled with success
if err := mux.Serve(client); nil != err {
if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) { if io.EOF != err && io.ErrClosedPipe != err && !strings.Contains(err.Error(), errNetClosing) {
fmt.Printf("client could not be served: %q\n", err.Error()) fmt.Printf("client could not be served: %q\n", err.Error())
} }
fmt.Println("[debug] closing original client", err)
client.Close()
} }
client.Close()
}() }()
} }
} }

View File

@ -1,6 +1,7 @@
package telebit package telebit
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
@ -18,6 +19,8 @@ type RouteMux struct {
routes []meta routes []meta
} }
var ErrNotHandled = errors.New("connection not handled")
type meta struct { type meta struct {
addr string addr string
handler Handler handler Handler
@ -72,11 +75,10 @@ func (m *RouteMux) Serve(client net.Conn) error {
fmt.Println("Meta:", meta.addr, servername) fmt.Println("Meta:", meta.addr, servername)
if servername == meta.addr || "*" == meta.addr || port == meta.addr { if servername == meta.addr || "*" == meta.addr || port == meta.addr {
//fmt.Println("[debug] test of route:", meta) //fmt.Println("[debug] test of route:", meta)
if err := meta.handler.Serve(wconn); nil != err { // Only keep trying handlers if ErrNotHandled was returned
// error should be EOF if successful if err := meta.handler.Serve(wconn); ErrNotHandled != err {
return err return err
} }
// nil err means skipped
} }
} }
@ -121,9 +123,8 @@ func (m *RouteMux) HandleTLS(servername string, acme *ACME, handler Handler) err
panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }") panic("HandleTLS is special in that it must receive &ConnWrap{ Conn: conn }")
} }
if wconn.isTerminated() { if !wconn.isEncrypted() {
// nil to skip return ErrNotHandled
return nil
} }
//NewTerminator(acme, handler)(client) //NewTerminator(acme, handler)(client)

View File

@ -253,7 +253,7 @@ func TerminateTLS(client net.Conn, acme *ACME) net.Conn {
wconn := &ConnWrap{ wconn := &ConnWrap{
Conn: client, Conn: client,
} }
wconn.isTerminated() _ = wconn.isEncrypted()
servername = wconn.Servername() servername = wconn.Servername()
scheme = wconn.Scheme() scheme = wconn.Scheme()
client = wconn client = wconn