refactor: strip all optional/nil-guard plumbing from check-ip + geoip

- drop Checker struct, loadCohort helper, and contains() nil-wrapper
- inline check logic into server as a closure
- geoip.Databases: no nil-receiver guards, no nil-field branches, no
  "disabled" mode. City + ASN are both required; caller hands explicit
  paths and OpenDatabases returns a fully-initialized value or an err
- main.go is now straight-line wiring with no helper functions
This commit is contained in:
AJ ONeal 2026-04-20 15:55:55 -06:00
parent cdce7da04c
commit a84116f806
No known key found for this signature in database
4 changed files with 107 additions and 185 deletions

View File

@ -25,21 +25,18 @@ const (
refreshInterval = 47 * time.Minute refreshInterval = 47 * time.Minute
) )
type Config struct {
Serve string
GeoIPConf string
BlocklistRepo string
CacheDir string
}
func main() { func main() {
cfg := Config{} var (
bind string
confPath string
repoURL string
cacheDir string
)
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080") fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.Usage = func() { fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fs.PrintDefaults() fs.PrintDefaults()
@ -50,34 +47,34 @@ func main() {
} }
os.Exit(1) os.Exit(1)
} }
if cfg.Serve == "" { if bind == "" {
fs.Usage() fs.Usage()
os.Exit(1) os.Exit(1)
} }
if cacheDir == "" {
if d, err := os.UserCacheDir(); err == nil {
cacheDir = d
}
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop() defer stop()
cacheDir := cfg.CacheDir
if cacheDir == "" {
base, err := os.UserCacheDir()
if err != nil {
fatal("cache-dir", err)
}
cacheDir = base
}
// Blocklists: one git repo, two views sharing the same pull. // Blocklists: one git repo, two views sharing the same pull.
repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "") repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo) group := dataset.NewGroup(repo)
inbound := dataset.Add(group, loadCohort( inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"), repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"), repo.FilePath("tables/inbound/networks.txt"),
)) )
outbound := dataset.Add(group, loadCohort( })
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"), repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"), repo.FilePath("tables/outbound/networks.txt"),
)) )
})
if err := group.Load(ctx); err != nil { if err := group.Load(ctx); err != nil {
fatal("blocklists", err) fatal("blocklists", err)
} }
@ -85,11 +82,10 @@ func main() {
fmt.Fprintf(os.Stderr, "refresh: %v\n", err) fmt.Fprintf(os.Stderr, "refresh: %v\n", err)
}) })
// GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf // GeoIP: downloaded via httpcache when GeoIP.conf is available.
// is available; otherwise read from disk at the cache paths.
maxmindDir := filepath.Join(cacheDir, "maxmind") maxmindDir := filepath.Join(cacheDir, "maxmind")
geo, err := geoip.OpenDatabases( geo, err := geoip.OpenDatabases(
cfg.GeoIPConf, confPath,
filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"), filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"),
filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"), filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"),
) )
@ -98,8 +94,7 @@ func main() {
} }
defer func() { _ = geo.Close() }() defer func() { _ = geo.Close() }()
checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo} if err := serve(ctx, bind, inbound, outbound, geo); err != nil {
if err := serve(ctx, cfg.Serve, checker); err != nil {
fatal("serve", err) fatal("serve", err)
} }
} }
@ -108,42 +103,3 @@ func fatal(what string, err error) {
fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err) fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err)
os.Exit(1) os.Exit(1)
} }
// Checker bundles the blocklist views with the optional GeoIP databases.
type Checker struct {
Inbound *dataset.View[ipcohort.Cohort]
Outbound *dataset.View[ipcohort.Cohort]
GeoIP *geoip.Databases
}
// Result is the structured verdict for a single IP.
type Result struct {
IP string `json:"ip"`
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Geo geoip.Info `json:"geo,omitzero"`
}
// Check returns the structured verdict for ip.
func (c *Checker) Check(ip string) Result {
in := contains(c.Inbound.Value(), ip)
out := contains(c.Outbound.Value(), ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: c.GeoIP.Lookup(ip),
}
}
func contains(c *ipcohort.Cohort, ip string) bool {
return c != nil && c.Contains(ip)
}
func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) {
return func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
}
}

View File

