telebit/cmd/telebit-relay/telebit-relay.go

306 lines
8.3 KiB
Go

package main
import (
"context"
"crypto/tls"
"flag"
"fmt"
"io"
"io/ioutil"
golog "log"
"net/http"
"os"
"strconv"
"strings"
"git.coolaj86.com/coolaj86/go-telebitd/log"
"git.coolaj86.com/coolaj86/go-telebitd/relay"
"git.coolaj86.com/coolaj86/go-telebitd/relay/api"
"git.coolaj86.com/coolaj86/go-telebitd/relay/mplexy"
"github.com/caddyserver/certmagic"
jwt "github.com/dgrijalva/jwt-go"
"github.com/go-acme/lego/v3/providers/dns/duckdns"
lumberjack "gopkg.in/natefinch/lumberjack.v2"
_ "github.com/joho/godotenv/autoload"
)
// Loginfo TODO remove
var Loginfo = log.Loginfo
// Logdebug TODO remove
var Logdebug = log.Logdebug
func init() {
log.LogFlags = golog.Ldate | golog.Lmicroseconds | golog.Lshortfile
}
var (
logfile = "stdout"
configPath = "./"
configFile = "telebit-relay"
tcpPort int
argDeadTime int
connectionTable *api.Table
secretKey string
wssHostName string
adminHostName string
idle int
dwell int
cancelcheck int
loadBalanceMethod api.LoadBalanceStrategy
nickname string
acmeEmail string
acmeStorage string
acmeAgree bool
acmeStaging bool
allclients string
)
func init() {
flag.StringVar(&allclients, "clients", "", "list of client:secret pairings such as example.com:secret123,foo.com:secret321")
flag.StringVar(&acmeEmail, "acme-email", "", "email to use for Let's Encrypt / ACME registration")
flag.StringVar(&acmeStorage, "acme-storage", "./acme.d/", "path to ACME storage directory")
flag.BoolVar(&acmeAgree, "acme-agree", false, "agree to the terms of the ACME service provider (required)")
flag.BoolVar(&acmeStaging, "staging", false, "get fake certificates for testing")
flag.StringVar(&adminHostName, "admin-hostname", "", "the management domain")
flag.StringVar(&wssHostName, "wss-hostname", "", "the wss domain for connecting devices, if different from admin")
flag.StringVar(&configPath, "config-path", configPath, "Configuration File Path")
flag.StringVar(&secretKey, "secret", "", "a >= 16-character random string for JWT key signing") // SECRET
flag.StringVar(&logfile, "log", logfile, "Log file (or stdout/stderr; empty for none)")
flag.IntVar(&tcpPort, "port", 0, "tcp port on which to listen") // PORT
flag.StringVar(&nickname, "nickname", "", "a nickname for this server, as an identifier") // NICKNAME
}
var logoutput io.Writer
// Client is a domain and secret pair
type Client struct {
domain string
secret string
}
//Main -- main entry point
func main() {
flag.Parse()
if !acmeAgree {
fmt.Fprintf(os.Stderr, "set --acme-agree=true to accept the terms of the ACME service provider.\n")
os.Exit(1)
}
clients := []Client{}
for _, pair := range strings.Split(allclients, ", ") {
if 0 == len(pair) {
continue
}
keyval := strings.Split(pair, ":")
clients = append(clients, Client{
domain: keyval[0],
secret: keyval[1],
})
}
if 0 == len(secretKey) {
secretKey = os.Getenv("SECRET")
}
if len(secretKey) < 16 {
fmt.Fprintf(os.Stderr, "Invalid secret: %q. See --help for details.\n", secretKey)
os.Exit(1)
}
switch logfile {
case "stdout":
logoutput = os.Stdout
case "stderr":
logoutput = os.Stderr
case "":
logoutput = ioutil.Discard
default:
logoutput = &lumberjack.Logger{
Filename: logfile,
MaxSize: 100,
MaxAge: 120,
MaxBackups: 10,
}
}
// send the output io.Writing to the other packages
log.InitLogging(logoutput)
flag.IntVar(&argDeadTime, "dead-time-counter", 5, "deadtime counter in seconds")
if 0 == tcpPort {
tcpPort, _ = strconv.Atoi(os.Getenv("PORT"))
if 0 == tcpPort {
fmt.Fprintf(os.Stderr, "must specify --port or PORT\n")
os.Exit(1)
}
}
if 0 == len(adminHostName) {
adminHostName = os.Getenv("ADMIN_HOSTNAME")
}
if 0 == len(wssHostName) {
wssHostName = os.Getenv("WSS_HOSTNAME")
}
if 0 == len(wssHostName) {
wssHostName = adminHostName
}
// load balancer method
loadBalanceMethod = api.RoundRobin
if 0 == len(nickname) {
nickname = os.Getenv("NICKNAME")
}
// TODO what do these "deadtimes" do exactly?
dwell := 120
idle := 60
cancelcheck := 10
Loginfo.Println("startup")
ctx, cancelContext := context.WithCancel(context.Background())
defer cancelContext()
// CertMagic is Greenlock for Go
directory := certmagic.LetsEncryptProductionCA
if acmeStaging {
directory = certmagic.LetsEncryptStagingCA
}
magic, err := newCertMagic(directory, acmeEmail, &certmagic.FileStorage{Path: acmeStorage})
if nil != err {
fmt.Fprintf(os.Stderr, "failed to initialize certificate management (discovery url? local folder perms?): %s\n", err)
os.Exit(1)
}
serverStatus := api.NewStatus(ctx)
serverStatus.AdminDomain = adminHostName
serverStatus.WssDomain = wssHostName
serverStatus.Name = nickname
serverStatus.DeadTime = api.NewStatusDeadTime(dwell, idle, cancelcheck)
serverStatus.LoadbalanceDefaultMethod = string(loadBalanceMethod)
connectionTable := api.NewTable(dwell, idle, loadBalanceMethod)
tlsConfig := &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return magic.GetCertificate(hello)
/*
if false {
_, _ = magic.GetCertificate(hello)
}
// TODO
// 1. call out to greenlock for validation
// 2. push challenges through http channel
// 3. receive certificates (or don't)
certbundleT, err := tls.LoadX509KeyPair("certs/fullchain.pem", "certs/privkey.pem")
certbundle := &certbundleT
if err != nil {
return nil, err
}
return certbundle, nil
*/
},
}
authorizer := func(r *http.Request) (*mplexy.Authz, error) {
// do we have a valid wss_client?
var tokenString string
if auth := strings.Split(r.Header.Get("Authorization"), " "); len(auth) > 1 {
// TODO handle Basic auth tokens as well
tokenString = auth[1]
}
if "" == tokenString {
tokenString = r.URL.Query().Get("access_token")
}
tok, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(secretKey), nil
})
if nil != err {
fmt.Println("return an error, do not go on")
return nil, err
}
fmt.Printf("client claims:\n%+v\n", tok.Claims)
domains := []string{}
for _, name := range tok.Claims.(jwt.MapClaims)["domains"].([]interface{}) {
domains = append(domains, name.(string))
}
authz := &mplexy.Authz{
Domains: domains,
}
return authz, err
/*
tokenString := r.URL.Query().Get("access_token")
result, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
return []byte(secretKey), nil
})
if err != nil || !result.Valid {
w.WriteHeader(http.StatusForbidden)
w.Write([]byte("Not Authorized"))
Loginfo.Println("access_token invalid...closing connection")
return
}
// TODO
claims := result.Claims.(jwt.MapClaims)
domains, ok := claims["domains"].([]interface{})
*/
}
r := relay.New(ctx, tlsConfig, authorizer, serverStatus, connectionTable)
r.ListenAndServe(tcpPort)
}
func newCertMagic(directory string, email string, storage certmagic.Storage) (*certmagic.Config, error) {
cache := certmagic.NewCache(certmagic.CacheOptions{
GetConfigForCert: func(cert certmagic.Certificate) (*certmagic.Config, error) {
// do whatever you need to do to get the right
// configuration for this certificate; keep in
// mind that this config value is used as a
// template, and will be completed with any
// defaults that are set in the Default config
return &certmagic.Config{}, nil
},
})
provider, err := newDuckDNSProvider(os.Getenv("DUCKDNS_TOKEN"))
if err != nil {
return nil, err
}
magic := certmagic.New(cache, certmagic.Config{
Storage: storage,
OnDemand: &certmagic.OnDemandConfig{
DecisionFunc: func(name string) error {
return nil
},
},
})
// Ummm... just a little confusing
magic.Issuer = certmagic.NewACMEManager(magic, certmagic.ACMEManager{
DNSProvider: provider,
CA: directory,
Email: email,
Agreed: true,
DisableHTTPChallenge: true,
DisableTLSALPNChallenge: true,
// plus any other customizations you need
})
return magic, nil
}
// newDuckDNSProvider is for the sake of demoing the tunnel
func newDuckDNSProvider(token string) (*duckdns.DNSProvider, error) {
config := duckdns.NewDefaultConfig()
config.Token = token
return duckdns.NewDNSProviderConfig(config)
}