Merge branch 'client'

This commit is contained in:
tigerbot 2017-04-17 15:15:22 -06:00
commit 01de157cfe
12 changed files with 606 additions and 71 deletions

208
go-rvpn-client/main.go Normal file
View File

@ -0,0 +1,208 @@
package main
import (
"context"
"fmt"
"regexp"
"strconv"
"strings"
jwt "github.com/dgrijalva/jwt-go"
flag "github.com/spf13/pflag"
"github.com/spf13/viper"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/client"
)
var httpRegexp = regexp.MustCompile(`(?i)^http`)
func init() {
flag.StringSlice("locals", []string{}, "comma separated list of <proto>:<port> or "+
"<proto>:<hostname>:<port> to which matching incoming connections should forward. "+
"Ex: smtps:8465,https:example.com:8443")
flag.StringSlice("domains", []string{}, "comma separated list of domain names to set to the tunnel")
viper.BindPFlag("locals", flag.Lookup("locals"))
viper.BindPFlag("domains", flag.Lookup("domains"))
flag.BoolP("insecure", "k", false, "Allow TLS connections to stunneld without valid certs")
flag.String("stunneld", "", "the domain (or ip address) at which the RVPN server is running")
flag.String("secret", "", "the same secret used by stunneld (used for JWT authentication)")
flag.String("token", "", "a pre-generated token to give the server (instead of generating one with --secret)")
viper.BindPFlag("raw.insecure", flag.Lookup("insecure"))
viper.BindPFlag("raw.stunneld", flag.Lookup("stunneld"))
viper.BindPFlag("raw.secret", flag.Lookup("secret"))
viper.BindPFlag("raw.token", flag.Lookup("token"))
}
type proxy struct {
protocol string
hostname string
port int
}
func addLocals(proxies []proxy, location string) []proxy {
parts := strings.Split(location, ":")
if len(parts) > 3 {
panic(fmt.Sprintf("provided invalid location %q", location))
}
// If all that was provided as a "local" is the domain name we assume that domain
// has HTTP and HTTPS handlers on the default ports.
if len(parts) == 1 {
proxies = append(proxies, proxy{"http", parts[0], 80})
proxies = append(proxies, proxy{"https", parts[0], 443})
return proxies
}
// Make everything lower case and trim any slashes in something like https://john.example.com
parts[0] = strings.ToLower(parts[0])
parts[1] = strings.ToLower(strings.Trim(parts[1], "/"))
if len(parts) == 2 {
if strings.Contains(parts[1], ".") {
if parts[0] == "http" {
parts = append(parts, "80")
} else if parts[0] == "https" {
parts = append(parts, "443")
} else {
panic(fmt.Sprintf("port must be specified for %q", location))
}
} else {
// https:3443 -> https:*:3443
parts = []string{parts[0], "*", parts[1]}
}
}
if port, err := strconv.Atoi(parts[2]); err != nil {
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
} else if port <= 0 || port > 65535 {
panic(fmt.Sprintf("%d is an invalid port for local services", port))
} else {
proxies = append(proxies, proxy{parts[0], parts[1], port})
}
return proxies
}
func addDomains(proxies []proxy, location string) []proxy {
parts := strings.Split(location, ":")
if len(parts) > 3 {
panic(fmt.Sprintf("provided invalid location %q", location))
} else if len(parts) == 2 {
panic("invalid argument for --domains, use format <domainname> or <scheme>:<domainname>:<local-port>")
}
// If the scheme and port weren't provided use the zero values
if len(parts) == 1 {
return append(proxies, proxy{"", parts[0], 0})
}
if port, err := strconv.Atoi(parts[2]); err != nil {
panic(fmt.Sprintf("port must be a valid number, not %q: %v", parts[2], err))
} else if port <= 0 || port > 65535 {
panic(fmt.Sprintf("%d is an invalid port for local services", port))
} else {
proxies = append(proxies, proxy{parts[0], parts[1], port})
}
return proxies
}
func extractServicePorts(proxies []proxy) map[string]map[string]int {
result := make(map[string]map[string]int, 2)
for _, p := range proxies {
if p.protocol != "" && p.port != 0 {
hostPorts := result[p.protocol]
if hostPorts == nil {
result[p.protocol] = make(map[string]int)
hostPorts = result[p.protocol]
}
// Only HTTP and HTTPS allow us to determine the hostname from the request, so only
// those protocols support different ports for the same service.
if !httpRegexp.MatchString(p.protocol) || p.hostname == "" {
p.hostname = "*"
}
if port, ok := hostPorts[p.hostname]; ok && port != p.port {
panic(fmt.Sprintf("duplicate ports for %s://%s", p.protocol, p.hostname))
}
hostPorts[p.hostname] = p.port
}
}
// Make sure we have defaults for HTTPS and HTTP.
if result["https"] == nil {
result["https"] = make(map[string]int, 1)
}
if result["https"]["*"] == 0 {
result["https"]["*"] = 8443
}
if result["http"] == nil {
result["http"] = make(map[string]int, 1)
}
if result["http"]["*"] == 0 {
result["http"]["*"] = result["https"]["*"]
}
return result
}
func main() {
flag.Parse()
proxies := make([]proxy, 0)
for _, option := range viper.GetStringSlice("locals") {
for _, location := range strings.Split(option, ",") {
proxies = addLocals(proxies, location)
}
}
for _, option := range viper.GetStringSlice("domains") {
for _, location := range strings.Split(option, ",") {
proxies = addDomains(proxies, location)
}
}
servicePorts := extractServicePorts(proxies)
domainMap := make(map[string]bool)
for _, p := range proxies {
if p.hostname != "" && p.hostname != "*" {
domainMap[p.hostname] = true
}
}
if viper.GetString("raw.stunneld") == "" {
panic("must provide remote RVPN server to connect to")
}
var token string
if viper.GetString("raw.token") != "" {
token = viper.GetString("raw.token")
} else if viper.GetString("raw.secret") != "" {
domains := make([]string, 0, len(domainMap))
for name := range domainMap {
domains = append(domains, name)
}
tokenData := jwt.MapClaims{"domains": domains}
secret := []byte(viper.GetString("raw.secret"))
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
if tokenStr, err := jwtToken.SignedString(secret); err != nil {
panic(err)
} else {
token = tokenStr
}
} else {
panic("must provide either token or secret")
}
ctx, quit := context.WithCancel(context.Background())
defer quit()
config := client.Config{
Insecure: viper.GetBool("raw.insecure"),
Server: viper.GetString("raw.stunneld"),
Services: servicePorts,
Token: token,
}
panic(client.Run(ctx, &config))
}