@ -13,26 +13,53 @@ import (
"time" "time"
"github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset"
) )
const shutdownTimeout = 5 * time.Second const shutdownTimeout = 5 * time.Second
// serve runs the HTTP API until ctx is cancelled, shutting down gracefully. // Result is the structured verdict for a single IP.
type Result struct {
IP string `json:"ip"`
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Geo geoip.Info `json:"geo,omitzero"`
}
// serve runs the HTTP API until ctx is cancelled.
// //
// GET / checks the request's client IP // GET / checks the request's client IP
// GET /check same, plus ?ip= overrides // GET /check same, plus ?ip= overrides
// //
// Response format is chosen per request: ?format=json, then // Response format: ?format=json, then Accept: application/json, else pretty.
// Accept: application/json, else pretty text. func serve(
func serve(ctx context.Context, bind string, checker *Checker) error { ctx context.Context,
bind string,
inbound, outbound *dataset.View[ipcohort.Cohort],
geo *geoip.Databases,
) error {
check := func(ip string) Result {
in := inbound.Value().Contains(ip)
out := outbound.Value().Contains(ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: geo.Lookup(ip),
}
}
handle := func(w http.ResponseWriter, r *http.Request, ip string) { handle := func(w http.ResponseWriter, r *http.Request, ip string) {
format := requestFormat(r) f := requestFormat(r)
if format == formatJSON { if f == formatJSON {
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else { } else {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
} }
writeResult(w, checker.Check(ip), format) write(w, check(ip), f)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
@ -54,7 +81,6 @@ func serve(ctx context.Context, bind string, checker *Checker) error {
return ctx return ctx
}, },
} }
go func() { go func() {
<-ctx.Done() <-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
@ -77,7 +103,6 @@ const (
formatJSON formatJSON
) )
// requestFormat picks a response format from ?format=, then Accept header.
func requestFormat(r *http.Request) format { func requestFormat(r *http.Request) format {
switch r.URL.Query().Get("format") { switch r.URL.Query().Get("format") {
case "json": case "json":
@ -91,7 +116,7 @@ func requestFormat(r *http.Request) format {
return formatPretty return formatPretty
} }
func writeResult(w io.Writer, r Result, f format) { func write(w io.Writer, r Result, f format) {
if f == formatJSON { if f == formatJSON {
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
enc.SetIndent("", " ") enc.SetIndent("", " ")
@ -108,31 +133,25 @@ func writeResult(w io.Writer, r Result, f format) {
default: default:
fmt.Fprintf(w, "%s is allowed\n", r.IP) fmt.Fprintf(w, "%s is allowed\n", r.IP)
} }
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string var parts []string
if info.City != "" { if r.Geo.City != "" {
parts = append(parts, info.City) parts = append(parts, r.Geo.City)
} }
if info.Region != "" { if r.Geo.Region != "" {
parts = append(parts, info.Region) parts = append(parts, r.Geo.Region)
} }
if info.Country != "" { if r.Geo.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO)) parts = append(parts, fmt.Sprintf("%s (%s)", r.Geo.Country, r.Geo.CountryISO))
} }
if len(parts) > 0 { if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
} }
if info.ASN != 0 { if r.Geo.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg) fmt.Fprintf(w, " ASN: AS%d %s\n", r.Geo.ASN, r.Geo.ASNOrg)
} }
} }
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present. // clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
// The leftmost entry in X-Forwarded-For is the originating client; intermediate
// proxies append themselves rightward.
func clientIP(r *http.Request) string { func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",") first, _, _ := strings.Cut(xff, ",")

View File

