got non-terminating traffic identified, and used SNI to figure direction

This commit is contained in:
Henry Camacho 2017-03-10 23:36:42 -06:00
parent f3bb9cb584
commit 5334649fba
4 changed files with 131 additions and 43 deletions

View File

@ -25,6 +25,7 @@ var (
argDeadTime int argDeadTime int
connectionTable *genericlistener.Table connectionTable *genericlistener.Table
secretKey = "abc123" secretKey = "abc123"
wssHostName = "localhost.daplie.me"
) )
func init() { func init() {
@ -70,7 +71,7 @@ func main() {
connectionTable = genericlistener.NewTable() connectionTable = genericlistener.NewTable()
go connectionTable.Run(ctx) go connectionTable.Run(ctx)
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime) genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime, wssHostName)
go genericListeners.Run(ctx, 9999) go genericListeners.Run(ctx, 9999)
//Run for 10 minutes and then shutdown cleanly //Run for 10 minutes and then shutdown cleanly

View File

@ -34,6 +34,7 @@ const (
ctxDeadTime contextKey = "deadtime" ctxDeadTime contextKey = "deadtime"
ctxListenerRegistration contextKey = "listenerRegistration" ctxListenerRegistration contextKey = "listenerRegistration"
ctxConnectionTrack contextKey = "connectionTrack" ctxConnectionTrack contextKey = "connectionTrack"
ctxWssHostName contextKey = "wsshostname"
) )
const ( const (
@ -146,6 +147,35 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
} else if encryptMode != encryptNone { } else if encryptMode != encryptNone {
loginfo.Println("Handle Encryption") loginfo.Println("Handle Encryption")
// check SNI heading
// if matched, then looks like a WSS connection
// else external don't pull off TLS.
peek, err := wConn.PeekAll()
if err != nil {
loginfo.Println("error while peeking")
loginfo.Println(hex.Dump(peek[0:]))
return
}
wssHostName := ctx.Value(ctxWssHostName).(string)
sniHostName, err := getHello(peek)
if err != nil {
loginfo.Println(err)
return
}
loginfo.Println("sni:", sniHostName)
if wssHostName != sniHostName {
//traffic not terminating on the rvpn do not decrypt
loginfo.Println("processing non terminating traffic")
handleExternalHTTPRequest(ctx, wConn, sniHostName)
}
loginfo.Println("processing traffic terminating on RVPN")
tlsListener := tls.NewListener(oneConn, config) tlsListener := tls.NewListener(oneConn, config)
conn, err := tlsListener.Accept() conn, err := tlsListener.Accept()
@ -175,11 +205,11 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
// - handle other? // - handle other?
func handleStream(ctx context.Context, wConn *WedgeConn) { func handleStream(ctx context.Context, wConn *WedgeConn) {
loginfo.Println("handle Stream") loginfo.Println("handle Stream")
loginfo.Println("conn", wConn, wConn.LocalAddr().String(), wConn.RemoteAddr().String()) loginfo.Println("conn", wConn.LocalAddr().String(), wConn.RemoteAddr().String())
peek, err := wConn.PeekAll() peek, err := wConn.PeekAll()
if err != nil { if err != nil {
loginfo.Println("error while peeking") loginfo.Println("error while peeking", err)
loginfo.Println(hex.Dump(peek[0:])) loginfo.Println(hex.Dump(peek[0:]))
return return
} }
@ -217,8 +247,8 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
return return
} else { } else {
loginfo.Println("default connection") loginfo.Println("unsupported")
handleExternalHTTPRequest(ctx, wConn) loginfo.Println(hex.Dump(peek))
return return
} }
} }
@ -227,7 +257,7 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
//handleExternalHTTPRequest - //handleExternalHTTPRequest -
// - get a wConn and start processing requests // - get a wConn and start processing requests
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn, hostname string) {
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking) connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
defer func() { defer func() {
@ -236,37 +266,6 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
}() }()
connectionTable := ctx.Value(ctxConnectionTable).(*Table) connectionTable := ctx.Value(ctxConnectionTable).(*Table)
var buffer [512]byte
for {
cnt, err := extConn.Read(buffer[0:])
if err != nil {
return
}
readBuffer := bytes.NewBuffer(buffer[0:cnt])
reader := bufio.NewReader(readBuffer)
r, err := http.ReadRequest(reader)
if err != nil {
loginfo.Println("error parsing request")
return
}
hostname := r.Host
loginfo.Println("Host: ", hostname)
if strings.Contains(hostname, ":") {
arr := strings.Split(hostname, ":")
hostname = arr[0]
}
loginfo.Println("Remote: ", extConn.RemoteAddr().String())
remoteSplit := strings.Split(extConn.RemoteAddr().String(), ":")
rAddr := remoteSplit[0]
rPort := remoteSplit[1]
//find the connection by domain name //find the connection by domain name
conn, ok := connectionTable.ConnByDomain(hostname) conn, ok := connectionTable.ConnByDomain(hostname)
if !ok { if !ok {
@ -279,7 +278,21 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
track := NewTrack(extConn, hostname) track := NewTrack(extConn, hostname)
connectionTracking.register <- track connectionTracking.register <- track
loginfo.Println("Domain Accepted", conn, rAddr, rPort) loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String())
rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String())
if err != nil {
loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String())
return
}
var buffer [512]byte
for {
cnt, err := extConn.Read(buffer[0:])
if err != nil {
return
}
p := packer.NewPacker() p := packer.NewPacker()
p.Header.SetAddress(rAddr) p.Header.SetAddress(rAddr)
p.Header.Port, err = strconv.Atoi(rPort) p.Header.Port, err = strconv.Atoi(rPort)
@ -292,6 +305,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
p.Data.AppendBytes(buffer[0:cnt]) p.Data.AppendBytes(buffer[0:cnt])
buf := p.PackV1() buf := p.PackV1()
loginfo.Println(hex.Dump(buf.Bytes()))
sendTrack := NewSendTrack(buf.Bytes(), hostname) sendTrack := NewSendTrack(buf.Bytes(), hostname)
conn.SendCh() <- sendTrack conn.SendCh() <- sendTrack
} }