71
rvpn/client/client.go Normal file
View File

@ -0,0 +1,71 @@
package client
import (
"context"
"crypto/tls"
"fmt"
"net/url"
"time"
"github.com/gorilla/websocket"
)
// The Config struct holds all of the information needed to establish and handle a connection
// with the RVPN server.
type Config struct {
Server string
Token string
Insecure bool
Services map[string]map[string]int
}
// Run establishes a connection with the RVPN server specified in the config. If the first attempt
// to connect fails it is assumed that something is wrong with the authentication and it will
// return an error. Otherwise it will continuously attempt to reconnect whenever the connection
// is broken.
func Run(ctx context.Context, config *Config) error {
serverURL, err := url.Parse(config.Server)
if err != nil {
return fmt.Errorf("Invalid server URL: %v", err)
}
if serverURL.Scheme == "" {
serverURL.Scheme = "wss"
}
serverURL.Path = ""
query := make(url.Values)
query.Set("access_token", config.Token)
serverURL.RawQuery = query.Encode()
dialer := websocket.Dialer{}
if config.Insecure {
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
for name, portList := range config.Services {
if _, ok := portList["*"]; !ok {
return fmt.Errorf(`service %s missing port for "*"`, name)
}
}
handler := NewWsHandler(config.Services)
authenticated := false
for {
if conn, _, err := dialer.Dial(serverURL.String(), nil); err == nil {
loginfo.Println("connected to remote server")
authenticated = true
handler.HandleConn(ctx, conn)
} else if !authenticated {
return fmt.Errorf("First connection to server failed - check auth: %v", err)
}
loginfo.Println("disconnected from remote server")
// Sleep for a few seconds before trying again, but only if the context is still active
select {
case <-ctx.Done():
return nil
case <-time.After(5 * time.Second):
}
loginfo.Println("attempting reconnect to remote server")
}
}

15
rvpn/client/setup.go Normal file
View File

@ -0,0 +1,15 @@
package client
import (
"log"
"os"
)
const (
logFlags = log.Ldate | log.Lmicroseconds | log.Lshortfile
)
var (
loginfo = log.New(os.Stdout, "INFO: client: ", logFlags)
logdebug = log.New(os.Stdout, "DEBUG: client:", logFlags)
)

233
rvpn/client/ws_handler.go Normal file
View File

@ -0,0 +1,233 @@
package client
import (
"context"
"fmt"
"io"
"net"
"regexp"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/packer"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/sni"
)
var hostRegexp = regexp.MustCompile(`(?im)(?:^|[\r\n])Host: *([^\r\n]+)[\r\n]`)
// WsHandler handles all of reading and writing for the websocket connection to the RVPN server
// and the TCP connections to the local servers.
type WsHandler struct {
lock sync.Mutex
localConns map[string]net.Conn
servicePorts map[string]map[string]int
ctx context.Context
dataChan chan *packer.Packer
}
// NewWsHandler creates a new handler ready to be given a websocket connection. The services
// argument specifies what port each service type should be directed to on the local interface.
func NewWsHandler(services map[string]map[string]int) *WsHandler {
h := new(WsHandler)
h.servicePorts = services
h.localConns = make(map[string]net.Conn)
return h
}
// HandleConn handles all of the traffic on the provided websocket connection. The function
// will not return until the connection ends.
//
// The WsHandler is designed to handle exactly one connection at a time. If HandleConn is called
// again while the instance is still handling another connection (or if the previous connection
// failed to fully cleanup) calling HandleConn again will panic.
func (h *WsHandler) HandleConn(ctx context.Context, conn *websocket.Conn) {
if h.dataChan != nil {
panic("WsHandler.HandleConn called while handling a previous connection")
}
if len(h.localConns) > 0 {
panic(fmt.Sprintf("WsHandler has lingering local connections: %v", h.localConns))
}
h.dataChan = make(chan *packer.Packer)
// The sub context allows us to clean up all of the goroutines associated with this websocket
// if it closes at any point for any reason.
subCtx, socketQuit := context.WithCancel(ctx)
defer socketQuit()
h.ctx = subCtx
// Start the routine that will write all of the data from the local connection to the
// remote websocket connection.
go h.writeRemote(conn)
for {
_, message, err := conn.ReadMessage()
if err != nil {
loginfo.Println("failed to read message from websocket", err)
return
}
p, err := packer.ReadMessage(message)
if err != nil {
loginfo.Println("failed to parse message from websocket", err)
return
}
h.writeLocal(p)
}
}
func (h *WsHandler) writeRemote(conn *websocket.Conn) {
defer h.closeConnections()
defer func() { h.dataChan = nil }()
for {
select {
case <-h.ctx.Done():
// We can't tell if this happened because the websocket is already closed/errored or
// if it happened because the main context closed (in which case it would be preferable
// to properly close the connection). As such we try to close the connection and ignore
// all errors if it doesn't work.
message := websocket.FormatCloseMessage(websocket.CloseGoingAway, "closing connection")
deadline := time.Now().Add(10 * time.Second)
conn.WriteControl(websocket.CloseMessage, message, deadline)
conn.Close()
return
case p := <-h.dataChan:
packed := p.PackV1()
conn.WriteMessage(websocket.BinaryMessage, packed.Bytes())
}
}
}
func (h *WsHandler) sendPackedMessage(header *packer.Header, data []byte, service string) {
p := packer.NewPacker(header)
if len(data) > 0 {
p.Data.AppendBytes(data)
}
if service != "" {
p.SetService(service)
}
// Avoid blocking on the data channel if the websocket closes or is already closed
select {
case h.dataChan <- p:
case <-h.ctx.Done():
}
}
func (h *WsHandler) getLocalConn(p *packer.Packer) net.Conn {
h.lock.Lock()
defer h.lock.Unlock()
key := fmt.Sprintf("%s:%d", p.Address(), p.Port())
// Simplest case: it's already open, just return it.
if conn := h.localConns[key]; conn != nil {
return conn
}
service := strings.ToLower(p.Service())
portList := h.servicePorts[service]
if portList == nil {
loginfo.Println("cannot open connection for invalid service", service)
return nil
}
var hostname string
if service == "http" {
if match := hostRegexp.FindSubmatch(p.Data.Data()); match != nil {
hostname = strings.Split(string(match[1]), ":")[0]
}
} else if service == "https" {
hostname, _ = sni.GetHostname(p.Data.Data())
} else {
hostname = "*"
}
if hostname == "" {
loginfo.Println("missing servername for", service, key)
return nil
}
hostname = strings.ToLower(hostname)
port := portList[hostname]
if port == 0 {
port = portList["*"]
}
if port == 0 {
loginfo.Println("unable to determine local port for", service, hostname)
return nil
}
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", port))
if err != nil {
loginfo.Println("unable to open local connection on port", port, err)
return nil
}
h.localConns[key] = conn
loginfo.Printf("new client %q for %s:%d (%d clients)\n", key, hostname, port, len(h.localConns))
go h.readLocal(key, &p.Header)
return conn
}
func (h *WsHandler) writeLocal(p *packer.Packer) {
conn := h.getLocalConn(p)
if conn == nil {
h.sendPackedMessage(&p.Header, nil, "error")
return
}
if p.Service() == "error" || p.Service() == "end" {
conn.Close()
return
}
if _, err := conn.Write(p.Data.Data()); err != nil {
h.sendPackedMessage(&p.Header, nil, "error")
loginfo.Println("failed to write to local connection", err)
}
}
func (h *WsHandler) readLocal(key string, header *packer.Header) {
h.lock.Lock()
conn := h.localConns[key]
h.lock.Unlock()
defer conn.Close()
defer func() {
h.lock.Lock()
delete(h.localConns, key)
loginfo.Printf("closing client %q: (%d clients)\n", key, len(h.localConns))
h.lock.Unlock()
}()
buf := make([]byte, 4096)
for {
size, err := conn.Read(buf)
if err != nil {
if err == io.EOF || strings.Contains(err.Error(), "use of closed network connection") {
h.sendPackedMessage(header, nil, "end")
} else {
loginfo.Println("failed to read from local connection for", key, err)
h.sendPackedMessage(header, nil, "error")
}
return
}
h.sendPackedMessage(header, buf[:size], "")
}
}
func (h *WsHandler) closeConnections() {
h.lock.Lock()
defer h.lock.Unlock()
for _, conn := range h.localConns {
conn.Close()
}
}

View File

@ -16,16 +16,19 @@ const (
//Packer -- contains both header and data
type Packer struct {
Header *packerHeader
Data *packerData
Header
Data packerData
}
//NewPacker -- Structre
func NewPacker() (p *Packer) {
p = new(Packer)
p.Header = newPackerHeader()
p.Data = newPackerData()
return
// NewPacker creates a new Packer struct using the information from the provided header as
// its own header. (Because the header is stored directly and not as a pointer/reference
// it should be safe to override items like the service without affecting the template header.)
func NewPacker(header *Header) *Packer {
p := new(Packer)
if header != nil {
p.Header = *header
}
return p
}
func splitHeader(header []byte, names []string) (map[string]string, error) {
@ -48,7 +51,7 @@ func ReadMessage(b []byte) (*Packer, error) {
// Detect protocol in use
if b[0] == packerV1 {
// Separate the header and body using the header length in the second byte.
p := NewPacker()
p := NewPacker(nil)
header := b[2 : b[1]+2]
data := b[b[1]+2:]
@ -69,8 +72,8 @@ func ReadMessage(b []byte) (*Packer, error) {
p.Header.address = net.ParseIP(parts["address"])
if p.Header.address == nil {
return nil, fmt.Errorf("Invalid network address %q", parts["address"])
} else if p.Header.Family() == FamilyIPv4 && p.Header.address.To4() == nil {
return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.FamilyText())
} else if p.Header.family == FamilyIPv4 && p.Header.address.To4() == nil {
return nil, fmt.Errorf("Address %q is not in address family %s", parts["address"], p.Header.Family())
}
//handle port
@ -79,7 +82,7 @@ func ReadMessage(b []byte) (*Packer, error) {
} else if port <= 0 || port > 65535 {
return nil, fmt.Errorf("Port %d out of range", port)
} else {
p.Header.Port = port
p.Header.port = port
}
//handle data length
@ -90,7 +93,7 @@ func ReadMessage(b []byte) (*Packer, error) {
}
//handle Service
p.Header.Service = parts["service"]
p.Header.service = parts["service"]
//handle payload
p.Data.AppendBytes(data)
@ -103,11 +106,11 @@ func ReadMessage(b []byte) (*Packer, error) {
//PackV1 -- Outputs version 1 of packer
func (p *Packer) PackV1() bytes.Buffer {
header := strings.Join([]string{
p.Header.FamilyText(),
p.Header.AddressString(),
strconv.Itoa(p.Header.Port),
p.Header.Family(),
p.Header.Address(),
strconv.Itoa(p.Header.Port()),
strconv.Itoa(p.Data.DataLen()),
p.Header.Service,
p.Header.Service(),
}, ",")
var buf bytes.Buffer

View File

@ -9,10 +9,6 @@ type packerData struct {
buffer bytes.Buffer
}
func newPackerData() *packerData {
return new(packerData)
}
func (p *packerData) AppendString(dataString string) (int, error) {
return p.buffer.WriteString(dataString)
}

View File

@ -7,12 +7,15 @@ import (
type addressFamily int
// packerHeader structure to hold our header information.
type packerHeader struct {
// The Header struct holds most of the information contained in the header for packets
// between the client and the server (the length of the data is not included here). It
// is used to uniquely identify remote connections on the servers end and to communicate
// which service the remote client is trying to connect to.
type Header struct {
family addressFamily
address net.IP
Port int
Service string
port int
service string
}
//Family -- ENUM for Address Family
@ -26,16 +29,19 @@ var addressFamilyText = [...]string{
"IPv6",
}
func newPackerHeader() (p *packerHeader) {
p = new(packerHeader)
p.SetAddress("127.0.0.1")
p.Port = 65535
p.Service = "na"
return
// NewHeader create a new Header object.
func NewHeader(address string, port int, service string) (*Header, error) {
h := new(Header)
if err := h.setAddress(address); err != nil {
return nil, err
}
h.port = port
h.service = service
return h, nil
}
//SetAddress -- Set Address. which sets address family automatically
func (p *packerHeader) SetAddress(addr string) {
// setAddress parses the provided address string and automatically sets the IP family.
func (p *Header) setAddress(addr string) error {
p.address = net.ParseIP(addr)
if p.address.To4() != nil {
@ -43,30 +49,33 @@ func (p *packerHeader) SetAddress(addr string) {
} else if p.address.To16() != nil {
p.family = FamilyIPv6
} else {
panic(fmt.Sprintf("setAddress does not support %q", addr))
return fmt.Errorf("invalid IP address %q", addr)
}
return nil
}
func (p *packerHeader) AddressBytes() []byte {
if ip4 := p.address.To4(); ip4 != nil {
p.address = ip4
}
return []byte(p.address)
// Family returns the string corresponding to the address's IP family.
func (p *Header) Family() string {
return addressFamilyText[p.family]
}
func (p *packerHeader) AddressString() string {
// Address returns the string form of the header's remote address.
func (p *Header) Address() string {
return p.address.String()
}
func (p *packerHeader) Address() net.IP {
return p.address
// Port returns the connected port of the remote connection.
func (p *Header) Port() int {
return p.port
}
func (p *packerHeader) Family() addressFamily {
return p.family
// SetService overrides the header's original service. This is primarily useful
// for sending 'error' and 'end' messages.
func (p *Header) SetService(service string) {
p.service = service
}
func (p *packerHeader) FamilyText() string {
return addressFamilyText[p.family]
// Service returns the service stored in the header.
func (p *Header) Service() string {
return p.service
}

View File

@ -2,8 +2,8 @@ package server
import (
"context"
"fmt"
"io"
"strconv"
"sync"
"time"
@ -252,7 +252,7 @@ func (c *Connection) Reader(ctx context.Context) {
// unpack the message.
p, err := packer.ReadMessage(message)
key := p.Header.Address().String() + ":" + strconv.Itoa(p.Header.Port)
key := fmt.Sprintf("%s:%d", p.Address(), p.Port())
track, err := connectionTrack.Lookup(key)
//loginfo.Println(hex.Dump(p.Data.Data()))

View File

@ -18,6 +18,7 @@ import (
"github.com/gorilla/websocket"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/packer"
"git.daplie.com/Daplie/go-rvpn-server/rvpn/sni"
)
type contextKey string
@ -160,7 +161,7 @@ func handleConnection(ctx context.Context, wConn *WedgeConn) {
wssHostName := ctx.Value(ctxWssHostName).(string)
adminHostName := ctx.Value(ctxAdminHostName).(string)
sniHostName, err := getHello(peek)
sniHostName, err := sni.GetHostname(peek)
if err != nil {
loginfo.Println(err)
return
@ -292,11 +293,19 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
track := NewTrack(extConn, hostname)
serverStatus.ExtConnectionRegister(track)
loginfo.Println("Domain Accepted", hostname, extConn.RemoteAddr().String())
remoteStr := extConn.RemoteAddr().String()
loginfo.Println("Domain Accepted", hostname, remoteStr)
rAddr, rPort, err := net.SplitHostPort(extConn.RemoteAddr().String())
if err != nil {
loginfo.Println("unable to decode hostport", extConn.RemoteAddr().String())
var header *packer.Header
if rAddr, rPort, err := net.SplitHostPort(remoteStr); err != nil {
loginfo.Println("unable to decode hostport", remoteStr, err)
} else if port, err := strconv.Atoi(rPort); err != nil {
loginfo.Printf("unable to parse port string %q: %v\n", rPort, err)
} else if header, err = packer.NewHeader(rAddr, port, service); err != nil {
loginfo.Println("unable to create packer header", err)
}
if header == nil {
return
}
@ -309,18 +318,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
loginfo.Println("Before Packer", hex.Dump(buffer))
cnt := len(buffer)
p := packer.NewPacker()
p.Header.SetAddress(rAddr)
p.Header.Port, err = strconv.Atoi(rPort)
if err != nil {
loginfo.Println("Unable to set Remote port", err)
return
}
p.Header.Service = service
p.Data.AppendBytes(buffer[0:cnt])
p := packer.NewPacker(header)
p.Data.AppendBytes(buffer)
buf := p.PackV1()
//loginfo.Println(hex.Dump(buf.Bytes()))
@ -329,8 +328,8 @@ func handleExternalHTTPRequest(ctx context.Context, extConn *WedgeConn, hostname
sendTrack := NewSendTrack(buf.Bytes(), hostname)
serverStatus.SendExtRequest(conn, sendTrack)
_, err = extConn.Discard(cnt)
if err != nil {
cnt := len(buffer)
if _, err = extConn.Discard(cnt); err != nil {
loginfo.Println("unable to discard", cnt, err)
return
}

View File

@ -52,7 +52,7 @@ type servers struct {
secretKey string
certbundle tls.Certificate
register chan *ListenerRegistration
servers *servers
servers *servers
wssHostName string
adminHostName string
cancelCheck int

View File

@ -1,10 +1,11 @@
package server
package sni
import (
"errors"
)
func getHello(b []byte) (string, error) {
// GetHostname uses SNI to determine the intended target of a new TLS connection.
func GetHostname(b []byte) (string, error) {
rest := b[5:]
current := 0
handshakeType := rest[0]