@ -10,9 +10,7 @@ import (
"github.com/oschwald/geoip2-golang" "github.com/oschwald/geoip2-golang"
) )
// Databases holds open GeoLite2 readers. A nil field means that edition // Databases holds open GeoLite2 City + ASN readers.
// wasn't configured. A nil *Databases means geoip is disabled; all methods
// are nil-safe no-ops so callers need not branch.
type Databases struct { type Databases struct {
City *geoip2.Reader City *geoip2.Reader
ASN *geoip2.Reader ASN *geoip2.Reader
@ -22,9 +20,8 @@ type Databases struct {
// GeoIP.conf with credentials is available), and opens the readers. // GeoIP.conf with credentials is available), and opens the readers.
// //
// - confPath="" → auto-discover from DefaultConfPaths // - confPath="" → auto-discover from DefaultConfPaths
// - conf found → auto-download; cityPath/asnPath override default locations // - conf found → auto-download to cityPath/asnPath
// - no conf → cityPath and asnPath must point to existing .mmdb files // - no conf → cityPath and asnPath must point to existing .mmdb files
// - no conf and no paths → returns nil, nil (geoip disabled)
func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) { func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if confPath == "" { if confPath == "" {
for _, p := range DefaultConfPaths() { for _, p := range DefaultConfPaths() {
@ -40,20 +37,8 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("geoip-conf: %w", err) return nil, fmt.Errorf("geoip-conf: %w", err)
} }
dbDir := cfg.DatabaseDirectory if err := os.MkdirAll(filepath.Dir(cityPath), 0o755); err != nil {
if dbDir == "" { return nil, err
if dbDir, err = DefaultCacheDir(); err != nil {
return nil, fmt.Errorf("geoip cache dir: %w", err)
}
}
if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("mkdir %s: %w", dbDir, err)
}
if cityPath == "" {
cityPath = filepath.Join(dbDir, CityEdition+".mmdb")
}
if asnPath == "" {
asnPath = filepath.Join(dbDir, ASNEdition+".mmdb")
} }
dl := New(cfg.AccountID, cfg.LicenseKey) dl := New(cfg.AccountID, cfg.LicenseKey)
if _, err := dl.NewCacher(CityEdition, cityPath).Fetch(); err != nil { if _, err := dl.NewCacher(CityEdition, cityPath).Fetch(); err != nil {
@ -62,60 +47,30 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if _, err := dl.NewCacher(ASNEdition, asnPath).Fetch(); err != nil { if _, err := dl.NewCacher(ASNEdition, asnPath).Fetch(); err != nil {
return nil, fmt.Errorf("fetch %s: %w", ASNEdition, err) return nil, fmt.Errorf("fetch %s: %w", ASNEdition, err)
} }
return Open(cityPath, asnPath)
}
if cityPath == "" && asnPath == "" {
return nil, nil
} }
return Open(cityPath, asnPath) return Open(cityPath, asnPath)
} }
// Open opens city and ASN .mmdb files from the given paths. Empty paths are // Open opens city and ASN .mmdb files from the given paths.
// treated as unconfigured (the corresponding field stays nil).
func Open(cityPath, asnPath string) (*Databases, error) { func Open(cityPath, asnPath string) (*Databases, error) {
d := &Databases{} city, err := geoip2.Open(cityPath)
if cityPath != "" {
r, err := geoip2.Open(cityPath)
if err != nil { if err != nil {
return nil, fmt.Errorf("open %s: %w", cityPath, err) return nil, fmt.Errorf("open %s: %w", cityPath, err)
} }
d.City = r asn, err := geoip2.Open(asnPath)
}
if asnPath != "" {
r, err := geoip2.Open(asnPath)
if err != nil { if err != nil {
if d.City != nil { _ = city.Close()
_ = d.City.Close()
}
return nil, fmt.Errorf("open %s: %w", asnPath, err) return nil, fmt.Errorf("open %s: %w", asnPath, err)
} }
d.ASN = r return &Databases{City: city, ASN: asn}, nil
}
return d, nil
} }
// Close closes any open readers. No-op on nil receiver. // Close closes the city and ASN readers.
func (d *Databases) Close() error { func (d *Databases) Close() error {
if d == nil { return errors.Join(d.City.Close(), d.ASN.Close())
return nil
}
var errs []error
if d.City != nil {
if err := d.City.Close(); err != nil {
errs = append(errs, err)
}
}
if d.ASN != nil {
if err := d.ASN.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
} }
// Info is the structured result of a GeoIP lookup. Zero-valued fields mean // Info is the structured result of a GeoIP lookup.
// the database didn't return a value (or wasn't configured).
type Info struct { type Info struct {
City string `json:"city,omitempty"` City string `json:"city,omitempty"`
Region string `json:"region,omitempty"` Region string `json:"region,omitempty"`
@ -125,20 +80,16 @@ type Info struct {
ASNOrg string `json:"asn_org,omitempty"` ASNOrg string `json:"asn_org,omitempty"`
} }
// Lookup returns city + ASN info for ip. Returns a zero Info on nil receiver, // Lookup returns city + ASN info for ip. Returns a zero Info on unparseable
// unparseable IP, or database miss. // IP or database miss.
func (d *Databases) Lookup(ip string) Info { func (d *Databases) Lookup(ip string) Info {
var info Info var info Info
if d == nil {
return info
}
addr, err := netip.ParseAddr(ip) addr, err := netip.ParseAddr(ip)
if err != nil { if err != nil {
return info return info
} }
stdIP := addr.AsSlice() stdIP := addr.AsSlice()
if d.City != nil {
if rec, err := d.City.City(stdIP); err == nil { if rec, err := d.City.City(stdIP); err == nil {
info.City = rec.City.Names["en"] info.City = rec.City.Names["en"]
info.Country = rec.Country.Names["en"] info.Country = rec.Country.Names["en"]
@ -149,13 +100,9 @@ func (d *Databases) Lookup(ip string) Info {
} }
} }
} }
}
if d.ASN != nil {
if rec, err := d.ASN.ASN(stdIP); err == nil { if rec, err := d.ASN.ASN(stdIP); err == nil {
info.ASN = rec.AutonomousSystemNumber info.ASN = rec.AutonomousSystemNumber
info.ASNOrg = rec.AutonomousSystemOrganization info.ASNOrg = rec.AutonomousSystemOrganization
} }
}
return info return info
} }