refactor(httpcache): use http.Header instead of AuthHeader/AuthValue

Cacher.Header is a stdlib http.Header that's merged into every request.
Authorization is stripped on redirect unconditionally (presigned S3/R2
targets, etc). Callers build the header with the usual http.Header
literal; BasicAuth/Bearer still produce the Authorization value.
This commit is contained in:
AJ ONeal 2026-04-20 16:55:15 -06:00
parent 4753888402
commit f75d5c489a
No known key found for this signature in database
4 changed files with 39 additions and 41 deletions

View File

@ -9,6 +9,7 @@ import (
"flag" "flag"
"fmt" "fmt"
"log" "log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@ -143,20 +144,19 @@ func main() {
asnTarPath := filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz") asnTarPath := filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz")
var geoSet *dataset.Set var geoSet *dataset.Set
if cfg.GeoIPBasicAuth != "" { if cfg.GeoIPBasicAuth != "" {
authHeader := http.Header{"Authorization": []string{cfg.GeoIPBasicAuth}}
geoSet = dataset.NewSet( geoSet = dataset.NewSet(
&httpcache.Cacher{ &httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz", URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz",
Path: cityTarPath, Path: cityTarPath,
MaxAge: 3 * 24 * time.Hour, MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization", Header: authHeader,
AuthValue: cfg.GeoIPBasicAuth,
}, },
&httpcache.Cacher{ &httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz", URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz",
Path: asnTarPath, Path: asnTarPath,
MaxAge: 3 * 24 * time.Hour, MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization", Header: authHeader,
AuthValue: cfg.GeoIPBasicAuth,
}, },
) )
} else { } else {

View File

@ -3,6 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"time" "time"
@ -45,7 +46,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
auth := httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey) authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey)}}
maxAge := time.Duration(*freshDays) * 24 * time.Hour maxAge := time.Duration(*freshDays) * 24 * time.Hour
exitCode := 0 exitCode := 0
@ -55,8 +56,7 @@ func main() {
URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz",
Path: path, Path: path,
MaxAge: maxAge, MaxAge: maxAge,
AuthHeader: "Authorization", Header: authHeader,
AuthValue: auth,
} }
updated, err := cacher.Fetch() updated, err := cacher.Fetch()
if err != nil { if err != nil {

View File

@ -3,6 +3,7 @@
package geoip_test package geoip_test
import ( import (
"net/http"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -52,8 +53,9 @@ func newCacher(cfg *geoip.Conf, edition, path string) *httpcache.Cacher {
return &httpcache.Cacher{ return &httpcache.Cacher{
URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz",
Path: path, Path: path,
AuthHeader: "Authorization", Header: http.Header{
AuthValue: httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey), "Authorization": []string{httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey)},
},
} }
} }

View File

@ -14,14 +14,14 @@ import (
) )
// BasicAuth returns an HTTP Basic Authorization header value: // BasicAuth returns an HTTP Basic Authorization header value:
// "Basic " + base64(user:pass). Assign to Cacher.AuthValue with // "Basic " + base64(user:pass). Pair with the "Authorization" header in
// AuthHeader "Authorization". // Cacher.Header.
func BasicAuth(user, pass string) string { func BasicAuth(user, pass string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass))
} }
// Bearer returns a Bearer Authorization header value: "Bearer " + token. // Bearer returns a Bearer Authorization header value: "Bearer " + token.
// Assign to Cacher.AuthValue with AuthHeader "Authorization". // Pair with the "Authorization" header in Cacher.Header.
func Bearer(token string) string { func Bearer(token string) string {
return "Bearer " + token return "Bearer " + token
} }
@ -45,12 +45,10 @@ const (
// Caching — ETag and Last-Modified values are persisted to a <path>.meta // Caching — ETag and Last-Modified values are persisted to a <path>.meta
// sidecar file so conditional GETs survive process restarts. // sidecar file so conditional GETs survive process restarts.
// //
// Auth — AuthHeader/AuthValue set a request header on every attempt. Auth is // Header — any values in Header are sent on every request. Authorization
// stripped before following redirects so presigned targets (e.g. S3/R2 URLs) // headers are stripped before following redirects so presigned targets
// never receive credentials. Use any scheme: "Authorization"/"Bearer token", // (e.g. S3/R2 URLs) never receive credentials. The BasicAuth and Bearer
// "X-API-Key"/"secret", "Authorization"/"Basic base64(user:pass)", etc. The // helpers produce Authorization values for the common cases.
// BasicAuth and Bearer helpers produce the right AuthValue for the common
// cases.
type Cacher struct { type Cacher struct {
URL string URL string
Path string Path string
@ -58,8 +56,7 @@ type Cacher struct {
Timeout time.Duration // 0 uses 5m; caps overall request including body read Timeout time.Duration // 0 uses 5m; caps overall request including body read
MaxAge time.Duration // 0 disables; skip HTTP if file mtime is within this MaxAge time.Duration // 0 disables; skip HTTP if file mtime is within this
MinInterval time.Duration // 0 disables; skip HTTP if last Fetch attempt was within this MinInterval time.Duration // 0 disables; skip HTTP if last Fetch attempt was within this
AuthHeader string // e.g. "Authorization" or "X-API-Key" Header http.Header // headers sent on every request (Authorization is stripped on redirect)
AuthValue string // e.g. "Bearer token" or "Basic base64(user:pass)"
mu sync.Mutex mu sync.Mutex
etag string etag string
@ -166,20 +163,19 @@ func (c *Cacher) Fetch() (updated bool, err error) {
TLSHandshakeTimeout: connTimeout, TLSHandshakeTimeout: connTimeout,
} }
if c.AuthHeader != "" { for k, vs := range c.Header {
req.Header.Set(c.AuthHeader, c.AuthValue) for _, v := range vs {
req.Header.Add(k, v)
}
} }
client := &http.Client{Timeout: timeout, Transport: transport} client := &http.Client{Timeout: timeout, Transport: transport}
if c.AuthHeader != "" { // Strip Authorization before following any redirect — redirect targets
// Strip auth before following any redirect — redirect targets (e.g. // (e.g. presigned S3/R2 URLs) must not receive our credentials.
// presigned S3/R2 URLs) must not receive our credentials.
authHeader := c.AuthHeader
client.CheckRedirect = func(req *http.Request, via []*http.Request) error { client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
req.Header.Del(authHeader) req.Header.Del("Authorization")
return nil return nil
} }
}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {