got non-terminating traffic identified, and used SNI to figure direction
This commit is contained in:
parent
f3bb9cb584
commit
5334649fba
3
main.go
3
main.go
|
@ -25,6 +25,7 @@ var (
|
|||
argDeadTime int
|
||||
connectionTable *genericlistener.Table
|
||||
secretKey = "abc123"
|
||||
wssHostName = "localhost.daplie.me"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
@ -70,7 +71,7 @@ func main() {
|
|||
connectionTable = genericlistener.NewTable()
|
||||
go connectionTable.Run(ctx)
|
||||
|
||||
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime)
|
||||
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime, wssHostName)
|
||||
go genericListeners.Run(ctx, 9999)
|
||||
|
||||
//Run for 10 minutes and then shutdown cleanly
|
||||
|
|
|
@ -34,6 +34,7 @@ const (
|
|||
ctxDeadTime contextKey = "deadtime"
|
||||
ctxListenerRegistration contextKey = "listenerRegistration"
|
||||
ctxConnectionTrack contextKey = "connectionTrack"
|
||||
ctxWssHostName contextKey = "wsshostname"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -146,6 +147,35 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
|
|||
|
||||
} else if encryptMode != encryptNone {
|
||||
loginfo.Println("Handle Encryption")
|
||||
|
||||
// check SNI heading
|
||||
// if matched, then looks like a WSS connection
|
||||
// else external don't pull off TLS.
|
||||
|
||||
peek, err := wConn.PeekAll()
|
||||
if err != nil {
|
||||
loginfo.Println("error while peeking")
|
||||
loginfo.Println(hex.Dump(peek[0:]))
|
||||
return
|
||||
}
|
||||
|
||||
wssHostName := ctx.Value(ctxWssHostName).(string)
|
||||
sniHostName, err := getHello(peek)
|
||||
if err != nil {
|
||||
loginfo.Println(err)
|
||||
return
|
||||
}
|
||||
|
||||
loginfo.Println("sni:", sniHostName)
|
||||
|
||||
if wssHostName != sniHostName {
|
||||
//traffic not terminating on the rvpn do not decrypt
|
||||
loginfo.Println("processing non terminating traffic")
|
||||
handleExternalHTTPRequest(ctx, wConn, sniHostName)
|
||||
}
|
||||
|
||||
loginfo.Println("processing traffic terminating on RVPN")
|
||||
|
||||
tlsListener := tls.NewListener(oneConn, config)
|
||||
|
||||
conn, err := tlsListener.Accept()
|
||||
|
@ -175,11 +205,11 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
|
|||
// - handle other?
|
||||
func handleStream(ctx context.Context, wConn *WedgeConn) {
|
||||
loginfo.Println("handle Stream")
|
||||
loginfo.Println("conn", wConn, wConn.LocalAddr().String(), wConn.RemoteAddr().String())
|
||||
loginfo.Println("conn", wConn.LocalAddr().String(), wConn.RemoteAddr().String())
|
||||
|
||||
peek, err := wConn.PeekAll()
|
||||
if err != nil {
|
||||
loginfo.Println("error while peeking")
|
||||
loginfo.Println("error while peeking", err)
|
||||
loginfo.Println(hex.Dump(peek[0:]))
|
||||
return
|
||||
}
|
||||
|
@ -217,8 +247,8 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
|
|||
return
|
||||
|
||||
} else {
|
||||
loginfo.Println("default connection")
|
||||
handleExternalHTTPRequest(ctx, wConn)
|
||||
loginfo.Println("unsupported")
|
||||
loginfo.Println(hex.Dump(peek))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
@ -227,7 +257,7 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
|
|||
|
||||
//handleExternalHTTPRequest -
|
||||
// - get a wConn and start processing requests
|
||||
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
||||
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn, hostname string) {
|
||||
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
|
||||
|
||||
defer func() {
|
||||
|
@ -236,37 +266,6 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
|||
}()
|
||||
|
||||
connectionTable := ctx.Value(ctxConnectionTable).(*Table)
|
||||
|
||||
var buffer [512]byte
|
||||
for {
|
||||
cnt, err := extConn.Read(buffer[0:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
readBuffer := bytes.NewBuffer(buffer[0:cnt])
|
||||
reader := bufio.NewReader(readBuffer)
|
||||
r, err := http.ReadRequest(reader)
|
||||
|
||||
if err != nil {
|
||||
loginfo.Println("error parsing request")
|
||||
return
|
||||
}
|
||||
|
||||
hostname := r.Host
|
||||
loginfo.Println("Host: ", hostname)
|
||||
|
||||
if strings.Contains(hostname, ":") {
|
||||
arr := strings.Split(hostname, ":")
|
||||
hostname = arr[0]
|
||||
}
|
||||
|
||||
loginfo.Println("Remote: ", extConn.RemoteAddr().String())
|
||||
|
||||
remoteSplit := strings.Split(extConn.RemoteAddr().String(), ":")
|
||||
rAddr := remoteSplit[0]
|
||||
rPort := remoteSplit[1]
|
||||
|
||||
//find the connection by domain name
|
||||
conn, ok := connectionTable.ConnByDomain(hostname)
|
||||
if !ok {
|
||||
|
@ -279,7 +278,21 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
|||
track := NewTrack(extConn, hostname)
|
||||
connectionTracking.register <- track
|
||||
|
||||
loginfo.Println("Domain Accepted", conn, rAddr, rPort)
|
||||
loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String())
|
||||
|
||||
rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String())
|
||||
return
|
||||
}
|
||||
|
||||
var buffer [512]byte
|
||||
for {
|
||||
cnt, err := extConn.Read(buffer[0:])
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
p := packer.NewPacker()
|
||||
p.Header.SetAddress(rAddr)
|
||||
p.Header.Port, err = strconv.Atoi(rPort)
|
||||
|
@ -292,6 +305,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
|||
p.Data.AppendBytes(buffer[0:cnt])
|
||||
buf := p.PackV1()
|
||||
|
||||
loginfo.Println(hex.Dump(buf.Bytes()))
|
||||
|
||||
sendTrack := NewSendTrack(buf.Bytes(), hostname)
|
||||
conn.SendCh() <- sendTrack
|
||||
}
|
||||
|
|
|
@ -54,10 +54,11 @@ type GenericListeners struct {
|
|||
deadTime int
|
||||
register chan *ListenerRegistration
|
||||
genericListeners *GenericListeners
|
||||
wssHostName string
|
||||
}
|
||||
|
||||
//NewGenerListeners --
|
||||
func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int) (p *GenericListeners) {
|
||||
func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int, wssHostName string) (p *GenericListeners) {
|
||||
p = new(GenericListeners)
|
||||
p.listeners = make(map[*net.Listener]int)
|
||||
p.ctx = ctx
|
||||
|
@ -67,6 +68,7 @@ func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTr
|
|||
p.certbundle = certbundle
|
||||
p.deadTime = deadTime
|
||||
p.register = make(chan *ListenerRegistration)
|
||||
p.wssHostName = wssHostName
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -87,6 +89,7 @@ func (gl *GenericListeners) Run(ctx context.Context, initialPort int) {
|
|||
ctx = context.WithValue(ctx, ctxConfig, config)
|
||||
ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime)
|
||||
ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register)
|
||||
ctx = context.WithValue(ctx, ctxWssHostName, gl.wssHostName)
|
||||
|
||||
go func(ctx context.Context) {
|
||||
for {
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
package genericlistener
|
||||
|
||||
import "errors"
|
||||
|
||||
func getHello(b []byte) (string, error) {
|
||||
rest := b[5:]
|
||||
current := 0
|
||||
handshakeType := rest[0]
|
||||
current++
|
||||
if handshakeType != 0x1 {
|
||||
return "", errors.New("Not a ClientHello")
|
||||
}
|
||||
|
||||
// Skip over another length
|
||||
current += 3
|
||||
// Skip over protocolversion
|
||||
current += 2
|
||||
// Skip over random number
|
||||
current += 4 + 28
|
||||
// Skip over session ID
|
||||
sessionIDLength := int(rest[current])
|
||||
current++
|
||||
current += sessionIDLength
|
||||
|
||||
cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1])
|
||||
current += 2
|
||||
current += cipherSuiteLength
|
||||
|
||||
compressionMethodLength := int(rest[current])
|
||||
current++
|
||||
current += compressionMethodLength
|
||||
|
||||
if current > len(rest) {
|
||||
return "", errors.New("no extensions")
|
||||
}
|
||||
|
||||
current += 2
|
||||
|
||||
hostname := ""
|
||||
for current < len(rest) && hostname == "" {
|
||||
extensionType := (int(rest[current]) << 8) + int(rest[current+1])
|
||||
current += 2
|
||||
|
||||
extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1])
|
||||
current += 2
|
||||
|
||||
if extensionType == 0 {
|
||||
|
||||
// Skip over number of names as we're assuming there's just one
|
||||
current += 2
|
||||
|
||||
nameType := rest[current]
|
||||
current++
|
||||
if nameType != 0 {
|
||||
return "", errors.New("Not a hostname")
|
||||
}
|
||||
nameLen := (int(rest[current]) << 8) + int(rest[current+1])
|
||||
current += 2
|
||||
hostname = string(rest[current : current+nameLen])
|
||||
}
|
||||
|
||||
current += extensionDataLength
|
||||
}
|
||||
if hostname == "" {
|
||||
return "", errors.New("No hostname")
|
||||
}
|
||||
return hostname, nil
|
||||
|
||||
}
|
Loading…
Reference in New Issue