View File

@ -54,10 +54,11 @@ type GenericListeners struct {
deadTime int deadTime int
register chan *ListenerRegistration register chan *ListenerRegistration
genericListeners *GenericListeners genericListeners *GenericListeners
wssHostName string
} }
//NewGenerListeners -- //NewGenerListeners --
func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int) (p *GenericListeners) { func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int, wssHostName string) (p *GenericListeners) {
p = new(GenericListeners) p = new(GenericListeners)
p.listeners = make(map[*net.Listener]int) p.listeners = make(map[*net.Listener]int)
p.ctx = ctx p.ctx = ctx
@ -67,6 +68,7 @@ func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTr
p.certbundle = certbundle p.certbundle = certbundle
p.deadTime = deadTime p.deadTime = deadTime
p.register = make(chan *ListenerRegistration) p.register = make(chan *ListenerRegistration)
p.wssHostName = wssHostName
return return
} }
@ -87,6 +89,7 @@ func (gl *GenericListeners) Run(ctx context.Context, initialPort int) {
ctx = context.WithValue(ctx, ctxConfig, config) ctx = context.WithValue(ctx, ctxConfig, config)
ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime) ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime)
ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register) ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register)
ctx = context.WithValue(ctx, ctxWssHostName, gl.wssHostName)
go func(ctx context.Context) { go func(ctx context.Context) {
for { for {

View File

@ -0,0 +1,69 @@
package genericlistener
import "errors"
func getHello(b []byte) (string, error) {
rest := b[5:]
current := 0
handshakeType := rest[0]
current++
if handshakeType != 0x1 {
return "", errors.New("Not a ClientHello")
}
// Skip over another length
current += 3
// Skip over protocolversion
current += 2
// Skip over random number
current += 4 + 28
// Skip over session ID
sessionIDLength := int(rest[current])
current++
current += sessionIDLength
cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1])
current += 2
current += cipherSuiteLength
compressionMethodLength := int(rest[current])
current++
current += compressionMethodLength
if current > len(rest) {
return "", errors.New("no extensions")
}
current += 2
hostname := ""
for current < len(rest) && hostname == "" {
extensionType := (int(rest[current]) << 8) + int(rest[current+1])
current += 2
extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1])
current += 2
if extensionType == 0 {
// Skip over number of names as we're assuming there's just one
current += 2
nameType := rest[current]
current++
if nameType != 0 {
return "", errors.New("Not a hostname")
}
nameLen := (int(rest[current]) << 8) + int(rest[current+1])
current += 2
hostname = string(rest[current : current+nameLen])
}
current += extensionDataLength
}
if hostname == "" {
return "", errors.New("No hostname")
}
return hostname, nil
}