From a84116f806080cc51dedc39831a2ac3bcb21b5e6 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 15:55:55 -0600 Subject: [PATCH] refactor: strip all optional/nil-guard plumbing from check-ip + geoip - drop Checker struct, loadCohort helper, and contains() nil-wrapper - inline check logic into server as a closure - geoip.Databases: no nil-receiver guards, no nil-field branches, no "disabled" mode. City + ASN are both required; caller hands explicit paths and OpenDatabases returns a fully-initialized value or an err - main.go is now straight-line wiring with no helper functions --- cmd/check-ip/main.go | 108 ++++++++++++---------------------------- cmd/check-ip/server.go | 67 ++++++++++++++++--------- net/geoip/conf.go | 8 +-- net/geoip/databases.go | 109 +++++++++++------------------------------ 4 files changed, 107 insertions(+), 185 deletions(-) diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index 2093353..27386a3 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -25,21 +25,18 @@ const ( refreshInterval = 47 * time.Minute ) -type Config struct { - Serve string - GeoIPConf string - BlocklistRepo string - CacheDir string -} - func main() { - cfg := Config{} - + var ( + bind string + confPath string + repoURL string + cacheDir string + ) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080") - fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") - fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") - fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") + fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080") + fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") + fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") + fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s --serve [flags]\n", os.Args[0]) fs.PrintDefaults() @@ -50,34 +47,34 @@ func main() { } os.Exit(1) } - if cfg.Serve == "" { + if bind == "" { fs.Usage() os.Exit(1) } + if cacheDir == "" { + if d, err := os.UserCacheDir(); err == nil { + cacheDir = d + } + } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - cacheDir := cfg.CacheDir - if cacheDir == "" { - base, err := os.UserCacheDir() - if err != nil { - fatal("cache-dir", err) - } - cacheDir = base - } - // Blocklists: one git repo, two views sharing the same pull. - repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "") + repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "") group := dataset.NewGroup(repo) - inbound := dataset.Add(group, loadCohort( - repo.FilePath("tables/inbound/single_ips.txt"), - repo.FilePath("tables/inbound/networks.txt"), - )) - outbound := dataset.Add(group, loadCohort( - repo.FilePath("tables/outbound/single_ips.txt"), - repo.FilePath("tables/outbound/networks.txt"), - )) + inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { + return ipcohort.LoadFiles( + repo.FilePath("tables/inbound/single_ips.txt"), + repo.FilePath("tables/inbound/networks.txt"), + ) + }) + outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { + return ipcohort.LoadFiles( + repo.FilePath("tables/outbound/single_ips.txt"), + repo.FilePath("tables/outbound/networks.txt"), + ) + }) if err := group.Load(ctx); err != nil { fatal("blocklists", err) } @@ -85,11 +82,10 @@ func main() { fmt.Fprintf(os.Stderr, "refresh: %v\n", err) }) - // GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf - // is available; otherwise read from disk at the cache paths. + // GeoIP: downloaded via httpcache when GeoIP.conf is available. maxmindDir := filepath.Join(cacheDir, "maxmind") geo, err := geoip.OpenDatabases( - cfg.GeoIPConf, + confPath, filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"), filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"), ) @@ -98,8 +94,7 @@ func main() { } defer func() { _ = geo.Close() }() - checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo} - if err := serve(ctx, cfg.Serve, checker); err != nil { + if err := serve(ctx, bind, inbound, outbound, geo); err != nil { fatal("serve", err) } } @@ -108,42 +103,3 @@ func fatal(what string, err error) { fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err) os.Exit(1) } - -// Checker bundles the blocklist views with the optional GeoIP databases. -type Checker struct { - Inbound *dataset.View[ipcohort.Cohort] - Outbound *dataset.View[ipcohort.Cohort] - GeoIP *geoip.Databases -} - -// Result is the structured verdict for a single IP. -type Result struct { - IP string `json:"ip"` - Blocked bool `json:"blocked"` - BlockedInbound bool `json:"blocked_inbound"` - BlockedOutbound bool `json:"blocked_outbound"` - Geo geoip.Info `json:"geo,omitzero"` -} - -// Check returns the structured verdict for ip. -func (c *Checker) Check(ip string) Result { - in := contains(c.Inbound.Value(), ip) - out := contains(c.Outbound.Value(), ip) - return Result{ - IP: ip, - Blocked: in || out, - BlockedInbound: in, - BlockedOutbound: out, - Geo: c.GeoIP.Lookup(ip), - } -} - -func contains(c *ipcohort.Cohort, ip string) bool { - return c != nil && c.Contains(ip) -} - -func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) { - return func() (*ipcohort.Cohort, error) { - return ipcohort.LoadFiles(paths...) - } -} diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index 284dae1..68c6dd2 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -13,26 +13,53 @@ import ( "time" "github.com/therootcompany/golib/net/geoip" + "github.com/therootcompany/golib/net/ipcohort" + "github.com/therootcompany/golib/sync/dataset" ) const shutdownTimeout = 5 * time.Second -// serve runs the HTTP API until ctx is cancelled, shutting down gracefully. +// Result is the structured verdict for a single IP. +type Result struct { + IP string `json:"ip"` + Blocked bool `json:"blocked"` + BlockedInbound bool `json:"blocked_inbound"` + BlockedOutbound bool `json:"blocked_outbound"` + Geo geoip.Info `json:"geo,omitzero"` +} + +// serve runs the HTTP API until ctx is cancelled. // // GET / checks the request's client IP // GET /check same, plus ?ip= overrides // -// Response format is chosen per request: ?format=json, then -// Accept: application/json, else pretty text. -func serve(ctx context.Context, bind string, checker *Checker) error { +// Response format: ?format=json, then Accept: application/json, else pretty. +func serve( + ctx context.Context, + bind string, + inbound, outbound *dataset.View[ipcohort.Cohort], + geo *geoip.Databases, +) error { + check := func(ip string) Result { + in := inbound.Value().Contains(ip) + out := outbound.Value().Contains(ip) + return Result{ + IP: ip, + Blocked: in || out, + BlockedInbound: in, + BlockedOutbound: out, + Geo: geo.Lookup(ip), + } + } + handle := func(w http.ResponseWriter, r *http.Request, ip string) { - format := requestFormat(r) - if format == formatJSON { + f := requestFormat(r) + if f == formatJSON { w.Header().Set("Content-Type", "application/json; charset=utf-8") } else { w.Header().Set("Content-Type", "text/plain; charset=utf-8") } - writeResult(w, checker.Check(ip), format) + write(w, check(ip), f) } mux := http.NewServeMux() @@ -54,7 +81,6 @@ func serve(ctx context.Context, bind string, checker *Checker) error { return ctx }, } - go func() { <-ctx.Done() shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) @@ -77,7 +103,6 @@ const ( formatJSON ) -// requestFormat picks a response format from ?format=, then Accept header. func requestFormat(r *http.Request) format { switch r.URL.Query().Get("format") { case "json": @@ -91,7 +116,7 @@ func requestFormat(r *http.Request) format { return formatPretty } -func writeResult(w io.Writer, r Result, f format) { +func write(w io.Writer, r Result, f format) { if f == formatJSON { enc := json.NewEncoder(w) enc.SetIndent("", " ") @@ -108,31 +133,25 @@ func writeResult(w io.Writer, r Result, f format) { default: fmt.Fprintf(w, "%s is allowed\n", r.IP) } - writeGeo(w, r.Geo) -} - -func writeGeo(w io.Writer, info geoip.Info) { var parts []string - if info.City != "" { - parts = append(parts, info.City) + if r.Geo.City != "" { + parts = append(parts, r.Geo.City) } - if info.Region != "" { - parts = append(parts, info.Region) + if r.Geo.Region != "" { + parts = append(parts, r.Geo.Region) } - if info.Country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO)) + if r.Geo.Country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", r.Geo.Country, r.Geo.CountryISO)) } if len(parts) > 0 { fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) } - if info.ASN != 0 { - fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg) + if r.Geo.ASN != 0 { + fmt.Fprintf(w, " ASN: AS%d %s\n", r.Geo.ASN, r.Geo.ASNOrg) } } // clientIP extracts the caller's IP, honoring X-Forwarded-For when present. -// The leftmost entry in X-Forwarded-For is the originating client; intermediate -// proxies append themselves rightward. func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { first, _, _ := strings.Cut(xff, ",") diff --git a/net/geoip/conf.go b/net/geoip/conf.go index 2e034c5..7c1aaf0 100644 --- a/net/geoip/conf.go +++ b/net/geoip/conf.go @@ -9,10 +9,10 @@ import ( // Conf holds the fields parsed from a geoipupdate-style config file. type Conf struct { - AccountID string - LicenseKey string - EditionIDs []string - DatabaseDirectory string + AccountID string + LicenseKey string + EditionIDs []string + DatabaseDirectory string } // ParseConf reads a geoipupdate-style config file (whitespace-separated diff --git a/net/geoip/databases.go b/net/geoip/databases.go index 66a013d..e87c43c 100644 --- a/net/geoip/databases.go +++ b/net/geoip/databases.go @@ -10,9 +10,7 @@ import ( "github.com/oschwald/geoip2-golang" ) -// Databases holds open GeoLite2 readers. A nil field means that edition -// wasn't configured. A nil *Databases means geoip is disabled; all methods -// are nil-safe no-ops so callers need not branch. +// Databases holds open GeoLite2 City + ASN readers. type Databases struct { City *geoip2.Reader ASN *geoip2.Reader @@ -22,9 +20,8 @@ type Databases struct { // GeoIP.conf with credentials is available), and opens the readers. // // - confPath="" → auto-discover from DefaultConfPaths -// - conf found → auto-download; cityPath/asnPath override default locations +// - conf found → auto-download to cityPath/asnPath // - no conf → cityPath and asnPath must point to existing .mmdb files -// - no conf and no paths → returns nil, nil (geoip disabled) func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) { if confPath == "" { for _, p := range DefaultConfPaths() { @@ -40,20 +37,8 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) { if err != nil { return nil, fmt.Errorf("geoip-conf: %w", err) } - dbDir := cfg.DatabaseDirectory - if dbDir == "" { - if dbDir, err = DefaultCacheDir(); err != nil { - return nil, fmt.Errorf("geoip cache dir: %w", err) - } - } - if err := os.MkdirAll(dbDir, 0o755); err != nil { - return nil, fmt.Errorf("mkdir %s: %w", dbDir, err) - } - if cityPath == "" { - cityPath = filepath.Join(dbDir, CityEdition+".mmdb") - } - if asnPath == "" { - asnPath = filepath.Join(dbDir, ASNEdition+".mmdb") + if err := os.MkdirAll(filepath.Dir(cityPath), 0o755); err != nil { + return nil, err } dl := New(cfg.AccountID, cfg.LicenseKey) if _, err := dl.NewCacher(CityEdition, cityPath).Fetch(); err != nil { @@ -62,60 +47,30 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) { if _, err := dl.NewCacher(ASNEdition, asnPath).Fetch(); err != nil { return nil, fmt.Errorf("fetch %s: %w", ASNEdition, err) } - return Open(cityPath, asnPath) - } - - if cityPath == "" && asnPath == "" { - return nil, nil } return Open(cityPath, asnPath) } -// Open opens city and ASN .mmdb files from the given paths. Empty paths are -// treated as unconfigured (the corresponding field stays nil). +// Open opens city and ASN .mmdb files from the given paths. func Open(cityPath, asnPath string) (*Databases, error) { - d := &Databases{} - if cityPath != "" { - r, err := geoip2.Open(cityPath) - if err != nil { - return nil, fmt.Errorf("open %s: %w", cityPath, err) - } - d.City = r + city, err := geoip2.Open(cityPath) + if err != nil { + return nil, fmt.Errorf("open %s: %w", cityPath, err) } - if asnPath != "" { - r, err := geoip2.Open(asnPath) - if err != nil { - if d.City != nil { - _ = d.City.Close() - } - return nil, fmt.Errorf("open %s: %w", asnPath, err) - } - d.ASN = r + asn, err := geoip2.Open(asnPath) + if err != nil { + _ = city.Close() + return nil, fmt.Errorf("open %s: %w", asnPath, err) } - return d, nil + return &Databases{City: city, ASN: asn}, nil } -// Close closes any open readers. No-op on nil receiver. +// Close closes the city and ASN readers. func (d *Databases) Close() error { - if d == nil { - return nil - } - var errs []error - if d.City != nil { - if err := d.City.Close(); err != nil { - errs = append(errs, err) - } - } - if d.ASN != nil { - if err := d.ASN.Close(); err != nil { - errs = append(errs, err) - } - } - return errors.Join(errs...) + return errors.Join(d.City.Close(), d.ASN.Close()) } -// Info is the structured result of a GeoIP lookup. Zero-valued fields mean -// the database didn't return a value (or wasn't configured). +// Info is the structured result of a GeoIP lookup. type Info struct { City string `json:"city,omitempty"` Region string `json:"region,omitempty"` @@ -125,37 +80,29 @@ type Info struct { ASNOrg string `json:"asn_org,omitempty"` } -// Lookup returns city + ASN info for ip. Returns a zero Info on nil receiver, -// unparseable IP, or database miss. +// Lookup returns city + ASN info for ip. Returns a zero Info on unparseable +// IP or database miss. func (d *Databases) Lookup(ip string) Info { var info Info - if d == nil { - return info - } addr, err := netip.ParseAddr(ip) if err != nil { return info } stdIP := addr.AsSlice() - if d.City != nil { - if rec, err := d.City.City(stdIP); err == nil { - info.City = rec.City.Names["en"] - info.Country = rec.Country.Names["en"] - info.CountryISO = rec.Country.IsoCode - if len(rec.Subdivisions) > 0 { - if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != info.City { - info.Region = sub - } + if rec, err := d.City.City(stdIP); err == nil { + info.City = rec.City.Names["en"] + info.Country = rec.Country.Names["en"] + info.CountryISO = rec.Country.IsoCode + if len(rec.Subdivisions) > 0 { + if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != info.City { + info.Region = sub } } } - if d.ASN != nil { - if rec, err := d.ASN.ASN(stdIP); err == nil { - info.ASN = rec.AutonomousSystemNumber - info.ASNOrg = rec.AutonomousSystemOrganization - } + if rec, err := d.ASN.ASN(stdIP); err == nil { + info.ASN = rec.AutonomousSystemNumber + info.ASNOrg = rec.AutonomousSystemOrganization } return info } -