mirror of
https://github.com/therootcompany/golib.git
synced 2026-03-02 23:57:59 +00:00
feat: add cmd/auth-proxy to add Basic, Bearer, X-API-Key, or access_token auth to routes
This commit is contained in:
parent
8ef2f73cb0
commit
13eeb6793b
13
cmd/auth-proxy/go.mod
Normal file
13
cmd/auth-proxy/go.mod
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
module github.com/therootcompany/golib/cmd/auth-proxy
|
||||||
|
|
||||||
|
go 1.25.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/joho/godotenv v1.5.1
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.2
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/therootcompany/golib/auth v1.0.0 // indirect
|
||||||
|
golang.org/x/crypto v0.42.0 // indirect
|
||||||
|
)
|
||||||
12
cmd/auth-proxy/go.sum
Normal file
12
cmd/auth-proxy/go.sum
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
||||||
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
|
github.com/therootcompany/golib/auth v1.0.0 h1:17hfwcJO/Efc22/8RcCTKUD49mhCc5tyoHiKonA3Slg=
|
||||||
|
github.com/therootcompany/golib/auth v1.0.0/go.mod h1:DSw8llmDkMtvMZWrzrTRtcaLPpPMsT6Sg+qwGf5O2U8=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.0 h1:EgCT0pft4AhJz1omOntca5axv/B/1686wGIOgAS1++k=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.0/go.mod h1:e4kjRLWD7DWsuUT8+5MvPSYygXY6Is2UVeA9UpDllWU=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.1 h1:iy6S69nj/+Us0nx5pXke7SevgH1YaEM1vgbne7bnq70=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.1/go.mod h1:e4kjRLWD7DWsuUT8+5MvPSYygXY6Is2UVeA9UpDllWU=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.2 h1:t14Y8fSPhUZS6J+gRj0EFjkfE2JdIq93vOiKYcfItE8=
|
||||||
|
github.com/therootcompany/golib/auth/csvauth v1.2.2/go.mod h1://Iy5nbRa/bSndWa8b/NM4jclPVgkFWWhYUH6Ppelz0=
|
||||||
|
golang.org/x/crypto v0.42.0 h1:chiH31gIWm57EkTXpwnqf8qeuMUi0yekh6mT2AvFlqI=
|
||||||
|
golang.org/x/crypto v0.42.0/go.mod h1:4+rDnOTJhQCx2q7/j6rAN5XDw8kPjeaXEUR2eL94ix8=
|
||||||
680
cmd/auth-proxy/main.go
Normal file
680
cmd/auth-proxy/main.go
Normal file
@ -0,0 +1,680 @@
|
|||||||
|
// auth-proxy - A reverse proxy to require Basic Auth, Bearer Token, or access_token
|
||||||
|
//
|
||||||
|
// Copyright 2026 AJ ONeal <aj@therootcompany.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||||
|
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
|
||||||
|
//
|
||||||
|
// This Source Code Form is "Incompatible With Secondary Licenses", as
|
||||||
|
// defined by the Mozilla Public License, v. 2.0.
|
||||||
|
//
|
||||||
|
// SPDX-License-Identifier: MPL-2.0
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/joho/godotenv"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/auth"
|
||||||
|
"github.com/therootcompany/golib/auth/csvauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
name = "auth-proxy"
|
||||||
|
licenseYear = "2026"
|
||||||
|
licenseOwner = "AJ ONeal <aj@therootcompany.com> (https://therootcompany.com)"
|
||||||
|
licenseType = "CC0-1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
// replaced by goreleaser / ldflags
|
||||||
|
var (
|
||||||
|
version = "0.0.0-dev"
|
||||||
|
commit = "0000000"
|
||||||
|
date = "0001-01-01"
|
||||||
|
)
|
||||||
|
|
||||||
|
// printVersion displays the version, commit, and build date.
|
||||||
|
func printVersion(w io.Writer) {
|
||||||
|
_, _ = fmt.Fprintf(w, "%s v%s %s (%s)\n", name, version, commit[:7], date)
|
||||||
|
_, _ = fmt.Fprintf(w, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
|
||||||
|
_, _ = fmt.Fprintf(w, "Licensed under %s\n", licenseType)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrNoAuth = errors.New("request missing the required form of authorization")
|
||||||
|
)
|
||||||
|
|
||||||
|
var creds *csvauth.Auth
|
||||||
|
|
||||||
|
const basicAPIKeyName = ""
|
||||||
|
|
||||||
|
type MainConfig struct {
|
||||||
|
Address string
|
||||||
|
Port int
|
||||||
|
CredentialsPath string
|
||||||
|
ProxyTarget string
|
||||||
|
AES128KeyPath string
|
||||||
|
ShowVersion bool
|
||||||
|
AuthorizationHeaderSchemes []string
|
||||||
|
TokenHeaderNames []string
|
||||||
|
QueryParamNames []string
|
||||||
|
comma rune
|
||||||
|
commaString string
|
||||||
|
tokenSchemeList string
|
||||||
|
tokenHeaderList string
|
||||||
|
tokenParamList string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *MainConfig) Addr() string {
|
||||||
|
return fmt.Sprintf("%s:%d", c.Address, c.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cli := MainConfig{
|
||||||
|
Address: "0.0.0.0",
|
||||||
|
Port: 8081,
|
||||||
|
CredentialsPath: "./credentials.tsv",
|
||||||
|
ProxyTarget: "http://127.0.0.1:8080",
|
||||||
|
AES128KeyPath: filepath.Join("~", ".config", "csvauth", "aes-128.key"),
|
||||||
|
comma: '\t',
|
||||||
|
commaString: "",
|
||||||
|
tokenSchemeList: "",
|
||||||
|
tokenHeaderList: "",
|
||||||
|
tokenParamList: "",
|
||||||
|
AuthorizationHeaderSchemes: nil, // []string{"Bearer", "Token"}
|
||||||
|
TokenHeaderNames: nil, // []string{"X-API-Key", "X-Auth-Token", "X-Access-Token"},
|
||||||
|
QueryParamNames: nil, // []string{"access_token", "token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Peek for --envfile early
|
||||||
|
envPath := peekOption(os.Args[1:], []string{"-envfile", "--envfile"}, ".env")
|
||||||
|
_ = godotenv.Load(envPath) // silent if missing
|
||||||
|
|
||||||
|
// Override defaults from env
|
||||||
|
if v := os.Getenv("AUTHPROXY_PORT"); v != "" {
|
||||||
|
if _, err := fmt.Sscanf(v, "%d", &cli.Port); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "invalid AUTHPROXY_PORT value: %s\n", v)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := os.Getenv("AUTHPROXY_ADDRESS"); v != "" {
|
||||||
|
cli.Address = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("AUTHPROXY_CREDENTIALS_FILE"); v != "" {
|
||||||
|
cli.CredentialsPath = v
|
||||||
|
}
|
||||||
|
if v := os.Getenv("AUTHPROXY_TARGET"); v != "" {
|
||||||
|
cli.ProxyTarget = v
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flags
|
||||||
|
fs := flag.NewFlagSet(name, flag.ContinueOnError)
|
||||||
|
fs.BoolVar(&cli.ShowVersion, "version", false, "show version and exit")
|
||||||
|
fs.IntVar(&cli.Port, "port", cli.Port, "port to listen on")
|
||||||
|
fs.StringVar(&cli.Address, "address", cli.Address, "address to bind to (e.g. 127.0.0.1)")
|
||||||
|
fs.StringVar(&cli.AES128KeyPath, "aes-128-key", cli.AES128KeyPath, "path to credentials TSV/CSV file")
|
||||||
|
fs.StringVar(&cli.CredentialsPath, "credentials", cli.CredentialsPath, "path to credentials TSV/CSV file")
|
||||||
|
fs.StringVar(&cli.ProxyTarget, "proxy-target", cli.ProxyTarget, "upstream target to proxy requests to")
|
||||||
|
fs.StringVar(&cli.commaString, "comma", "\\t", "single-character CSV separator for credentials file (literal characters and escapes accepted)")
|
||||||
|
fs.StringVar(&cli.tokenSchemeList, "token-schemes", "Bearer,Token", "checks for header 'Authorization: <Scheme> <token>'")
|
||||||
|
fs.StringVar(&cli.tokenHeaderList, "token-headers", "X-API-Key,X-Auth-Token,X-Access-Token", "checks for header '<API-Key-Header>: <token>'")
|
||||||
|
fs.StringVar(&cli.tokenParamList, "token-params", "access_token,token", "checks for query param '?<param>=<token>'")
|
||||||
|
|
||||||
|
fs.Usage = func() {
|
||||||
|
fmt.Fprintf(os.Stderr, "USAGE\n %s [flags]\n\n", name)
|
||||||
|
fmt.Fprintf(os.Stderr, "FLAGS\n")
|
||||||
|
fs.PrintDefaults()
|
||||||
|
fmt.Fprintf(os.Stderr, "\nENVIRONMENT\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " AUTHPROXY_PORT port to listen on\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " AUTHPROXY_ADDRESS bind address\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " AUTHPROXY_CREDENTIALS_FILE path to tokens file\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " AUTHPROXY_TARGET upstream URL\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for version/help
|
||||||
|
if len(os.Args) > 1 {
|
||||||
|
arg := os.Args[1]
|
||||||
|
if arg == "-V" || arg == "--version" || arg == "version" {
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
if arg == "help" || arg == "-help" || arg == "--help" {
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
_, _ = fmt.Fprintln(os.Stdout, "")
|
||||||
|
fs.SetOutput(os.Stdout)
|
||||||
|
fs.Usage()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
printVersion(os.Stderr)
|
||||||
|
fmt.Fprintln(os.Stderr, "")
|
||||||
|
|
||||||
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||||
|
if err == flag.ErrHelp {
|
||||||
|
fs.Usage()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
log.Fatalf("flag parse error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
homedir, err := os.UserHomeDir()
|
||||||
|
if err == nil {
|
||||||
|
var found bool
|
||||||
|
if cli.AES128KeyPath, found = strings.CutPrefix(cli.AES128KeyPath, "~"); found {
|
||||||
|
cli.AES128KeyPath = homedir + cli.AES128KeyPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// credentials file delimiter
|
||||||
|
var err error
|
||||||
|
cli.comma, err = DecodeDelimiter(cli.commaString)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("comma parse error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorization: <Scheme> <token>
|
||||||
|
cli.tokenSchemeList = strings.TrimSpace(cli.tokenSchemeList)
|
||||||
|
if cli.tokenSchemeList != "" && cli.tokenSchemeList != "none" {
|
||||||
|
cli.tokenSchemeList = strings.ReplaceAll(cli.tokenSchemeList, ",", " ")
|
||||||
|
cli.AuthorizationHeaderSchemes = strings.Fields(cli.tokenSchemeList)
|
||||||
|
if len(cli.AuthorizationHeaderSchemes) == 1 && cli.AuthorizationHeaderSchemes[0] == "" {
|
||||||
|
cli.AuthorizationHeaderSchemes = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// <API-Key-Header>: <token>
|
||||||
|
// trick: this allows `Authorization: <token>` without the scheme
|
||||||
|
cli.tokenHeaderList = strings.TrimSpace(cli.tokenHeaderList)
|
||||||
|
if cli.tokenHeaderList != "" && cli.tokenHeaderList != "none" {
|
||||||
|
cli.tokenHeaderList = strings.ReplaceAll(cli.tokenHeaderList, ",", " ")
|
||||||
|
cli.TokenHeaderNames = strings.Fields(cli.tokenHeaderList)
|
||||||
|
if len(cli.TokenHeaderNames) == 1 && cli.TokenHeaderNames[0] == "" {
|
||||||
|
cli.TokenHeaderNames = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ?<param>=<token>
|
||||||
|
// trick: this allows `Authorization: <token>` without the scheme
|
||||||
|
cli.tokenParamList = strings.TrimSpace(cli.tokenParamList)
|
||||||
|
if cli.tokenParamList != "" && cli.tokenParamList != "none" {
|
||||||
|
cli.tokenParamList = strings.ReplaceAll(cli.tokenParamList, ",", " ")
|
||||||
|
cli.QueryParamNames = strings.Fields(cli.tokenParamList)
|
||||||
|
if len(cli.QueryParamNames) == 1 && cli.QueryParamNames[0] == "" {
|
||||||
|
cli.QueryParamNames = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
run(&cli)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileSeparator = '\x1c'
|
||||||
|
groupSeparator = '\x1d'
|
||||||
|
recordSeparator = '\x1e'
|
||||||
|
unitSeparator = '\x1f'
|
||||||
|
)
|
||||||
|
|
||||||
|
func DecodeDelimiter(delimString string) (rune, error) {
|
||||||
|
switch delimString {
|
||||||
|
case "^_", "\\x1f":
|
||||||
|
delimString = string(unitSeparator)
|
||||||
|
case "^^", "\\x1e":
|
||||||
|
delimString = string(recordSeparator)
|
||||||
|
case "^]", "\\x1d":
|
||||||
|
delimString = string(groupSeparator)
|
||||||
|
case "^\\", "\\x1c":
|
||||||
|
delimString = string(fileSeparator)
|
||||||
|
case "^L", "\\f":
|
||||||
|
delimString = "\f"
|
||||||
|
case "^K", "\\v":
|
||||||
|
delimString = "\v"
|
||||||
|
case "^I", "\\t":
|
||||||
|
delimString = "\t"
|
||||||
|
default:
|
||||||
|
// it is what it is
|
||||||
|
}
|
||||||
|
delim, _ := utf8.DecodeRuneInString(delimString)
|
||||||
|
return delim, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func run(cli *MainConfig) {
|
||||||
|
defaultAESKeyENVName := "CSVAUTH_AES_128_KEY"
|
||||||
|
aesKey, err := getAESKey(defaultAESKeyENVName, cli.AES128KeyPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load credentials from CSV/TSV file once at startup
|
||||||
|
f, err := os.Open(cli.CredentialsPath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to open credentials file %q: %v", cli.CredentialsPath, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = f.Close() }()
|
||||||
|
|
||||||
|
creds = csvauth.New(aesKey)
|
||||||
|
if err := creds.LoadCSV(f, cli.comma); err != nil {
|
||||||
|
log.Fatalf("Failed to load CSV auth: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var usableRoles int
|
||||||
|
for key := range creds.CredentialKeys() {
|
||||||
|
u, err := creds.LoadCredential(key)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to read users from CSV auth: %v", err)
|
||||||
|
}
|
||||||
|
if len(u.Roles) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if usableRoles == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Current credentials, tokens, and permissions:\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " %s\t%s\t%s\n", u.Purpose, u.ID(), strings.Join(u.Roles, " "))
|
||||||
|
}
|
||||||
|
usableRoles += 1
|
||||||
|
}
|
||||||
|
|
||||||
|
var warnRoles bool
|
||||||
|
for key := range creds.CredentialKeys() {
|
||||||
|
u, _ := creds.LoadCredential(key)
|
||||||
|
if len(u.Roles) > 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !warnRoles {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "WARNING - Please Read\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "The following credentials cannot be used because they contain no Roles:\n")
|
||||||
|
warnRoles = true
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, " %q (%s)\n", u.Name, u.Purpose)
|
||||||
|
}
|
||||||
|
if warnRoles {
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "Permission must be explicitly granted in the Roles column as a space-separated\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "list of URI matchers in the form of \"[METHOD:][HOST]/[PATH]\"\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "CLI Examples\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " authcsv store --roles '/' john.doe\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " authcsv store --roles 'GET:example.com/mail POST:example.com/feed' --token openclaw\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "CSV Examples\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " GET:example.com/ # GET-only access to example.com, for all paths\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " / # Full access to everything\n")
|
||||||
|
fmt.Fprintf(os.Stderr, " GET:/ POST:/logs # GET anything, POST only to /logs/... \n")
|
||||||
|
fmt.Fprintf(os.Stderr, " ex1.com/ GET:ex2.net/logs # full access to ex1.com, GET-only for ex2.net\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
fmt.Fprintf(os.Stderr, "\n")
|
||||||
|
}
|
||||||
|
if usableRoles == 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: no usable credentials found\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build proxy handler
|
||||||
|
handler := cli.newAuthProxyHandler(cli.ProxyTarget)
|
||||||
|
|
||||||
|
// Server setup
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: cli.Addr(),
|
||||||
|
Handler: handler,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 60 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown
|
||||||
|
done := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-done
|
||||||
|
log.Println("Shutting down server...")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(ctx); err != nil {
|
||||||
|
log.Printf("HTTP server shutdown error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
log.Printf("Starting %s v%s on %s → %s", name, version, srv.Addr, cli.ProxyTarget)
|
||||||
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
log.Fatalf("HTTP server failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Server stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *MainConfig) newAuthProxyHandler(targetURL string) http.Handler {
|
||||||
|
target, err := url.Parse(targetURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("invalid proxy target %q: %v", targetURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := &httputil.ReverseProxy{
|
||||||
|
Rewrite: func(r *httputil.ProxyRequest) {
|
||||||
|
r.SetURL(target)
|
||||||
|
r.Out.Host = r.In.Host // preserve original Host header
|
||||||
|
// X-Forwarded-* headers are preserved from incoming request
|
||||||
|
},
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
log.Printf("proxy error: %v", err)
|
||||||
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !cli.authorize(r) {
|
||||||
|
// TODO allow --realm for `WWW-Authenticate: Basic realm="My Application"`
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic`)
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
proxy.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *MainConfig) authorize(r *http.Request) bool {
|
||||||
|
cred, err := cli.authenticate(r)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, ErrNoAuth) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
cred, err = creds.Authenticate("guest", "")
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
grants := cred.Permissions()
|
||||||
|
if len(grants) == 0 {
|
||||||
|
// must have at least '/'
|
||||||
|
fmt.Fprintf(os.Stderr, "Warn: user %q correctly authenticated, but no --roles were specified (assign * or / for full access)\n", cred.ID())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if grants[0] == "*" || grants[0] == "/" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET,POST example.com/path/{$}
|
||||||
|
for _, grant := range grants {
|
||||||
|
if matchPattern(grant, r.Method, r.Host, r.URL.Path) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// patternMatch returns true for a grant in the form of a ServeMux pattern matches the current request
|
||||||
|
// (though : is used instead of space since space is already used as a separator)
|
||||||
|
//
|
||||||
|
// grant is in the form `[METHOD:][HOST]/[PATH]`, but METHOD may be a comma-separated list GET,POST...
|
||||||
|
// rMthod must be ALL_CAPS
|
||||||
|
// rHost may be HOSTNAME or HOSTNAME:PORT
|
||||||
|
// rPath must start with /
|
||||||
|
func matchPattern(grant, rMethod, rHost, rPath string) bool {
|
||||||
|
// this should have been done already, but...
|
||||||
|
grant = strings.TrimSpace(grant)
|
||||||
|
|
||||||
|
// must have at least /
|
||||||
|
// ""
|
||||||
|
if grant == "" {
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: missing grant\n")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// / => ["/"]
|
||||||
|
// /path => ["/path"]
|
||||||
|
// example.com/path => ["example.com/path"]
|
||||||
|
// GET:example.com/path => ["GET", "example.com/path"]
|
||||||
|
{
|
||||||
|
var methodSep = ":"
|
||||||
|
var methods []string
|
||||||
|
grantParts := strings.SplitN(grant, methodSep, 2)
|
||||||
|
switch len(grantParts) {
|
||||||
|
case 1:
|
||||||
|
// no method
|
||||||
|
case 2:
|
||||||
|
if len(grantParts) == 2 {
|
||||||
|
methods = strings.Split(strings.ToUpper(grantParts[0]), ",")
|
||||||
|
if !slices.Contains(methods, rMethod) {
|
||||||
|
// TODO maybe propagate method-not-allowed?
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: method %q != %q\n", rMethod, grantParts[0])
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
grant = grantParts[1]
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: extraneous spaces in grant %q\n", grant)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// / => /
|
||||||
|
// /path => /path
|
||||||
|
// example.com/path => /path
|
||||||
|
idx := strings.Index(grant, "/")
|
||||||
|
if idx < 0 {
|
||||||
|
// host without path is invalid
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: missing leading / from grant %q\n", grant)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
hostname := grant[:idx]
|
||||||
|
if hostname != "" {
|
||||||
|
// example.com:443 => example.com
|
||||||
|
if h, _, _ := strings.Cut(rHost, ":"); hostname != h {
|
||||||
|
// hostname doesn't match
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: hostname %q != %q\n", rHost, hostname)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
grant = grant[idx:]
|
||||||
|
|
||||||
|
// Prefix-only matching
|
||||||
|
//
|
||||||
|
// /path => /path
|
||||||
|
// /path/ => /path
|
||||||
|
// /path/{var}/bar => /path/{var}/bar
|
||||||
|
// /path/{var...} = /path/{var}
|
||||||
|
// /path/{$} => /path
|
||||||
|
// var exact bool
|
||||||
|
// grant, exact = strings.CutSuffix(grant, "/{$}")
|
||||||
|
// grant, _ = strings.CutSuffix(grant, "/")
|
||||||
|
// rPath, _ = strings.CutSuffix(rPath, "/")
|
||||||
|
// if len(grantPaths) > len(rPaths) {
|
||||||
|
// return false
|
||||||
|
// } else if len(grantPaths) < len(rPaths) {
|
||||||
|
// if exact {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // TODO replace with pattern matching as per https://pkg.go.dev/net/http#hdr-Patterns-ServeMux
|
||||||
|
// // /path/{var}/bar matches /path/foo/bar and /path/foo/bar/
|
||||||
|
// // /path/{var...} matches /path/ and /path/foo/bar
|
||||||
|
// // /path/{var}/bar/{$} matches /path/foo/bar and /path/baz/bar/ but not /path/foo/
|
||||||
|
// for i := 1; i < len(grantPaths); i++ {
|
||||||
|
// grantPath := grantPaths[i]
|
||||||
|
// rPath := rPaths[i]
|
||||||
|
// if strings.HasPrefix(grantPath, "{") {
|
||||||
|
// continue
|
||||||
|
// }
|
||||||
|
// if rPath != grantPath {
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// return true
|
||||||
|
|
||||||
|
// ServeMux pattern matching
|
||||||
|
nextGPath, gstop := iter.Pull(strings.SplitSeq(grant, "/"))
|
||||||
|
nextRPath, rstop := iter.Pull(strings.SplitSeq(rPath, "/"))
|
||||||
|
defer gstop()
|
||||||
|
defer rstop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
gp, gok := nextGPath()
|
||||||
|
rp, rok := nextRPath()
|
||||||
|
// everything has matched thus far, and the pattern has ended
|
||||||
|
if !gok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// false unless the extra length of the pattern signifies the exact match, disregarding trailing /
|
||||||
|
if !rok {
|
||||||
|
// this matches trailing /, {var}, {var}/, {var...}, and {$}
|
||||||
|
if gp == "" || (strings.HasPrefix(gp, "{") && strings.HasSuffix(gp, "}")) {
|
||||||
|
gp2, more := nextGPath()
|
||||||
|
// this allows for one more final trailing /, but nothing else
|
||||||
|
if !more {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if gp2 == "" {
|
||||||
|
// two trailing slashes are not allowed
|
||||||
|
_, more := nextGPath()
|
||||||
|
return !more
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// path parts are only allowed to disagree for trailing slashes and variables
|
||||||
|
if gp != rp {
|
||||||
|
// this allows for one more final trailing / on the pattern, but nothing else
|
||||||
|
if gp == "" {
|
||||||
|
_, more := nextGPath()
|
||||||
|
return !more
|
||||||
|
}
|
||||||
|
// this allows for a placeholder in the pattern
|
||||||
|
if strings.HasPrefix(gp, "{") && strings.HasSuffix(gp, "}") {
|
||||||
|
// normal variables pass
|
||||||
|
if gp != "{$}" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// trailing slash on exact match passes
|
||||||
|
if rp == "" {
|
||||||
|
_, more := nextRPath()
|
||||||
|
return !more
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: path past {$} %q vs %q\n", rp, gp)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "DEBUG: path part %q != %q\n", rp, gp)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cli *MainConfig) authenticate(r *http.Request) (auth.BasicPrinciple, error) {
|
||||||
|
// 1. Try Basic Auth first (cleanest path)
|
||||||
|
username, password, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
// Authorization: Basic <Auth> exists
|
||||||
|
return creds.Authenticate(username, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Any Authorization: <scheme> <token>
|
||||||
|
if len(cli.AuthorizationHeaderSchemes) > 0 {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader != "" {
|
||||||
|
parts := strings.SplitN(authHeader, " ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
if cli.AuthorizationHeaderSchemes[0] == "*" ||
|
||||||
|
slices.Contains(cli.AuthorizationHeaderSchemes, parts[0]) {
|
||||||
|
token := strings.TrimSpace(parts[1])
|
||||||
|
// Authorization: <Scheme> <Token> exists
|
||||||
|
return creds.Authenticate(basicAPIKeyName, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("'Authorization' header is not properly formatted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. API-Key / X-API-Key headers
|
||||||
|
for _, h := range cli.TokenHeaderNames {
|
||||||
|
if key := r.Header.Get(h); key != "" {
|
||||||
|
// <TokenHeader>: <Token> exists
|
||||||
|
return creds.Authenticate(basicAPIKeyName, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. access_token query param
|
||||||
|
for _, h := range cli.QueryParamNames {
|
||||||
|
if token := r.URL.Query().Get(h); token != "" {
|
||||||
|
// <query_param>?=<Token> exists
|
||||||
|
return creds.Authenticate(basicAPIKeyName, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, ErrNoAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
// peekOption looks for a flag value without parsing the full set
|
||||||
|
func peekOption(args []string, names []string, def string) string {
|
||||||
|
for i := range len(args) {
|
||||||
|
for _, name := range names {
|
||||||
|
if args[i] == name {
|
||||||
|
if i+1 < len(args) {
|
||||||
|
return args[i+1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO expose this from csvauth
|
||||||
|
func getAESKey(envname, filename string) ([]byte, error) {
|
||||||
|
envKey := os.Getenv(envname)
|
||||||
|
if envKey != "" {
|
||||||
|
key, err := hex.DecodeString(strings.TrimSpace(envKey))
|
||||||
|
if err != nil || len(key) != 16 {
|
||||||
|
return nil, fmt.Errorf("invalid %s: must be 32-char hex string", envname)
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Found AES Key in %s\n", envname)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(filename); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read %s: %v", filename, err)
|
||||||
|
}
|
||||||
|
key, err := hex.DecodeString(strings.TrimSpace(string(data)))
|
||||||
|
if err != nil || len(key) != 16 {
|
||||||
|
return nil, fmt.Errorf("invalid key in %s: must be 32-char hex string", filename)
|
||||||
|
}
|
||||||
|
// relpath := strings.Replace(filename, homedir, "~", 1)
|
||||||
|
fmt.Fprintf(os.Stderr, "Found AES Key at %s\n", filename)
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
112
cmd/auth-proxy/pattern_test.go
Normal file
112
cmd/auth-proxy/pattern_test.go
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMatchPattern(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
grant string
|
||||||
|
method string
|
||||||
|
host string
|
||||||
|
path string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
// Basic path matching
|
||||||
|
{"/", "GET", "example.com", "/", true},
|
||||||
|
{"GET:/", "POST", "example.com", "/", false},
|
||||||
|
{"/api/users", "GET", "api.example.com", "/api/users", true},
|
||||||
|
{"/api/users", "GET", "api.example.com", "/api/users/", true},
|
||||||
|
{"/api/users", "GET", "api.example.com", "/api/users", true},
|
||||||
|
{"/api/users/", "GET", "", "/api/users", true},
|
||||||
|
|
||||||
|
// Host matching
|
||||||
|
{"example.com/", "GET", "example.com", "/", true},
|
||||||
|
{"GET:example.com/", "GET", "example.com", "/", true},
|
||||||
|
{"whatever.net/", "GET", "example.com", "/", false},
|
||||||
|
{"example.comz/", "GET", "example.com", "/", false},
|
||||||
|
{"example.com/", "GET", "example.comz", "/", false},
|
||||||
|
{"aexample.com/", "GET", "example.com", "/", false},
|
||||||
|
{"example.com/", "GET", "aexample.com", "/", false},
|
||||||
|
{"example.com/", "GET", "api.example.com", "/", false},
|
||||||
|
{"api.example.com/", "GET", "example.com", "/", false},
|
||||||
|
{".example.com/", "GET", "api.example.com", "/", false},
|
||||||
|
{"api.example.com/", "GET", "", "/", false},
|
||||||
|
{"GET:api.example.com/", "GET", "example.com", "/", false},
|
||||||
|
{"example.com/", "GET", "example.com:443", "/", true},
|
||||||
|
{"GET:example.com/", "GET", "example.com:443", "/", true},
|
||||||
|
|
||||||
|
// Method lists
|
||||||
|
{"GET,POST,PUT:/api", "POST", "", "/api", true},
|
||||||
|
{"GET,DELETE:/api", "POST", "", "/api", false},
|
||||||
|
|
||||||
|
// Wildcard / placeholder segments
|
||||||
|
// bad
|
||||||
|
{"/users/{id}", "GET", "", "/user", false},
|
||||||
|
// good
|
||||||
|
{"/users/{id}", "GET", "", "/users", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/friends", true},
|
||||||
|
// bad
|
||||||
|
{"/users/{id}/", "GET", "", "/user", false},
|
||||||
|
// good
|
||||||
|
{"/users/{id}/", "GET", "", "/users", true},
|
||||||
|
{"/users/{id}/", "GET", "", "/users/", true},
|
||||||
|
{"/users/{id}/", "GET", "", "/users/123", true},
|
||||||
|
{"/users/{id}/", "GET", "", "/users/123/", true},
|
||||||
|
{"/users/{id}/", "GET", "", "/users/123/friends", true},
|
||||||
|
// good (these are exactly the same as /path/{var} above, but added for completeness)
|
||||||
|
{"/users/{id...}", "GET", "", "/users", true},
|
||||||
|
{"/users/{id...}", "GET", "", "/users/", true},
|
||||||
|
{"/users/{id...}", "GET", "", "/users/123", true},
|
||||||
|
{"/users/{id...}", "GET", "", "/users/123/", true},
|
||||||
|
{"/users/{id...}", "GET", "", "/users/123/friends", true},
|
||||||
|
// good
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar/", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar/456", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar/", true},
|
||||||
|
{"/users/{id}", "GET", "", "/users/123/bar/456", true},
|
||||||
|
// good
|
||||||
|
{"/users/{id}/{$}", "GET", "", "/users/123", true},
|
||||||
|
{"/users/{id}/{$}", "GET", "", "/users/123/", true},
|
||||||
|
// wrong
|
||||||
|
{"/users/{id}/bar", "GET", "", "/users/123", false},
|
||||||
|
{"/users/{id}/bar", "GET", "", "/users/123/", false},
|
||||||
|
{"/users/{id}/bar", "GET", "", "/users/123/b", false},
|
||||||
|
{"/users/{id}/bar", "GET", "", "/users/123/b/", false},
|
||||||
|
{"/users/{id}/bar", "GET", "", "/users/123/b/456", false},
|
||||||
|
{"/users/{id}/{$}", "GET", "", "/users/123/b", false},
|
||||||
|
// {$}
|
||||||
|
{"/api/{ver}/items/{$}", "GET", "", "/api/v1/items", true},
|
||||||
|
{"/api/{ver}/items/{$}", "GET", "", "/api/v1/items/", true},
|
||||||
|
{"/api/{ver}/items/{$}", "GET", "", "/api/v1/items/42", false},
|
||||||
|
{"/foo/{$}/baz", "GET", "", "/foo/bar/baz", false},
|
||||||
|
|
||||||
|
// Edge cases & invalid patterns
|
||||||
|
{"", "GET", "example.com", "/", false},
|
||||||
|
{"", "GET", "", "/", false},
|
||||||
|
{"GET", "GET", "example.com", "/", false},
|
||||||
|
{"GET:", "GET", "example.com", "/", false},
|
||||||
|
{"example.com", "GET", "example.com", "/", false},
|
||||||
|
{"GET:example.com", "GET", "example.com", "/", false},
|
||||||
|
{"GET:/users ", "GET", "", "/users", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
// name := "Pattern " + tt.grant + " vs URI " + strings.TrimSpace(fmt.Sprintf("%s %s%s", tt.method, tt.host, tt.path))
|
||||||
|
name := tt.grant + " vs " + strings.TrimSpace(fmt.Sprintf("%s %s%s", tt.method, tt.host, tt.path))
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
got := matchPattern(tt.grant, tt.method, tt.host, tt.path)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("matchPattern(%q, %q, %q, %q) = %v, want %v",
|
||||||
|
tt.grant, tt.method, tt.host, tt.path, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user