package api

import (
	"context"
	"fmt"
	"io"
	"log"
	"sync"
	"time"

	"github.com/gorilla/websocket"

	"git.coolaj86.com/coolaj86/go-telebitd/packer"
)

var connectionID int64 = 0

// Connection track websocket and faciliates in and out data
type Connection struct {
	mutex sync.Mutex

	// The main connection table (should be just one of these created at startup)
	connectionTable *Table

	//used to track traffic for a domain.  Not use for lookup or validation only for tracking
	DomainTrack map[string]*DomainTrack

	// The websocket connection.
	conn *websocket.Conn

	// Buffered channel of outbound messages.
	send chan *SendTrack

	// WssState channel
	// Must check state via channel before xmit

	// Address of the Remote End Point
	source string

	// serverName -- Name of the server, at this point 1st domain registered.  Will likely change with JWT
	serverName string

	// bytes in
	bytesIn int64

	// bytes out
	bytesOut int64

	// Requests
	Requests int64 // TODO atomic

	// Response
	Responses int64 // TODO atomic

	// Connect Time
	connectTime time.Time

	//lastUpdate
	lastUpdate time.Time

	//initialDomains - a list of domains from the JWT
	initialDomains []string

	connectionTrack *Tracking

	///wssState tracks a highlevel status of the connection, false means do nothing.
	wssState bool

	//connectionID
	connectionID int64
}

//NewConnection -- Constructor
func NewConnection(connectionTable *Table, conn *websocket.Conn, remoteAddress string,
	initialDomains []string, connectionTrack *Tracking, serverName string) (p *Connection) {
	connectionID = connectionID + 1

	p = new(Connection)
	p.connectionTable = connectionTable
	p.conn = conn
	p.source = remoteAddress
	p.serverName = serverName
	p.bytesIn = 0
	p.bytesOut = 0
	p.Requests = 0
	p.Responses = 0
	p.send = make(chan *SendTrack)
	p.connectTime = time.Now()
	p.initialDomains = initialDomains
	p.connectionTrack = connectionTrack
	p.DomainTrack = make(map[string]*DomainTrack)
	p.lastUpdate = time.Now()

	for _, domain := range initialDomains {
		p.AddTrackedDomain(domain)
	}

	p.SetState(true)
	p.connectionID = connectionID
	return
}

//AddTrackedDomain -- Add a tracked domain
func (c *Connection) AddTrackedDomain(domain string) {
	p := new(DomainTrack)
	p.DomainName = domain
	c.DomainTrack[domain] = p
}

//ServerName -- Property
func (c *Connection) ServerName() string {
	return c.serverName
}

//SetServerName -- Setter
func (c *Connection) SetServerName(serverName string) {
	c.serverName = serverName
}

//InitialDomains -- Property
func (c *Connection) InitialDomains() []string {
	return c.initialDomains
}

//ConnectTime -- Property
func (c *Connection) ConnectTime() time.Time {
	return c.connectTime
}

//BytesIn -- Property
func (c *Connection) BytesIn() int64 {
	return c.bytesIn
}

//BytesOut -- Property
func (c *Connection) BytesOut() int64 {
	return c.bytesOut
}

//SendCh -- property to sending channel
func (c *Connection) SendCh() chan *SendTrack {
	return c.send
}

//Source --
func (c *Connection) Source() string {
	return c.source
}

func (c *Connection) addIn(num int64) {
	c.bytesIn = c.bytesIn + num
}

func (c *Connection) addOut(num int64) {
	c.bytesOut = c.bytesOut + num
}

func (c *Connection) addRequests() {
	// TODO atomic
	c.Requests++
}

func (c *Connection) addResponse() {
	// TODO atomic
	c.Responses++
}

//ConnectionTable -- property
func (c *Connection) ConnectionTable() *Table {
	return c.connectionTable
}

//State -- Get state of Socket...this is a high level state.
func (c *Connection) State() bool {
	c.mutex.Lock()
	defer c.mutex.Unlock()

	return c.wssState
}

