Merge branch 'client'
This commit is contained in:
commit
01de157cfe
|
@ -0,0 +1,208 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
jwt "github.com/dgrijalva/jwt-go"
|
||||||
|
flag "github.com/spf13/pflag"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
|
||||||
|
"git.daplie.com/Daplie/go-rvpn-server/rvpn/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
var httpRegexp = regexp.MustCompile(`(?i)^http`)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
flag.StringSlice("locals", []string{}, "comma separated list of <proto>:<port> or "+
|
||||||
|
"<proto>:<hostname>:<port> to which matching incoming connections should forward. "+
|
||||||
|
"Ex: smtps:8465,https:example.com:8443")
|
||||||
|
flag.StringSlice("domains", []string{}, "comma separated list of domain names to set to the tunnel")
|
||||||
|
viper.BindPFlag("locals", flag.Lookup("locals"))
|
||||||
|
viper.BindPFlag("domains", flag.Lookup("domains"))
|
||||||
|
|
||||||
|
flag.BoolP("insecure", "k", false, "Allow TLS connections to stunneld without valid certs")
|
||||||
|
flag.String("stunneld", "", "the domain (or ip address) at which the RVPN server is running")
|
||||||
|
flag.String("secret", "", "the same secret used by stunneld (used for JWT authentication)")
|
||||||
|
flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
|
||||||
|
viper.BindPFlag("raw.insecure", flag.Lookup("insecure"))
|
||||||
|
viper.BindPFlag("raw.stunneld", flag.Lookup("stunneld"))
|
||||||
|
viper.BindPFlag("raw.secret", flag.Lookup("secret"))
|
||||||
|
viper.BindPFlag("raw.token", flag.Lookup("token"))
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxy struct {
|
||||||
|
protocol string
|
||||||
|
hostname string
|
||||||
|
port int
|
||||||
|
}
|
||||||
|
|
||||||
|
func addLocals(proxies []proxy, location string) []proxy {
|
||||||
|
parts := strings.Split(location, ":")
|
||||||
|
if len(parts) > 3 {
|
||||||
|
panic(fmt.Sprintf("provided invalid location %q", location))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If all that was provided as a "local" is the domain name we assume that domain
|
||||||
|
// has HTTP and HTTPS handlers on the default ports.
|
||||||
|
if len(parts) == 1 {
|
||||||
|
proxies = append(proxies, proxy{"http", parts[0], 80})
|
||||||
|
proxies = append(proxies, proxy{"https", parts[0], 443})
|
||||||
|
return proxies
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make everything lower case and trim any slashes in something like https://john.example.com
|
||||||
|
parts[0] = strings.ToLower(parts[0])
|
||||||
|
parts[1] = strings.ToLower(strings.Trim(parts[1], "/"))
|
||||||
|
|
||||||
|
if len(parts) == 2 {
|
||||||
|
if strings.Contains(parts[1], ".") {
|
||||||
|
if parts[0] == "http" {
|
||||||
|
parts = append(parts, "80")
|
||||||
|
} else if parts[0] == "https" {
|
||||||
|
parts = append(parts, "443")
|
||||||
|
} else {
|
||||||
|
panic(fmt.Sprintf("port must be specified for %q", location))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// https:3443 -> https:*:3443
|
||||||
|
parts = []string{parts[0], "*", parts[1]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if port, err := strconv.Atoi(parts[2]); err != nil {
|
||||||
|
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
|
||||||
|
} else if port <= 0 || port > 65535 {
|
||||||
|
panic(fmt.Sprintf("%d is an invalid port for local services", port))
|
||||||
|
} else {
|
||||||
|
proxies = append(proxies, proxy{parts[0], parts[1], port})
|
||||||
|
}
|
||||||
|
return proxies
|
||||||
|
}
|
||||||
|
|
||||||
|
func addDomains(proxies []proxy, location string) []proxy {
|
||||||
|
parts := strings.Split(location, ":")
|
||||||
|
if len(parts) > 3 {
|
||||||
|
panic(fmt.Sprintf("provided invalid location %q", location))
|
||||||
|
} else if len(parts) == 2 {
|
||||||
|
panic("invalid argument for --domains, use format <domainname> or <scheme>:<domainname>:<local-port>")
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the scheme and port weren't provided use the zero values
|
||||||
|
if len(parts) == 1 {
|
||||||
|
return append(proxies, proxy{"", parts[0], 0})
|
||||||
|
}
|
||||||
|
|
||||||
|
if port, err := strconv.Atoi(parts[2]); err != nil {
|
||||||
|
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
|
||||||
|
} else if port <= 0 || port > 65535 {
|
||||||
|
panic(fmt.Sprintf("%d is an invalid port for local services", port))
|
||||||
|
} else {
|
||||||
|
proxies = append(proxies, proxy{parts[0], parts[1], port})
|
||||||
|
}
|
||||||
|
return proxies
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractServicePorts(proxies []proxy) map[string]map[string]int {
|
||||||
|
result := make(map[string]map[string]int, 2)
|
||||||
|
|
||||||
|
for _, p := range proxies {
|
||||||
|
if p.protocol != "" && p.port != 0 {
|
||||||
|
hostPorts := result[p.protocol]
|
||||||
|
if hostPorts == nil {
|
||||||
|
result[p.protocol] = make(map[string]int)
|
||||||
|
hostPorts = result[p.protocol]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only HTTP and HTTPS allow us to determine the hostname from the request, so only
|
||||||
|
// those protocols support different ports for the same service.
|
||||||
|
if !httpRegexp.MatchString(p.protocol) || p.hostname == "" {
|
||||||
|
p.hostname = "*"
|
||||||
|
}
|
||||||
|
if port, ok := hostPorts[p.hostname]; ok && port != p.port {
|
||||||
|
panic(fmt.Sprintf("duplicate ports for %s://%s", p.protocol, p.hostname))
|
||||||
|
}
|
||||||
|
hostPorts[p.hostname] = p.port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make sure we have defaults for HTTPS and HTTP.
|
||||||
|
if result["https"] == nil {
|
||||||
|
result["https"] = make(map[string]int, 1)
|
||||||
|
}
|
||||||
|
if result["https"]["*"] == 0 {
|
||||||
|
result["https"]["*"] = 8443
|
||||||
|
}
|
||||||
|
|
||||||
|
if result["http"] == nil {
|
||||||
|
result["http"] = make(map[string]int, 1)
|
||||||
|
}
|
||||||
|
if result["http"]["*"] == 0 {
|
||||||
|
result["http"]["*"] = result["https"]["*"]
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
proxies := make([]proxy, 0)
|
||||||
|
for _, option := range viper.GetStringSlice("locals") {
|
||||||
|
for _, location := range strings.Split(option, ",") {
|
||||||
|
proxies = addLocals(proxies, location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, option := range viper.GetStringSlice("domains") {
|
||||||
|
for _, location := range strings.Split(option, ",") {
|
||||||
|
proxies = addDomains(proxies, location)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
servicePorts := extractServicePorts(proxies)
|
||||||
|
domainMap := make(map[string]bool)
|
||||||
|
for _, p := range proxies {
|
||||||
|
if p.hostname != "" && p.hostname != "*" {
|
||||||
|
domainMap[p.hostname] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if viper.GetString("raw.stunneld") == "" {
|
||||||
|
panic("must provide remote RVPN server to connect to")
|
||||||
|
}
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if viper.GetString("raw.token") != "" {
|
||||||
|
token = viper.GetString("raw.token")
|
||||||
|
} else if viper.GetString("raw.secret") != "" {
|
||||||
|
domains := make([]string, 0, len(domainMap))
|
||||||
|
for name := range domainMap {
|
||||||
|
domains = append(domains, name)
|
||||||
|
}
|
||||||
|
tokenData := jwt.MapClaims{"domains": domains}
|
||||||
|
|
||||||
|
secret := []byte(viper.GetString("raw.secret"))
|
||||||
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
|
||||||
|
if tokenStr, err := jwtToken.SignedString(secret); err != nil {
|
||||||
|
panic(err)
|
||||||
|
} else {
|
||||||
|
token = tokenStr
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
panic("must provide either token or secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, quit := context.WithCancel(context.Background())
|
||||||
|
defer quit()
|
||||||
|
|
||||||
|
config := client.Config{
|
||||||
|
Insecure: viper.GetBool("raw.insecure"),
|
||||||
|
Server: viper.GetString("raw.stunneld"),
|
||||||
|
Services: servicePorts,
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
panic(client.Run(ctx, &config))
|
||||||
|
}
|
|
@ -0,0 +1,71 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The Config struct holds all of the information needed to establish and handle a connection
|
||||||
|
// with the RVPN server.
|
||||||
|
type Config struct {
|
||||||
|
Server string
|
||||||
|
Token string
|
||||||
|
Insecure bool
|
||||||
|
Services map[string]map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run establishes a connection with the RVPN server specified in the config. If the first attempt
|
||||||
|
// to connect fails it is assumed that something is wrong with the authentication and it will
|
||||||
|
// return an error. Otherwise it will continuously attempt to reconnect whenever the connection
|
||||||
|
// is broken.
|
||||||
|
func Run(ctx context.Context, config *Config) error {
|
||||||
|
serverURL, err := url.Parse(config.Server)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Invalid server URL: %v", err)
|
||||||
|
}
|
||||||
|
if serverURL.Scheme == "" {
|
||||||
|
serverURL.Scheme = "wss"
|
||||||
|
}
|
||||||
|
serverURL.Path = ""
|
||||||
|
|
||||||
|
query := make(url.Values)
|
||||||
|
query.Set("access_token", config.Token)
|
||||||
|
serverURL.RawQuery = query.Encode()
|
||||||
|
|
||||||
|
dialer := websocket.Dialer{}
|
||||||
|
if config.Insecure {
|
||||||
|
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, portList := range config.Services {
|
||||||
|
if _, ok := portList["*"]; !ok {
|
||||||
|
return fmt.Errorf(`service %s missing port for "*"`, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handler := NewWsHandler(config.Services)
|
||||||
|
|
||||||
|
authenticated := false
|
||||||
|
for {
|
||||||
|
if conn, _, err := dialer.Dial(serverURL.String(), nil); err == nil {
|
||||||
|
loginfo.Println("connected to remote server")
|
||||||
|
authenticated = true
|
||||||
|
handler.HandleConn(ctx, conn)
|
||||||
|
} else if !authenticated {
|
||||||
|
return fmt.Errorf("First connection to server failed - check auth: %v", err)
|
||||||
|
}
|
||||||
|
loginfo.Println("disconnected from remote server")
|
||||||
|
|
||||||
|
// Sleep for a few seconds before trying again, but only if the context is still active
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
}
|
||||||
|
loginfo.Println("attempting reconnect to remote server")
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,15 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
logFlags = log.Ldate | log.Lmicroseconds | log.Lshortfile
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
loginfo = log.New(os.Stdout, "INFO: client: ", logFlags)
|
||||||
|
logdebug = log.New(os.Stdout, "DEBUG: client:", logFlags)
|
||||||
|
)
|
|
@ -0,0 +1,233 @@
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
|
"git.daplie.com/Daplie/go-rvpn-server/rvpn/packer"
|
||||||
|
"git.daplie.com/Daplie/go-rvpn-server/rvpn/sni"
|
||||||
|
)
|
||||||
|
|
||||||
|
var hostRegexp = regexp.MustCompile(`(?im)(?:^|[\r\n])Host: *([^\r\n]+)[\r\n]`)
|
||||||
|
|
||||||
|
// WsHandler handles all of reading and writing for the websocket connection to the RVPN server
|
||||||
|
// and the TCP connections to the local servers.
|
||||||
|
type WsHandler struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
localConns map[string]net.Conn
|
||||||
|
|
||||||
|
servicePorts map[string]map[string]int
|
||||||
|
|
||||||
|
ctx context.Context
|
||||||
|
dataChan chan *packer.Packer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWsHandler creates a new handler ready to be given a websocket connection. The services
|
||||||
|
// argument specifies what port each service type should be directed to on the local interface.
|
||||||
|
func NewWsHandler(services map[string]map[string]int) *WsHandler {
|
||||||
|
h := new(WsHandler)
|
||||||
|
h.servicePorts = services
|
||||||
|
h.localConns = make(map[string]net.Conn)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleConn handles all of the traffic on the provided websocket connection. The function
|
||||||
|
// will not return until the connection ends.
|
||||||
|
//
|
||||||
|
// The WsHandler is designed to handle exactly one connection at a time. If HandleConn is called
|
||||||
|
// again while the instance is still handling another connection (or if the previous connection
|
||||||
|
// failed to fully cleanup) calling HandleConn again will panic.
|
||||||
|
func (h *WsHandler) HandleConn(ctx context.Context, conn *websocket.Conn) {
|
||||||
|
if h.dataChan != nil {
|
||||||
|
panic("WsHandler.HandleConn called while handling a previous connection")
|
||||||
|
}
|
||||||
|
if len(h.localConns) > 0 {
|
||||||
|
panic(fmt.Sprintf("WsHandler has lingering local connections: %v", h.localConns))
|
||||||
|
}
|
||||||
|
h.dataChan = make(chan *packer.Packer)
|
||||||
|
|
||||||
|
// The sub context allows us to clean up all of the goroutines associated with this websocket
|
||||||
|
// if it closes at any point for any reason.
|
||||||
|
subCtx, socketQuit := context.WithCancel(ctx)
|
||||||
|
defer socketQuit()
|
||||||
|
h.ctx = subCtx
|
||||||
|
|
||||||
|
// Start the routine that will write all of the data from the local connection to the
|
||||||
|
// remote websocket connection.
|
||||||
|
go h.writeRemote(conn)
|
||||||
|
|
||||||
|
for {
|
||||||
|
_, message, err := conn.ReadMessage()
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println("failed to read message from websocket", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p, err := packer.ReadMessage(message)
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println("failed to parse message from websocket", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeLocal(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) writeRemote(conn *websocket.Conn) {
|
||||||
|
defer h.closeConnections()
|
||||||
|
defer func() { h.dataChan = nil }()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.ctx.Done():
|
||||||
|
// We can't tell if this happened because the websocket is already closed/errored or
|
||||||
|
// if it happened because the main context closed (in which case it would be preferable
|
||||||
|
// to properly close the connection). As such we try to close the connection and ignore
|
||||||
|
// all errors if it doesn't work.
|
||||||
|
message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection")
|
||||||
|
deadline := time.Now().Add(10 * time.Second)
|
||||||
|
conn.WriteControl(websocket.CloseMessage, message, deadline)
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
|
||||||
|
case p := <-h.dataChan:
|
||||||
|
packed := p.PackV1()
|
||||||
|
conn.WriteMessage(websocket.BinaryMessage, packed.Bytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) sendPackedMessage(header *packer.Header, data []byte, service string) {
|
||||||
|
p := packer.NewPacker(header)
|
||||||
|
if len(data) > 0 {
|
||||||
|
p.Data.AppendBytes(data)
|
||||||
|
}
|
||||||
|
if service != "" {
|
||||||
|
p.SetService(service)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid blocking on the data channel if the websocket closes or is already closed
|
||||||
|
select {
|
||||||
|
case h.dataChan <- p:
|
||||||
|
case <-h.ctx.Done():
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) getLocalConn(p *packer.Packer) net.Conn {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
key := fmt.Sprintf("%s:%d", p.Address(), p.Port())
|
||||||
|
// Simplest case: it's already open, just return it.
|
||||||
|
if conn := h.localConns[key]; conn != nil {
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
service := strings.ToLower(p.Service())
|
||||||
|
portList := h.servicePorts[service]
|
||||||
|
if portList == nil {
|
||||||
|
loginfo.Println("cannot open connection for invalid service", service)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var hostname string
|
||||||
|
if service == "http" {
|
||||||
|
if match := hostRegexp.FindSubmatch(p.Data.Data()); match != nil {
|
||||||
|
hostname = strings.Split(string(match[1]), ":")[0]
|
||||||
|
}
|
||||||
|
} else if service == "https" {
|
||||||
|
hostname, _ = sni.GetHostname(p.Data.Data())
|
||||||
|
} else {
|
||||||
|
hostname = "*"
|
||||||
|
}
|
||||||
|
if hostname == "" {
|
||||||
|
loginfo.Println("missing servername for", service, key)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
hostname = strings.ToLower(hostname)
|
||||||
|
|
||||||
|
port := portList[hostname]
|
||||||
|
if port == 0 {
|
||||||
|
port = portList["*"]
|
||||||
|
}
|
||||||
|
if port == 0 {
|
||||||
|
loginfo.Println("unable to determine local port for", service, hostname)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
|
||||||
|
if err != nil {
|
||||||
|
loginfo.Println("unable to open local connection on port", port, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
h.localConns[key] = conn
|
||||||
|
loginfo.Printf("new client %q for %s:%d (%d clients)\n", key, hostname, port, len(h.localConns))
|
||||||
|
go h.readLocal(key, &p.Header)
|
||||||
|
return conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) writeLocal(p *packer.Packer) {
|
||||||
|
conn := h.getLocalConn(p)
|
||||||
|
if conn == nil {
|
||||||
|
h.sendPackedMessage(&p.Header, nil, "error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Service() == "error" || p.Service() == "end" {
|
||||||
|
conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := conn.Write(p.Data.Data()); err != nil {
|
||||||
|
h.sendPackedMessage(&p.Header, nil, "error")
|
||||||
|
loginfo.Println("failed to write to local connection", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) readLocal(key string, header *packer.Header) {
|
||||||
|
h.lock.Lock()
|
||||||
|
conn := h.localConns[key]
|
||||||
|
h.lock.Unlock()
|
||||||
|
|
||||||
|
defer conn.Close()
|
||||||
|
defer func() {
|
||||||
|
h.lock.Lock()
|
||||||
|
delete(h.localConns, key)
|
||||||
|
loginfo.Printf("closing client %q: (%d clients)\n", key, len(h.localConns))
|
||||||
|
h.lock.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
size, err := conn.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
h.sendPackedMessage(header, nil, "end")
|
||||||
|
} else {
|
||||||
|
loginfo.Println("failed to read from local connection for", key, err)
|
||||||
|
h.sendPackedMessage(header, nil, "error")
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.sendPackedMessage(header, buf[:size], "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *WsHandler) closeConnections() {
|
||||||
|
h.lock.Lock()
|
||||||
|
defer h.lock.Unlock()
|
||||||
|
|
||||||
|
for _, conn := range h.localConns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,16 +16,19 @@ const (
|
||||||
|
|
||||||
//Packer -- contains both header and data
|
//Packer -- contains both header and data
|
||||||
type Packer struct {
|
type Packer struct {
|
||||||
Header *packerHeader
|
Header
|
||||||
Data *packerData
|
Data packerData
|
||||||
}
|
}
|
||||||
|
|
||||||
//NewPacker -- Structre
|
// NewPacker creates a new Packer struct using the information from the provided header as
|
||||||
func NewPacker() (p *Packer) {
|
// its own header. (Because the header is stored directly and not as a pointer/reference
|
||||||
p = new(Packer)
|
// it should be safe to override items like the service without affecting the template header.)
|
||||||
p.Header = newPackerHeader()
|
func NewPacker(header *Header) *Packer {
|
||||||
p.Data = newPackerData()
|
p := new(Packer)
|
||||||
return
|
if header != nil {
|
||||||
|
p.Header = *header
|
||||||
|
}
|
||||||
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
func splitHeader(header []byte, names []string) (map[string]string, error) {
|
func splitHeader(header []byte, names []string) (map[string]string, error) {
|
||||||
|
@ -48,7 +51,7 @@ func ReadMessage(b []byte) (*Packer, error) {
|
||||||
// Detect protocol in use
|
// Detect protocol in use
|
||||||
if b[0] == packerV1 {
|
if b[0] == packerV1 {
|
||||||
// Separate the header and body using the header length in the second byte.
|
// Separate the header and body using the header length in the second byte.
|
||||||
p := NewPacker()
|
p := NewPacker(nil)
|
||||||
header := b[2 : b[1]+2]
|
header := b[2 : b[1]+2]
|
||||||
data := b[b[1]+2:]
|
data := b[b[1]+2:]
|
||||||
|
|
||||||
|
@ -69,8 +72,8 @@ func ReadMessage(b []byte) (*Packer, error) {
|
||||||
p.Header.address = net.ParseIP(parts["address"])
|
p.Header.address = net.ParseIP(parts["address"])
|
||||||
if p.Header.address == nil {
|
if p.Header.address == nil {
|
||||||
return nil, fmt.Errorf("Invalid network address %q", parts["address"])
|
return nil, fmt.Errorf("Invalid network address %q", parts["address"])
|
||||||
} else if p.Header.Family() == FamilyIPv4 && p.Header.address.To4() == nil {
|
} else if p.Header.family == FamilyIPv4 && p.Header.address.To4() == nil {
|
||||||
return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.FamilyText())
|
return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.Family())
|
||||||
}
|
}
|
||||||
|
|
||||||
//handle port
|
//handle port
|
||||||
|
@ -79,7 +82,7 @@ func ReadMessage(b []byte) (*Packer, error) {
|
||||||
} else if port <= 0 || port > 65535 {
|
} else if port <= 0 || port > 65535 {
|
||||||
return nil, fmt.Errorf("Port %d out of range", port)
|
return nil, fmt.Errorf("Port %d out of range", port)
|
||||||
} else {
|
} else {
|
||||||
p.Header.Port = port
|
p.Header.port = port
|
||||||
}
|
}
|
||||||
|
|
||||||
//handle data length
|
//handle data length
|
||||||
|
@ -90,7 +93,7 @@ func ReadMessage(b []byte) (*Packer, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//handle Service
|
//handle Service
|
||||||
p.Header.Service = parts["service"]
|
p.Header.service = parts["service"]
|
||||||
|
|
||||||
//handle payload
|
//handle payload
|
||||||
p.Data.AppendBytes(data)
|
p.Data.AppendBytes(data)
|
||||||
|
@ -103,11 +106,11 @@ func ReadMessage(b []byte) (*Packer, error) {
|
||||||
//PackV1 -- Outputs version 1 of packer
|
//PackV1 -- Outputs version 1 of packer
|
||||||
func (p *Packer) PackV1() bytes.Buffer {
|
func (p *Packer) PackV1() bytes.Buffer {
|
||||||
header := strings.Join([]string{
|
header := strings.Join([]string{
|
||||||
p.Header.FamilyText(),
|
p.Header.Family(),
|
||||||
p.Header.AddressString(),
|
p.Header.Address(),
|
||||||
strconv.Itoa(p.Header.Port),
|
strconv.Itoa(p.Header.Port()),
|
||||||
strconv.Itoa(p.Data.DataLen()),
|
strconv.Itoa(p.Data.DataLen()),
|
||||||
p.Header.Service,
|
p.Header.Service(),
|
||||||
}, ",")
|
}, ",")
|
||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
|
|
|
@ -9,10 +9,6 @@ type packerData struct {
|
||||||
buffer bytes.Buffer
|
buffer bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPackerData() *packerData {
|
|
||||||
return new(packerData)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *packerData) AppendString(dataString string) (int, error) {
|
func (p *packerData) AppendString(dataString string) (int, error) {
|
||||||
return p.buffer.WriteString(dataString)
|
return p.buffer.WriteString(dataString)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,12 +7,15 @@ import (
|
||||||
|
|
||||||
type addressFamily int
|
type addressFamily int
|
||||||
|
|
||||||
// packerHeader structure to hold our header information.
|
// The Header struct holds most of the information contained in the header for packets
|
||||||
type packerHeader struct {
|
// between the client and the server (the length of the data is not included here). It
|
||||||
|
// is used to uniquely identify remote connections on the servers end and to communicate
|
||||||
|
// which service the remote client is trying to connect to.
|
||||||
|
type Header struct {
|
||||||
family addressFamily
|
family addressFamily
|
||||||
address net.IP
|
address net.IP
|
||||||
Port int
|
port int
|
||||||
Service string
|
service string
|
||||||
}
|
}
|
||||||
|
|
||||||
//Family -- ENUM for Address Family
|
//Family -- ENUM for Address Family
|
||||||
|
@ -26,16 +29,19 @@ var addressFamilyText = [...]string{
|
||||||
"IPv6",
|
"IPv6",
|
||||||
}
|
}
|
||||||
|
|
||||||
func newPackerHeader() (p *packerHeader) {
|
// NewHeader create a new Header object.
|
||||||
p = new(packerHeader)
|
func NewHeader(address string, port int, service string) (*Header, error) {
|
||||||
p.SetAddress("127.0.0.1")
|
h := new(Header)
|
||||||
p.Port = 65535
|
if err := h.setAddress(address); err != nil {
|
||||||
p.Service = "na"
|
return nil, err
|
||||||
return
|
}
|
||||||
|
h.port = port
|
||||||
|
h.service = service
|
||||||
|
return h, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//SetAddress -- Set Address. which sets address family automatically
|
// setAddress parses the provided address string and automatically sets the IP family.
|
||||||
func (p *packerHeader) SetAddress(addr string) {
|
func (p *Header) setAddress(addr string) error {
|
||||||
p.address = net.ParseIP(addr)
|
p.address = net.ParseIP(addr)
|
||||||
|
|
||||||
if p.address.To4() != nil {
|
if p.address.To4() != nil {
|
||||||
|
@ -43,30 +49,33 @@ func (p *packerHeader) SetAddress(addr string) {
|
||||||
} else if p.address.To16() != nil {
|
} else if p.address.To16() != nil {
|
||||||
p.family = FamilyIPv6
|
p.family = FamilyIPv6
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Sprintf("setAddress does not support %q", addr))
|
return fmt.Errorf("invalid IP address %q", addr)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packerHeader) AddressBytes() []byte {
|
// Family returns the string corresponding to the address's IP family.
|
||||||
if ip4 := p.address.To4(); ip4 != nil {
|
func (p *Header) Family() string {
|
||||||
p.address = ip4
|
return addressFamilyText[p.family]
|
||||||
}
|
|
||||||
|
|
||||||
return []byte(p.address)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packerHeader) AddressString() string {
|
// Address returns the string form of the header's remote address.
|
||||||
|
func (p *Header) Address() string {
|
||||||
return p.address.String()
|
return p.address.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packerHeader) Address() net.IP {
|
// Port returns the connected port of the remote connection.
|
||||||
return p.address
|
func (p *Header) Port() int {
|
||||||
|
return p.port
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packerHeader) Family() addressFamily {
|
// SetService overrides the header's original service. This is primarily useful
|
||||||
return p.family
|
// for sending 'error' and 'end' messages.
|
||||||
|
func (p *Header) SetService(service string) {
|
||||||
|
p.service = service
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *packerHeader) FamilyText() string {
|
// Service returns the service stored in the header.
|
||||||
return addressFamilyText[p.family]
|
func (p *Header) Service() string {
|
||||||
|
return p.service
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,8 +2,8 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ func (c *Connection) Reader(ctx context.Context) {
|
||||||
|
|
||||||
// unpack the message.
|
// unpack the message.
|
||||||
p, err := packer.ReadMessage(message)
|
p, err := packer.ReadMessage(message)
|
||||||
key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port)
|
key := fmt.Sprintf("%s:%d", p.Address(), p.Port())
|
||||||
track, err := connectionTrack.Lookup(key)
|
track, err := connectionTrack.Lookup(key)
|
||||||
|
|
||||||
//loginfo.Println(hex.Dump(p.Data.Data()))
|
//loginfo.Println(hex.Dump(p.Data.Data()))
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
|
|
||||||
"git.daplie.com/Daplie/go-rvpn-server/rvpn/packer"
|
"git.daplie.com/Daplie/go-rvpn-server/rvpn/packer"
|
||||||
|
"git.daplie.com/Daplie/go-rvpn-server/rvpn/sni"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
|
@ -160,7 +161,7 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
|
||||||
wssHostName := ctx.Value(ctxWssHostName).(string)
|
wssHostName := ctx.Value(ctxWssHostName).(string)
|
||||||
adminHostName := ctx.Value(ctxAdminHostName).(string)
|
adminHostName := ctx.Value(ctxAdminHostName).(string)
|
||||||
|
|
||||||
sniHostName, err := getHello(peek)
|
sniHostName, err := sni.GetHostname(peek)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
loginfo.Println(err)
|
loginfo.Println(err)
|
||||||
return
|
return
|
||||||
|
@ -292,11 +293,19 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
|
||||||
track := NewTrack(extConn, hostname)
|
track := NewTrack(extConn, hostname)
|
||||||
serverStatus.ExtConnectionRegister(track)
|
serverStatus.ExtConnectionRegister(track)
|
||||||
|
|
||||||
loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String())
|
remoteStr := extConn.RemoteAddr().String()
|
||||||
|
loginfo.Println("Domain Accepted", hostname, remoteStr)
|
||||||
|
|
||||||
rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String())
|
var header *packer.Header
|
||||||
if err != nil {
|
if rAddr, rPort, err := net.SplitHostPort(remoteStr); err != nil {
|
||||||
loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String())
|
loginfo.Println("unable to decode hostport", remoteStr, err)
|
||||||
|
} else if port, err := strconv.Atoi(rPort); err != nil {
|
||||||
|
loginfo.Printf("unable to parse port string %q: %v\n", rPort, err)
|
||||||
|
} else if header, err = packer.NewHeader(rAddr, port, service); err != nil {
|
||||||
|
loginfo.Println("unable to create packer header", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if header == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,18 +318,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
|
||||||
|
|
||||||
loginfo.Println("Before Packer", hex.Dump(buffer))
|
loginfo.Println("Before Packer", hex.Dump(buffer))
|
||||||
|
|
||||||
cnt := len(buffer)
|
p := packer.NewPacker(header)
|
||||||
|
p.Data.AppendBytes(buffer)
|
||||||
p := packer.NewPacker()
|
|
||||||
p.Header.SetAddress(rAddr)
|
|
||||||
p.Header.Port, err = strconv.Atoi(rPort)
|
|
||||||
if err != nil {
|
|
||||||
loginfo.Println("Unable to set Remote port", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
p.Header.Service = service
|
|
||||||
p.Data.AppendBytes(buffer[0:cnt])
|
|
||||||
buf := p.PackV1()
|
buf := p.PackV1()
|
||||||
|
|
||||||
//loginfo.Println(hex.Dump(buf.Bytes()))
|
//loginfo.Println(hex.Dump(buf.Bytes()))
|
||||||
|
@ -329,8 +328,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
|
||||||
sendTrack := NewSendTrack(buf.Bytes(), hostname)
|
sendTrack := NewSendTrack(buf.Bytes(), hostname)
|
||||||
serverStatus.SendExtRequest(conn, sendTrack)
|
serverStatus.SendExtRequest(conn, sendTrack)
|
||||||
|
|
||||||
_, err = extConn.Discard(cnt)
|
cnt := len(buffer)
|
||||||
if err != nil {
|
if _, err = extConn.Discard(cnt); err != nil {
|
||||||
loginfo.Println("unable to discard", cnt, err)
|
loginfo.Println("unable to discard", cnt, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
package server
|
package sni
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getHello(b []byte) (string, error) {
|
// GetHostname uses SNI to determine the intended target of a new TLS connection.
|
||||||
|
func GetHostname(b []byte) (string, error) {
|
||||||
rest := b[5:]
|
rest := b[5:]
|
||||||
current := 0
|
current := 0
|
||||||
handshakeType := rest[0]
|
handshakeType := rest[0]
|
Loading…
Reference in New Issue