diff --git a/rvpn/genericlistener/conn_tracking.go b/rvpn/genericlistener/conn_tracking.go index 87e129c..3cc6a34 100644 --- a/rvpn/genericlistener/conn_tracking.go +++ b/rvpn/genericlistener/conn_tracking.go @@ -1,8 +1,11 @@ package genericlistener -import "net" -import "context" -import "fmt" +import ( + "context" + "fmt" + "net" + "sync" +) //Track -- used to track connection + domain type Track struct { @@ -20,6 +23,7 @@ func NewTrack(conn net.Conn, domain string) (p *Track) { //Tracking -- type Tracking struct { + mutex *sync.Mutex connections map[string]*Track register chan *Track unregister chan net.Conn @@ -28,6 +32,7 @@ type Tracking struct { //NewTracking -- Constructor func NewTracking() (p *Tracking) { p = new(Tracking) + p.mutex = &sync.Mutex{} p.connections = make(map[string]*Track) p.register = make(chan *Track) p.unregister = make(chan net.Conn) @@ -46,18 +51,22 @@ func (p *Tracking) Run(ctx context.Context) { return case connection := <-p.register: + p.mutex.Lock() key := connection.conn.RemoteAddr().String() loginfo.Println("register fired", key) p.connections[key] = connection p.list() + p.mutex.Unlock() case connection := <-p.unregister: + p.mutex.Lock() key := connection.RemoteAddr().String() loginfo.Println("unregister fired", key) if _, ok := p.connections[key]; ok { delete(p.connections, key) } p.list() + p.mutex.Unlock() } } } @@ -71,6 +80,11 @@ func (p *Tracking) list() { //Lookup -- // - get connection from key func (p *Tracking) Lookup(key string) (c *Track, err error) { + defer func() { + p.mutex.Unlock() + }() + p.mutex.Lock() + if _, ok := p.connections[key]; ok { c = p.connections[key] } else {