//SetState -- Set the set of the high level connection
func (c *Connection) SetState(state bool) {
	c.mutex.Lock()
	defer c.mutex.Unlock()

	c.wssState = state
}

//Update -- updates the lastUpdate property tracking idle time
func (c *Connection) Update() {
	c.mutex.Lock()
	defer c.mutex.Unlock()

	c.lastUpdate = time.Now()
}

//LastUpdate -- retrieve last update
func (c *Connection) LastUpdate() time.Time {
	return c.lastUpdate
}

//ConnectionID - Get
func (c *Connection) ConnectionID() int64 {
	return c.connectionID
}

//NextWriter -- Wrapper to allow a high level state check before offering NextWriter
//The libary failes if client abends during write-cycle.  a fast moving write is not caught before socket state bubbles up
//A synchronised state is maintained
func (c *Connection) NextWriter(wssMessageType int) (io.WriteCloser, error) {
	if c.State() {
		return c.conn.NextWriter(wssMessageType)
	}

	// Is returning a nil error actually the proper thing to do here?
	log.Println("NextWriter aborted, state is not true")
	return nil, nil
}

//Write -- Wrapper to allow a high level state check before allowing a write to the socket.
func (c *Connection) Write(w io.WriteCloser, message []byte) (int, error) {
	if c.State() {
		return w.Write(message)
	}

	// Is returning a nil error actually the proper thing to do here?
	return 0, nil
}

//Reader -- export the reader function
func (c *Connection) Reader(ctx context.Context) {
	connectionTrack := c.connectionTrack

	defer func() {
		c.connectionTable.unregister <- c
		c.conn.Close()
		log.Println("reader defer", c)
	}()

	log.Println("Reader Start ", c)

	//c.conn.SetReadLimit(65535)
	for {
		_, message, err := c.conn.ReadMessage()

		//log.Println("ReadMessage", msgType, err)

		c.Update()

		if err != nil {
			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway) {
				c.SetState(false)
				log.Printf("error: %v", err)
			}
			break
		}

		// unpack the message.
		p, err := packer.ReadMessage(message)
		key := fmt.Sprintf("%s:%d", p.Address(), p.Port())
		track, err := connectionTrack.Lookup(key)

		//log.Println(hex.Dump(p.Data.Data()))

		if err != nil {
			//log.Println("Unable to locate Tracking for ", key)
			continue
		}

		//Support for tracking outbound traffic based on domain.
		if domainTrack, ok := c.DomainTrack[track.domain]; ok {
			//if ok then add to structure, else warn there is something wrong
			domainTrack.AddOut(int64(len(message)))
			domainTrack.AddResponses()
		}

		track.conn.Write(p.Data.Data())

		c.addIn(int64(len(message)))
		c.addResponse()
		//log.Println("end of read")
	}
}

//Writer -- expoer the writer function
func (c *Connection) Writer() {
	defer c.conn.Close()

	log.Println("Writer Start ", c)

	for {
		select {

		case message := <-c.send:
			w, err := c.NextWriter(websocket.BinaryMessage)
			log.Println("next writer ", w)
			if err != nil {
				c.SetState(false)
				return
			}

			c.Update()

			_, err = c.Write(w, message.data)

			if err := w.Close(); err != nil {
				return
			}

			messageLen := int64(len(message.data))

			c.addOut(messageLen)
			c.addRequests()

			//Support for tracking outbound traffic based on domain.
			if domainTrack, ok := c.DomainTrack[message.domain]; ok {
				//if ok then add to structure, else warn there is something wrong
				domainTrack.AddIn(messageLen)
				domainTrack.AddRequests()
				log.Println("adding ", messageLen, " to ", message.domain)
			} else {
				log.Println("attempting to add bytes to ", message.domain, "it does not exist")
				log.Println("dt", c.DomainTrack)
			}
			log.Println(c)
		}
	}
}