Implemented domain tracking for external connections
- system now tracks both in and out bytes - various clean up.
This commit is contained in:
parent
d7e01e8b40
commit
f3bb9cb584
18
main.go
18
main.go
|
@ -24,7 +24,6 @@ var (
|
|||
argServerExternalBinding string
|
||||
argDeadTime int
|
||||
connectionTable *genericlistener.Table
|
||||
wssMapping *genericlistener.WssMapping
|
||||
secretKey = "abc123"
|
||||
)
|
||||
|
||||
|
@ -72,24 +71,9 @@ func main() {
|
|||
go connectionTable.Run(ctx)
|
||||
|
||||
genericListeners := genericlistener.NewGenerListeners(ctx, connectionTable, connectionTracking, secretKey, certbundle, argDeadTime)
|
||||
go genericListeners.Run(ctx, 8443)
|
||||
|
||||
//go genericlistener.GenericListenAndServe(ctx, connectionTable, secretKey, argGenericBinding, certbundle, argDeadTime)
|
||||
go genericListeners.Run(ctx, 9999)
|
||||
|
||||
//Run for 10 minutes and then shutdown cleanly
|
||||
time.Sleep(600 * time.Second)
|
||||
cancelContext()
|
||||
|
||||
//wssMapping = xlate.NewwssMapping()
|
||||
//go wssMapping.Run()
|
||||
|
||||
//go client.LaunchClientListener(connectionTable, &secretKey, &argServerBinding)
|
||||
//go external.LaunchWebRequestExternalListener(&argServerExternalBinding, connectionTable)
|
||||
//go external.LaunchExternalServer(argServerExternalBinding, connectionTable)
|
||||
//err = admin.LaunchAdminListener(&argServerAdminBinding, connectionTable)
|
||||
//if err != nil {
|
||||
// loginfo.Println("LauchAdminListener failed: ", err)
|
||||
//}
|
||||
|
||||
//genericlistener.LaunchWssListener(connectionTable, secretKey, argWssClientListener, "certs/fullchain.pem", "certs/privkey.pem")
|
||||
}
|
||||
|
|
|
@ -4,18 +4,32 @@ import "net"
|
|||
import "context"
|
||||
import "fmt"
|
||||
|
||||
//Track -- used to track connection + domain
|
||||
type Track struct {
|
||||
conn net.Conn
|
||||
domain string
|
||||
}
|
||||
|
||||
//NewTrack -- Constructor
|
||||
func NewTrack(conn net.Conn, domain string) (p *Track) {
|
||||
p = new(Track)
|
||||
p.conn = conn
|
||||
p.domain = domain
|
||||
return
|
||||
}
|
||||
|
||||
//Tracking --
|
||||
type Tracking struct {
|
||||
connections map[string]net.Conn
|
||||
register chan net.Conn
|
||||
connections map[string]*Track
|
||||
register chan *Track
|
||||
unregister chan net.Conn
|
||||
}
|
||||
|
||||
//NewTracking -- Constructor
|
||||
func NewTracking() (p *Tracking) {
|
||||
p = new(Tracking)
|
||||
p.connections = make(map[string]net.Conn)
|
||||
p.register = make(chan net.Conn)
|
||||
p.connections = make(map[string]*Track)
|
||||
p.register = make(chan *Track)
|
||||
p.unregister = make(chan net.Conn)
|
||||
return
|
||||
}
|
||||
|
@ -32,7 +46,7 @@ func (p *Tracking) Run(ctx context.Context) {
|
|||
return
|
||||
|
||||
case connection := <-p.register:
|
||||
key := connection.RemoteAddr().String()
|
||||
key := connection.conn.RemoteAddr().String()
|
||||
loginfo.Println("register fired", key)
|
||||
p.connections[key] = connection
|
||||
p.list()
|
||||
|
@ -40,7 +54,6 @@ func (p *Tracking) Run(ctx context.Context) {
|
|||
case connection := <-p.unregister:
|
||||
key := connection.RemoteAddr().String()
|
||||
loginfo.Println("unregister fired", key)
|
||||
p.connections[key] = connection
|
||||
if _, ok := p.connections[key]; ok {
|
||||
delete(p.connections, key)
|
||||
}
|
||||
|
@ -57,7 +70,7 @@ func (p *Tracking) list() {
|
|||
|
||||
//Lookup --
|
||||
// - get connection from key
|
||||
func (p *Tracking) Lookup(key string) (c net.Conn, err error) {
|
||||
func (p *Tracking) Lookup(key string) (c *Track, err error) {
|
||||
if _, ok := p.connections[key]; ok {
|
||||
c = p.connections[key]
|
||||
} else {
|
||||
|
|
|
@ -204,7 +204,6 @@ func (c *Connection) Reader(ctx context.Context) {
|
|||
msgType, message, err := c.conn.ReadMessage()
|
||||
|
||||
loginfo.Println("ReadMessage", msgType, err)
|
||||
loginfo.Println(hex.Dump(message))
|
||||
|
||||
c.Update()
|
||||
|
||||
|
@ -219,14 +218,22 @@ func (c *Connection) Reader(ctx context.Context) {
|
|||
// unpack the message.
|
||||
p, err := packer.ReadMessage(message)
|
||||
key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port)
|
||||
test, err := connectionTrack.Lookup(key)
|
||||
track, err := connectionTrack.Lookup(key)
|
||||
|
||||
loginfo.Println(hex.Dump(p.Data.Data()))
|
||||
|
||||
if err != nil {
|
||||
loginfo.Println("Unable to locate Tracking for ", key)
|
||||
continue
|
||||
}
|
||||
|
||||
test.Write(p.Data.Data())
|
||||
//Support for tracking outbound traffic based on domain.
|
||||
if domainTrack, ok := c.DomainTrack[track.domain]; ok {
|
||||
//if ok then add to structure, else warn there is something wrong
|
||||
domainTrack.AddIn(int64(len(message)))
|
||||
}
|
||||
|
||||
track.conn.Write(p.Data.Data())
|
||||
|
||||
c.addIn(int64(len(message)))
|
||||
loginfo.Println("end of read")
|
||||
|
|
|
@ -66,7 +66,7 @@ func (c *Table) reaper(delay int, idle int) {
|
|||
func (c *Table) Run(ctx context.Context) {
|
||||
loginfo.Println("ConnectionTable starting")
|
||||
|
||||
go c.reaper(300, 60)
|
||||
go c.reaper(3000, 60)
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
|
@ -1,112 +0,0 @@
|
|||
package genericlistener
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
//LaunchWssListener - obtains a onetime connection from wedge listener
|
||||
func LaunchWssListener(connectionTable *connection.Table, secretKey string, serverBind string, certfile string, keyfile string) (err error) {
|
||||
loginfo.Println("starting LaunchWssListener ")
|
||||
|
||||
router := mux.NewRouter().StrictSlash(true)
|
||||
|
||||
router.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||
loginfo.Println("HandleFunc /")
|
||||
switch url := r.URL.Path; url {
|
||||
case "/":
|
||||
// check to see if we are using the administrative Host
|
||||
if strings.Contains(r.Host, "rvpn.daplie.invalid") {
|
||||
http.Redirect(w, r, "/admin", 301)
|
||||
}
|
||||
|
||||
handleConnectionWebSocket(connectionTable, w, r, secretKey, false)
|
||||
|
||||
default:
|
||||
http.Error(w, "Not Found", 404)
|
||||
}
|
||||
})
|
||||
|
||||
router.HandleFunc("/admin", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "Welcome!")
|
||||
})
|
||||
|
||||
router.HandleFunc("/api/servers", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Println("here")
|
||||
serverContainer := admin.NewServerAPIContainer()
|
||||
|
||||
for c := range connectionTable.Connections() {
|
||||
serverAPI := admin.NewServerAPI(c)
|
||||
serverContainer.Servers = append(serverContainer.Servers, serverAPI)
|
||||
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
|
||||
json.NewEncoder(w).Encode(serverContainer)
|
||||
|
||||
})
|
||||
|
||||
s := &http.Server{
|
||||
Addr: serverBind,
|
||||
Handler: router,
|
||||
}
|
||||
|
||||
err = s.ListenAndServeTLS(certfile, keyfile)
|
||||
if err != nil {
|
||||
loginfo.Println("ListenAndServeTLS: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// handleConnectionWebSocket handles websocket requests from the peer.
|
||||
func handleConnectionWebSocket(connectionTable *connection.Table, w http.ResponseWriter, r *http.Request, secretKey string, admin bool) {
|
||||
loginfo.Println("websocket opening ", r.RemoteAddr, " ", r.Host)
|
||||
|
||||
tokenString := r.URL.Query().Get("access_token")
|
||||
result, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secretKey), nil
|
||||
})
|
||||
|
||||
if err != nil || !result.Valid {
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
w.Write([]byte("Not Authorized"))
|
||||
loginfo.Println("access_token invalid...closing connection")
|
||||
return
|
||||
}
|
||||
|
||||
loginfo.Println("help access_token valid")
|
||||
|
||||
claims := result.Claims.(jwt.MapClaims)
|
||||
domains, ok := claims["domains"].([]interface{})
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
loginfo.Println("WebSocket upgrade failed", err)
|
||||
return
|
||||
}
|
||||
|
||||
loginfo.Println("before connection table")
|
||||
|
||||
//newConnection := connection.NewConnection(connectionTable, conn, r.RemoteAddr, domains)
|
||||
|
||||
newRegistration := connection.NewRegistration(conn, r.RemoteAddr, domains)
|
||||
connectionTable.Register() <- newRegistration
|
||||
ok = <-newRegistration.CommCh()
|
||||
if !ok {
|
||||
loginfo.Println("connection registration failed ", newRegistration)
|
||||
return
|
||||
}
|
||||
|
||||
loginfo.Println("connection registration accepted ", newRegistration)
|
||||
}
|
|
@ -1 +0,0 @@
|
|||
package genericlistener
|
|
@ -229,7 +229,6 @@ func handleStream(ctx context.Context, wConn *WedgeConn) {
|
|||
// - get a wConn and start processing requests
|
||||
func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
||||
connectionTracking := ctx.Value(ctxConnectionTrack).(*Tracking)
|
||||
connectionTracking.register <- extConn
|
||||
|
||||
defer func() {
|
||||
connectionTracking.unregister <- extConn
|
||||
|
@ -277,6 +276,9 @@ func handleExternalHTTPRequest(ctx context.Context, extConn net.Conn) {
|
|||
return
|
||||
}
|
||||
|
||||
track := NewTrack(extConn, hostname)
|
||||
connectionTracking.register <- track
|
||||
|
||||
loginfo.Println("Domain Accepted", conn, rAddr, rPort)
|
||||
p := packer.NewPacker()
|
||||
p.Header.SetAddress(rAddr)
|
||||
|
|
|
@ -1,58 +0,0 @@
|
|||
package genericlistener
|
||||
|
||||
import "golang.org/x/net/websocket"
|
||||
|
||||
type domain string
|
||||
|
||||
//WssRegistration --
|
||||
type WssRegistration struct {
|
||||
domainName domain
|
||||
connection *websocket.Conn
|
||||
}
|
||||
|
||||
//WssMapping --
|
||||
type WssMapping struct {
|
||||
register chan *websocket.Conn
|
||||
unregister chan *websocket.Conn
|
||||
domainRegister chan *WssRegistration
|
||||
domainUnregister chan *WssRegistration
|
||||
connections map[*websocket.Conn][]domain
|
||||
domains map[domain]*websocket.Conn
|
||||
}
|
||||
|
||||
//NewwssMapping -- constructor
|
||||
func NewwssMapping() (p *WssMapping) {
|
||||
p = new(WssMapping)
|
||||
p.connections = make(map[*websocket.Conn][]domain)
|
||||
return
|
||||
}
|
||||
|
||||
//Run -- Execute
|
||||
func (c *WssMapping) Run() {
|
||||
loginfo.Println("WSSMapping starting")
|
||||
for {
|
||||
select {
|
||||
case wssConn := <-c.register:
|
||||
loginfo.Println("register fired")
|
||||
c.connections[wssConn] = make([]domain, initialDomains)
|
||||
|
||||
for conn := range c.connections {
|
||||
loginfo.Println(conn)
|
||||
}
|
||||
|
||||
case wssConn := <-c.unregister:
|
||||
loginfo.Println("closing connection ", wssConn)
|
||||
if _, ok := c.connections[wssConn]; ok {
|
||||
delete(c.connections, wssConn)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// register a wss connection first -- initialize the domain slice
|
||||
// add a domain
|
||||
// find the connectino add to the slice.
|
||||
// find the domain set the connection in the map.
|
||||
|
||||
// domain(s) -> connection
|
||||
// connection -> domains
|
Loading…
Reference in New Issue