got non-terminating traffic identified, and used SNI to figure direction
This commit is contained in:
parent
f3bb9cb584
commit
5334649fba
3
main.go
3
main.go
|
@ -25,6 +25,7 @@ var (
|
||||||
argDeadTime int
|
argDeadTime int
|
||||||
connectionTable *genericlistener.Table
|
connectionTable *genericlistener.Table
|
||||||
secretKey = "abc123"
|
secretKey = "abc123"
|
||||||
|
wssHostName = "localhost.daplie.me"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
@ -70,7 +71,7 @@ func main() {
|
||||||
connectionTable = genericlistener.NewTable()
|
connectionTable = genericlistener.NewTable()
|
||||||
go connectionTable.Run(ctx)
|
go connectionTable.Run(ctx)
|
||||||
|
|
||||||
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime)
|
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime, wssHostName)
|
||||||
go genericListeners.Run(ctx, 9999)
|
go genericListeners.Run(ctx, 9999)
|
||||||
|
|
||||||
//Run for 10 minutes and then shutdown cleanly
|
//Run for 10 minutes and then shutdown cleanly
|
||||||
|
|
|
@ -34,6 +34,7 @@ const (
|
||||||
ctxDeadTime contextKey = "deadtime"
|
ctxDeadTime contextKey = "deadtime"
|
||||||
ctxListenerRegistration contextKey = "listenerRegistration"
|
ctxListenerRegistration contextKey = "listenerRegistration"
|
||||||
ctxConnectionTrack contextKey = "connectionTrack"
|
ctxConnectionTrack contextKey = "connectionTrack"
|
||||||
|
ctxWssHostName contextKey = "wsshostname"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -146,6 +147,35 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
|
||||||
|
|
||||||
} else if encryptMode != encryptNone {
|
} else if encryptMode != encryptNone {
|
||||||
loginfo.Println("Handle Encryption")
|
loginfo.Println("Handle Encryption")
|
||||||
|
|
||||||
|
// check SNI heading
|
||||||
|
// if matched, then looks like a WSS connection
|
||||||
|
// else external don't pull off TLS.
|
||||||
|
|
||||||
|
peek, err := wConn.PeekAll()
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println("error while peeking")
|
||||||
|
loginfo.Println(hex.Dump(peek[0:]))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wssHostName := ctx.Value(ctxWssHostName).(string)
|
||||||
|
sniHostName, err := getHello(peek)
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loginfo.Println("sni:", sniHostName)
|
||||||
|
|
||||||
|
if wssHostName != sniHostName {
|
||||||
|
//traffic not terminating on the rvpn do not decrypt
|
||||||
|
loginfo.Println("processing non terminating traffic")
|
||||||
|
handleExternalHTTPRequest(ctx, wConn, sniHostName)
|
||||||
|
}
|
||||||
|
|
||||||
|
loginfo.Println("processing traffic terminating on RVPN")
|
||||||
|
|
||||||
tlsListener := tls.NewListener(oneConn, config)
|
tlsListener := tls.NewListener(oneConn, config)
|
||||||
|
|
||||||
conn, err := tlsListener.Accept()
|
conn, err := tlsListener.Accept()
|
||||||
|
@ -175,11 +205,11 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
|
||||||
// - handle other?
|
// - handle other?
|
||||||
func handleStream(ctx context.Context, wConn *WedgeConn) {
|
func handleStream(ctx context.Context, wConn *WedgeConn) {
|
||||||
loginfo.Println("handle Stream")
|
loginfo.Println("handle Stream")
|
||||||
loginfo.Println("conn", wConn, wConn.LocalAddr().String(), wConn.RemoteAddr().String())
|
loginfo.Println("conn", wConn.LocalAddr().String(), wConn.RemoteAddr().String())
|
||||||
|
|
||||||
peek, err := wConn.PeekAll()
|
peek, err := wConn.PeekAll()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
loginfo.Println("error while peeking")
|
loginfo.Println("error while peeking", err)
|
||||||
loginfo.Println(hex.Dump(peek[0:]))
|
loginfo.Println(hex.Dump(peek[0:]))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -217,8 +247,8 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
|
||||||
return
|
return
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
loginfo.Println("default connection")
|
loginfo.Println("unsupported")
|
||||||
handleExternalHTTPRequest(ctx, wConn)
|
loginfo.Println(hex.Dump(peek))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -227,7 +257,7 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
|
||||||
|
|
||||||
//handleExternalHTTPRequest -
|
//handleExternalHTTPRequest -
|
||||||
// - get a wConn and start processing requests
|
// - get a wConn and start processing requests
|
||||||
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn, hostname string) {
|
||||||
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
|
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -236,37 +266,6 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
connectionTable := ctx.Value(ctxConnectionTable).(*Table)
|
connectionTable := ctx.Value(ctxConnectionTable).(*Table)
|
||||||
|
|
||||||
var buffer [512]byte
|
|
||||||
for {
|
|
||||||
cnt, err := extConn.Read(buffer[0:])
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
readBuffer := bytes.NewBuffer(buffer[0:cnt])
|
|
||||||
reader := bufio.NewReader(readBuffer)
|
|
||||||
r, err := http.ReadRequest(reader)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
loginfo.Println("error parsing request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
hostname := r.Host
|
|
||||||
loginfo.Println("Host: ", hostname)
|
|
||||||
|
|
||||||
if strings.Contains(hostname, ":") {
|
|
||||||
arr := strings.Split(hostname, ":")
|
|
||||||
hostname = arr[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
loginfo.Println("Remote: ", extConn.RemoteAddr().String())
|
|
||||||
|
|
||||||
remoteSplit := strings.Split(extConn.RemoteAddr().String(), ":")
|
|
||||||
rAddr := remoteSplit[0]
|
|
||||||
rPort := remoteSplit[1]
|
|
||||||
|
|
||||||
//find the connection by domain name
|
//find the connection by domain name
|
||||||
conn, ok := connectionTable.ConnByDomain(hostname)
|
conn, ok := connectionTable.ConnByDomain(hostname)
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -279,7 +278,21 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
||||||
track := NewTrack(extConn, hostname)
|
track := NewTrack(extConn, hostname)
|
||||||
connectionTracking.register <- track
|
connectionTracking.register <- track
|
||||||
|
|
||||||
loginfo.Println("Domain Accepted", conn, rAddr, rPort)
|
loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String())
|
||||||
|
|
||||||
|
rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String())
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var buffer [512]byte
|
||||||
|
for {
|
||||||
|
cnt, err := extConn.Read(buffer[0:])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p := packer.NewPacker()
|
p := packer.NewPacker()
|
||||||
p.Header.SetAddress(rAddr)
|
p.Header.SetAddress(rAddr)
|
||||||
p.Header.Port, err = strconv.Atoi(rPort)
|
p.Header.Port, err = strconv.Atoi(rPort)
|
||||||
|
@ -292,6 +305,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
||||||
p.Data.AppendBytes(buffer[0:cnt])
|
p.Data.AppendBytes(buffer[0:cnt])
|
||||||
buf := p.PackV1()
|
buf := p.PackV1()
|
||||||
|
|
||||||
|
loginfo.Println(hex.Dump(buf.Bytes()))
|
||||||
|
|
||||||
sendTrack := NewSendTrack(buf.Bytes(), hostname)
|
sendTrack := NewSendTrack(buf.Bytes(), hostname)
|
||||||
conn.SendCh() <- sendTrack
|
conn.SendCh() <- sendTrack
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,10 +54,11 @@ type GenericListeners struct {
|
||||||
deadTime int
|
deadTime int
|
||||||
register chan *ListenerRegistration
|
register chan *ListenerRegistration
|
||||||
genericListeners *GenericListeners
|
genericListeners *GenericListeners
|
||||||
|
wssHostName string
|
||||||
}
|
}
|
||||||
|
|
||||||
//NewGenerListeners --
|
//NewGenerListeners --
|
||||||
func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int) (p *GenericListeners) {
|
func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTrack *Tracking, secretKey string, certbundle tls.Certificate, deadTime int, wssHostName string) (p *GenericListeners) {
|
||||||
p = new(GenericListeners)
|
p = new(GenericListeners)
|
||||||
p.listeners = make(map[*net.Listener]int)
|
p.listeners = make(map[*net.Listener]int)
|
||||||
p.ctx = ctx
|
p.ctx = ctx
|
||||||
|
@ -67,6 +68,7 @@ func NewGenerListeners(ctx context.Context, connectionTable *Table, connectionTr
|
||||||
p.certbundle = certbundle
|
p.certbundle = certbundle
|
||||||
p.deadTime = deadTime
|
p.deadTime = deadTime
|
||||||
p.register = make(chan *ListenerRegistration)
|
p.register = make(chan *ListenerRegistration)
|
||||||
|
p.wssHostName = wssHostName
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -87,6 +89,7 @@ func (gl *GenericListeners) Run(ctx context.Context, initialPort int) {
|
||||||
ctx = context.WithValue(ctx, ctxConfig, config)
|
ctx = context.WithValue(ctx, ctxConfig, config)
|
||||||
ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime)
|
ctx = context.WithValue(ctx, ctxDeadTime, gl.deadTime)
|
||||||
ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register)
|
ctx = context.WithValue(ctx, ctxListenerRegistration, gl.register)
|
||||||
|
ctx = context.WithValue(ctx, ctxWssHostName, gl.wssHostName)
|
||||||
|
|
||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
for {
|
for {
|
||||||
|
|
|
@ -0,0 +1,69 @@
|
||||||
|
package genericlistener
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
func getHello(b []byte) (string, error) {
|
||||||
|
rest := b[5:]
|
||||||
|
current := 0
|
||||||
|
handshakeType := rest[0]
|
||||||
|
current++
|
||||||
|
if handshakeType != 0x1 {
|
||||||
|
return "", errors.New("Not a ClientHello")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip over another length
|
||||||
|
current += 3
|
||||||
|
// Skip over protocolversion
|
||||||
|
current += 2
|
||||||
|
// Skip over random number
|
||||||
|
current += 4 + 28
|
||||||
|
// Skip over session ID
|
||||||
|
sessionIDLength := int(rest[current])
|
||||||
|
current++
|
||||||
|
current += sessionIDLength
|
||||||
|
|
||||||
|
cipherSuiteLength := (int(rest[current]) << 8) + int(rest[current+1])
|
||||||
|
current += 2
|
||||||
|
current += cipherSuiteLength
|
||||||
|
|
||||||
|
compressionMethodLength := int(rest[current])
|
||||||
|
current++
|
||||||
|
current += compressionMethodLength
|
||||||
|
|
||||||
|
if current > len(rest) {
|
||||||
|
return "", errors.New("no extensions")
|
||||||
|
}
|
||||||
|
|
||||||
|
current += 2
|
||||||
|
|
||||||
|
hostname := ""
|
||||||
|
for current < len(rest) && hostname == "" {
|
||||||
|
extensionType := (int(rest[current]) << 8) + int(rest[current+1])
|
||||||
|
current += 2
|
||||||
|
|
||||||
|
extensionDataLength := (int(rest[current]) << 8) + int(rest[current+1])
|
||||||
|
current += 2
|
||||||
|
|
||||||
|
if extensionType == 0 {
|
||||||
|
|
||||||
|
// Skip over number of names as we're assuming there's just one
|
||||||
|
current += 2
|
||||||
|
|
||||||
|
nameType := rest[current]
|
||||||
|
current++
|
||||||
|
if nameType != 0 {
|
||||||
|
return "", errors.New("Not a hostname")
|
||||||
|
}
|
||||||
|
nameLen := (int(rest[current]) << 8) + int(rest[current+1])
|
||||||
|
current += 2
|
||||||
|
hostname = string(rest[current : current+nameLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
current += extensionDataLength
|
||||||
|
}
|
||||||
|
if hostname == "" {
|
||||||
|
return "", errors.New("No hostname")
|
||||||
|
}
|
||||||
|
return hostname, nil
|
||||||
|
|
||||||
|
}
|
Loading…
Reference in New Issue