feat(check-ip): add --whitelist override, require GeoIP.conf

Whitelist is a combined IP+CIDR cohort file polled for mtime changes;
a match short-circuits the blocklist check and marks the result
allowlisted. Drops the geoip PollFiles fallback — missing GeoIP.conf
now fails fast instead of silently polling local tarballs.
This commit is contained in:
AJ ONeal 2026-04-20 17:05:55 -06:00
parent 159cf2d4d3
commit f293f86b16
No known key found for this signature in database
2 changed files with 64 additions and 38 deletions

View File

@ -36,15 +36,16 @@ type IPCheck struct {
GeoIPConfPath string
RepoURL string
CacheDir string
WhitelistPath string
// GeoIPBasicAuth is the pre-encoded Authorization header value for
// MaxMind downloads. Empty when no GeoIP.conf was found — in that case
// the .tar.gz archives must already exist in <CacheDir>/maxmind/.
// MaxMind downloads.
GeoIPBasicAuth string
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
geo *dataset.View[geoip.Databases]
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
whitelist *dataset.View[ipcohort.Cohort]
geo *dataset.View[geoip.Databases]
}
func main() {
@ -54,6 +55,7 @@ func main() {
fs.StringVar(&cfg.GeoIPConfPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.StringVar(&cfg.WhitelistPath, "whitelist", "", "path to a file of IPs and/or CIDRs (one per line) that override block decisions")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip> [ip...]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s --serve <bind> [flags]\n", os.Args[0])
@ -135,33 +137,28 @@ func main() {
log.Fatalf("blocklists: %v", err)
}
// GeoIP: with credentials, download the City + ASN tar.gz archives via
// httpcache conditional GETs. Without them, poll the existing tar.gz
// files in maxmindDir. geoip.Open extracts in-memory — no .mmdb files
// GeoIP: download the City + ASN tar.gz archives via httpcache
// conditional GETs. geoip.Open extracts in-memory — no .mmdb files
// are written to disk.
maxmindDir := filepath.Join(cfg.CacheDir, "maxmind")
cityTarPath := filepath.Join(maxmindDir, "GeoLite2-City.tar.gz")
asnTarPath := filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz")
var geoSet *dataset.Set
if cfg.GeoIPBasicAuth != "" {
authHeader := http.Header{"Authorization": []string{cfg.GeoIPBasicAuth}}
geoSet = dataset.NewSet(
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz",
Path: cityTarPath,
MaxAge: 3 * 24 * time.Hour,
Header: authHeader,
},
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz",
Path: asnTarPath,
MaxAge: 3 * 24 * time.Hour,
Header: authHeader,
},
)
} else {
geoSet = dataset.NewSet(dataset.PollFiles(cityTarPath, asnTarPath))
if cfg.GeoIPBasicAuth == "" {
log.Fatalf("geoip-conf: not found; set --geoip-conf or place GeoIP.conf in a default location")
}
maxmindDir := filepath.Join(cfg.CacheDir, "maxmind")
authHeader := http.Header{"Authorization": []string{cfg.GeoIPBasicAuth}}
geoSet := dataset.NewSet(
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz",
Path: filepath.Join(maxmindDir, "GeoLite2-City.tar.gz"),
MaxAge: 3 * 24 * time.Hour,
Header: authHeader,
},
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz",
Path: filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz"),
MaxAge: 3 * 24 * time.Hour,
Header: authHeader,
},
)
cfg.geo = dataset.Add(geoSet, func() (*geoip.Databases, error) {
return geoip.Open(maxmindDir)
})
@ -170,6 +167,19 @@ func main() {
}
defer func() { _ = cfg.geo.Value().Close() }()
// Whitelist: combined IPs + CIDRs in one file, polled for mtime changes.
// A match here overrides any block decision from the blocklists.
var whitelistSet *dataset.Set
if cfg.WhitelistPath != "" {
whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath))
cfg.whitelist = dataset.Add(whitelistSet, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFile(cfg.WhitelistPath)
})
if err := whitelistSet.Load(context.Background()); err != nil {
log.Fatalf("whitelist: %v", err)
}
}
for _, ip := range ips {
cfg.writeText(os.Stdout, cfg.lookup(ip))
}
@ -185,6 +195,11 @@ func main() {
go geoSet.Tick(ctx, refreshInterval, func(err error) {
log.Printf("geoip refresh: %v", err)
})
if whitelistSet != nil {
go whitelistSet.Tick(ctx, refreshInterval, func(err error) {
log.Printf("whitelist refresh: %v", err)
})
}
if err := cfg.serve(ctx); err != nil {
log.Fatalf("serve: %v", err)
}

View File

@ -9,6 +9,7 @@ import (
"log"
"net"
"net/http"
"net/netip"
"strings"
"time"
@ -21,26 +22,36 @@ type Result struct {
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Allowlisted bool `json:"allowlisted,omitzero"`
Geo geoip.Info `json:"geo,omitzero"`
}
// lookup builds a Result for ip against the currently loaded blocklists
// and GeoIP databases.
func (c *IPCheck) lookup(ip string) Result {
in := c.inbound.Value().Contains(ip)
out := c.outbound.Value().Contains(ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: c.geo.Value().Lookup(ip),
res := Result{IP: ip, Geo: c.geo.Value().Lookup(ip)}
addr, err := netip.ParseAddr(ip)
if err != nil {
res.Blocked = true
res.BlockedInbound = true
res.BlockedOutbound = true
return res
}
if c.whitelist != nil && c.whitelist.Value().ContainsAddr(addr) {
res.Allowlisted = true
return res
}
res.BlockedInbound = c.inbound.Value().ContainsAddr(addr)
res.BlockedOutbound = c.outbound.Value().ContainsAddr(addr)
res.Blocked = res.BlockedInbound || res.BlockedOutbound
return res
}
// writeText renders res as human-readable plain text.
func (c *IPCheck) writeText(w io.Writer, res Result) {
switch {
case res.Allowlisted:
fmt.Fprintf(w, "%s is ALLOWED (whitelist)\n", res.IP)
case res.BlockedInbound && res.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", res.IP)
case res.BlockedInbound: