Implemented domain tracking for external connections

- system now tracks both in and out bytes
- various clean up.
This commit is contained in:
Henry Camacho 2017-03-09 21:38:23 -06:00
parent d7e01e8b40
commit f3bb9cb584
8 changed files with 35 additions and 200 deletions

18
main.go
View File

@ -24,7 +24,6 @@ var (
argServerExternalBinding string argServerExternalBinding string
argDeadTime int argDeadTime int
connectionTable *genericlistener.Table connectionTable *genericlistener.Table
wssMapping *genericlistener.WssMapping
secretKey = "abc123" secretKey = "abc123"
) )
@ -72,24 +71,9 @@ func main() {
go connectionTable.Run(ctx) go connectionTable.Run(ctx)
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime) genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime)
go genericListeners.Run(ctx, 8443) go genericListeners.Run(ctx, 9999)
//go genericlistener.GenericListenAndServe(ctx, connectionTable, secretKey, argGenericBinding, certbundle, argDeadTime)
//Run for 10 minutes and then shutdown cleanly //Run for 10 minutes and then shutdown cleanly
time.Sleep(600 * time.Second) time.Sleep(600 * time.Second)
cancelContext() cancelContext()
//wssMapping = xlate.NewwssMapping()
//go wssMapping.Run()
//go client.LaunchClientListener(connectionTable, &secretKey, &argServerBinding)
//go external.LaunchWebRequestExternalListener(&argServerExternalBinding, connectionTable)
//go external.LaunchExternalServer(argServerExternalBinding, connectionTable)
//err = admin.LaunchAdminListener(&argServerAdminBinding, connectionTable)
//if err != nil {
// loginfo.Println("LauchAdminListener failed: ", err)
//}
//genericlistener.LaunchWssListener(connectionTable, secretKey, argWssClientListener, "certs/fullchain.pem", "certs/privkey.pem")
} }

View File

@ -4,18 +4,32 @@ import "net"
import "context" import "context"
import "fmt" import "fmt"
//Track -- used to track connection + domain
type Track struct {
conn net.Conn
domain string
}
//NewTrack -- Constructor
func NewTrack(conn net.Conn, domain string) (p *Track) {
p = new(Track)
p.conn = conn
p.domain = domain
return
}
//Tracking -- //Tracking --
type Tracking struct { type Tracking struct {
connections map[string]net.Conn connections map[string]*Track
register chan net.Conn register chan *Track
unregister chan net.Conn unregister chan net.Conn
} }
//NewTracking -- Constructor //NewTracking -- Constructor
func NewTracking() (p *Tracking) { func NewTracking() (p *Tracking) {
p = new(Tracking) p = new(Tracking)
p.connections = make(map[string]net.Conn) p.connections = make(map[string]*Track)
p.register = make(chan net.Conn) p.register = make(chan *Track)
p.unregister = make(chan net.Conn) p.unregister = make(chan net.Conn)
return return
} }
@ -32,7 +46,7 @@ func (p *Tracking) Run(ctx context.Context) {
return return
case connection := <-p.register: case connection := <-p.register:
key := connection.RemoteAddr().String() key := connection.conn.RemoteAddr().String()
loginfo.Println("register fired", key) loginfo.Println("register fired", key)
p.connections[key] = connection p.connections[key] = connection
p.list() p.list()
@ -40,7 +54,6 @@ func (p *Tracking) Run(ctx context.Context) {
case connection := <-p.unregister: case connection := <-p.unregister:
key := connection.RemoteAddr().String() key := connection.RemoteAddr().String()
loginfo.Println("unregister fired", key) loginfo.Println("unregister fired", key)
p.connections[key] = connection
if _, ok := p.connections[key]; ok { if _, ok := p.connections[key]; ok {
delete(p.connections, key) delete(p.connections, key)
} }
@ -57,7 +70,7 @@ func (p *Tracking) list() {
//Lookup -- //Lookup --
// - get connection from key // - get connection from key
func (p *Tracking) Lookup(key string) (c net.Conn, err error) { func (p *Tracking) Lookup(key string) (c *Track, err error) {
if _, ok := p.connections[key]; ok { if _, ok := p.connections[key]; ok {
c = p.connections[key] c = p.connections[key]
} else { } else {

View File

@ -204,7 +204,6 @@ func (c *Connection) Reader(ctx context.Context) {
msgType, message, err := c.conn.ReadMessage() msgType, message, err := c.conn.ReadMessage()
loginfo.Println("ReadMessage", msgType, err) loginfo.Println("ReadMessage", msgType, err)
loginfo.Println(hex.Dump(message))
c.Update() c.Update()
@ -219,14 +218,22 @@ func (c *Connection) Reader(ctx context.Context) {
// unpack the message. // unpack the message.
p, err := packer.ReadMessage(message) p, err := packer.ReadMessage(message)
key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port) key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port)
test, err := connectionTrack.Lookup(key) track, err := connectionTrack.Lookup(key)
loginfo.Println(hex.Dump(p.Data.Data()))
if err != nil { if err != nil {
loginfo.Println("Unable to locate Tracking for ", key) loginfo.Println("Unable to locate Tracking for ", key)
continue continue
} }
test.Write(p.Data.Data()) //Support for tracking outbound traffic based on domain.
if domainTrack, ok := c.DomainTrack[track.domain]; ok {
//if ok then add to structure, else warn there is something wrong
domainTrack.AddIn(int64(len(message)))
}
track.conn.Write(p.Data.Data())
c.addIn(int64(len(message))) c.addIn(int64(len(message)))
loginfo.Println("end of read") loginfo.Println("end of read")

View File

@ -66,7 +66,7 @@ func (c *Table) reaper(delay int, idle int) {
func (c *Table) Run(ctx context.Context) { func (c *Table) Run(ctx context.Context) {
loginfo.Println("ConnectionTable starting") loginfo.Println("ConnectionTable starting")
go c.reaper(300, 60) go c.reaper(3000, 60)
for { for {
select { select {

View File

@ -1,112 +0,0 @@
package genericlistener
import (
"encoding/json"
"fmt"
"net/http"
"strings"
jwt "github.com/dgrijalva/jwt-go"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
)
//LaunchWssListener - obtains a onetime connection from wedge listener
func LaunchWssListener(connectionTable *connection.Table, secretKey string, serverBind string, certfile string, keyfile string) (err error) {
loginfo.Println("starting LaunchWssListener ")
router := mux.NewRouter().StrictSlash(true)
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
loginfo.Println("HandleFunc /")
switch url := r.URL.Path; url {
case "/":
// check to see if we are using the administrative Host
if strings.Contains(r.Host, "rvpn.daplie.invalid") {
http.Redirect(w, r, "/admin", 301)
}
handleConnectionWebSocket(connectionTable, w, r, secretKey, false)
default:
http.Error(w, "Not Found", 404)
}
})
router.HandleFunc("/admin", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "Welcome!")
})
router.HandleFunc("/api/servers", func(w http.ResponseWriter, r *http.Request) {
fmt.Println("here")
serverContainer := admin.NewServerAPIContainer()
for c := range connectionTable.Connections() {
serverAPI := admin.NewServerAPI(c)
serverContainer.Servers = append(serverContainer.Servers, serverAPI)
}
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
json.NewEncoder(w).Encode(serverContainer)
})
s := &http.Server{
Addr: serverBind,
Handler: router,
}
err = s.ListenAndServeTLS(certfile, keyfile)
if err != nil {
loginfo.Println("ListenAndServeTLS: ", err)
}
return
}
// handleConnectionWebSocket handles websocket requests from the peer.
func handleConnectionWebSocket(connectionTable *connection.Table, w http.ResponseWriter, r *http.Request, secretKey string, admin bool) {
loginfo.Println("websocket opening ", r.RemoteAddr, " ", r.Host)
tokenString := r.URL.Query().Get("access_token")
result, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(secretKey), nil
})
if err != nil || !result.Valid {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("Not Authorized"))
loginfo.Println("access_token invalid...closing connection")
return
}
loginfo.Println("help access_token valid")
claims := result.Claims.(jwt.MapClaims)
domains, ok := claims["domains"].([]interface{})
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
loginfo.Println("WebSocket upgrade failed", err)
return
}
loginfo.Println("before connection table")
//newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains)
newRegistration := connection.NewRegistration(conn, r.RemoteAddr, domains)
connectionTable.Register() <- newRegistration
ok = <-newRegistration.CommCh()
if !ok {
loginfo.Println("connection registration failed ", newRegistration)
return
}
loginfo.Println("connection registration accepted ", newRegistration)
}

View File

@ -1 +0,0 @@
package genericlistener

View File

@ -229,7 +229,6 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
// - get a wConn and start processing requests // - get a wConn and start processing requests
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) { func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking) connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
connectionTracking.register <- extConn
defer func() { defer func() {
connectionTracking.unregister <- extConn connectionTracking.unregister <- extConn
@ -277,6 +276,9 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
return return
} }
track := NewTrack(extConn, hostname)
connectionTracking.register <- track
loginfo.Println("Domain Accepted", conn, rAddr, rPort) loginfo.Println("Domain Accepted", conn, rAddr, rPort)
p := packer.NewPacker() p := packer.NewPacker()
p.Header.SetAddress(rAddr) p.Header.SetAddress(rAddr)

View File

@ -1,58 +0,0 @@
package genericlistener
import "golang.org/x/net/websocket"
type domain string
//WssRegistration --
type WssRegistration struct {
domainName domain
connection *websocket.Conn
}
//WssMapping --
type WssMapping struct {
register chan *websocket.Conn
unregister chan *websocket.Conn
domainRegister chan *WssRegistration
domainUnregister chan *WssRegistration
connections map[*websocket.Conn][]domain
domains map[domain]*websocket.Conn
}
//NewwssMapping -- constructor
func NewwssMapping() (p *WssMapping) {
p = new(WssMapping)
p.connections = make(map[*websocket.Conn][]domain)
return
}
//Run -- Execute
func (c *WssMapping) Run() {
loginfo.Println("WSSMapping starting")
for {
select {
case wssConn := <-c.register:
loginfo.Println("register fired")
c.connections[wssConn] = make([]domain, initialDomains)
for conn := range c.connections {
loginfo.Println(conn)
}
case wssConn := <-c.unregister:
loginfo.Println("closing connection ", wssConn)
if _, ok := c.connections[wssConn]; ok {
delete(c.connections, wssConn)
}
}
}
}
// register a wss connection first -- initialize the domain slice
// add a domain
// find the connectino add to the slice.
// find the domain set the connection in the map.
// domain(s) -> connection
// connection -> domains