refactor: strip all optional/nil-guard plumbing from check-ip + geoip

- drop Checker struct, loadCohort helper, and contains() nil-wrapper
- inline check logic into server as a closure
- geoip.Databases: no nil-receiver guards, no nil-field branches, no
  "disabled" mode. City + ASN are both required; caller hands explicit
  paths and OpenDatabases returns a fully-initialized value or an err
- main.go is now straight-line wiring with no helper functions
This commit is contained in:
AJ ONeal 2026-04-20 15:55:55 -06:00
parent cdce7da04c
commit a84116f806
No known key found for this signature in database
4 changed files with 107 additions and 185 deletions

View File

@ -25,21 +25,18 @@ const (
refreshInterval = 47 * time.Minute
)
type Config struct {
Serve string
GeoIPConf string
BlocklistRepo string
CacheDir string
}
func main() {
cfg := Config{}
var (
bind string
confPath string
repoURL string
cacheDir string
)
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fs.PrintDefaults()
@ -50,34 +47,34 @@ func main() {
}
os.Exit(1)
}
if cfg.Serve == "" {
if bind == "" {
fs.Usage()
os.Exit(1)
}
if cacheDir == "" {
if d, err := os.UserCacheDir(); err == nil {
cacheDir = d
}
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
cacheDir := cfg.CacheDir
if cacheDir == "" {
base, err := os.UserCacheDir()
if err != nil {
fatal("cache-dir", err)
}
cacheDir = base
}
// Blocklists: one git repo, two views sharing the same pull.
repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "")
repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo)
inbound := dataset.Add(group, loadCohort(
inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
))
outbound := dataset.Add(group, loadCohort(
)
})
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
))
)
})
if err := group.Load(ctx); err != nil {
fatal("blocklists", err)
}
@ -85,11 +82,10 @@ func main() {
fmt.Fprintf(os.Stderr, "refresh: %v\n", err)
})
// GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf
// is available; otherwise read from disk at the cache paths.
// GeoIP: downloaded via httpcache when GeoIP.conf is available.
maxmindDir := filepath.Join(cacheDir, "maxmind")
geo, err := geoip.OpenDatabases(
cfg.GeoIPConf,
confPath,
filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"),
filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"),
)
@ -98,8 +94,7 @@ func main() {
}
defer func() { _ = geo.Close() }()
checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo}
if err := serve(ctx, cfg.Serve, checker); err != nil {
if err := serve(ctx, bind, inbound, outbound, geo); err != nil {
fatal("serve", err)
}
}
@ -108,42 +103,3 @@ func fatal(what string, err error) {
fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err)
os.Exit(1)
}
// Checker bundles the blocklist views with the optional GeoIP databases.
type Checker struct {
Inbound *dataset.View[ipcohort.Cohort]
Outbound *dataset.View[ipcohort.Cohort]
GeoIP *geoip.Databases
}
// Result is the structured verdict for a single IP.
type Result struct {
IP string `json:"ip"`
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Geo geoip.Info `json:"geo,omitzero"`
}
// Check returns the structured verdict for ip.
func (c *Checker) Check(ip string) Result {
in := contains(c.Inbound.Value(), ip)
out := contains(c.Outbound.Value(), ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: c.GeoIP.Lookup(ip),
}
}
func contains(c *ipcohort.Cohort, ip string) bool {
return c != nil && c.Contains(ip)
}
func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) {
return func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
}
}

View File

@ -13,26 +13,53 @@ import (
"time"
"github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset"
)
const shutdownTimeout = 5 * time.Second
// serve runs the HTTP API until ctx is cancelled, shutting down gracefully.
// Result is the structured verdict for a single IP.
type Result struct {
IP string `json:"ip"`
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Geo geoip.Info `json:"geo,omitzero"`
}
// serve runs the HTTP API until ctx is cancelled.
//
// GET / checks the request's client IP
// GET /check same, plus ?ip= overrides
//
// Response format is chosen per request: ?format=json, then
// Accept: application/json, else pretty text.
func serve(ctx context.Context, bind string, checker *Checker) error {
// Response format: ?format=json, then Accept: application/json, else pretty.
func serve(
ctx context.Context,
bind string,
inbound, outbound *dataset.View[ipcohort.Cohort],
geo *geoip.Databases,
) error {
check := func(ip string) Result {
in := inbound.Value().Contains(ip)
out := outbound.Value().Contains(ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: geo.Lookup(ip),
}
}
handle := func(w http.ResponseWriter, r *http.Request, ip string) {
format := requestFormat(r)
if format == formatJSON {
f := requestFormat(r)
if f == formatJSON {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
writeResult(w, checker.Check(ip), format)
write(w, check(ip), f)
}
mux := http.NewServeMux()
@ -54,7 +81,6 @@ func serve(ctx context.Context, bind string, checker *Checker) error {
return ctx
},
}
go func() {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
@ -77,7 +103,6 @@ const (
formatJSON
)
// requestFormat picks a response format from ?format=, then Accept header.
func requestFormat(r *http.Request) format {
switch r.URL.Query().Get("format") {
case "json":
@ -91,7 +116,7 @@ func requestFormat(r *http.Request) format {
return formatPretty
}
func writeResult(w io.Writer, r Result, f format) {
func write(w io.Writer, r Result, f format) {
if f == formatJSON {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
@ -108,31 +133,25 @@ func writeResult(w io.Writer, r Result, f format) {
default:
fmt.Fprintf(w, "%s is allowed\n", r.IP)
}
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string
if info.City != "" {
parts = append(parts, info.City)
if r.Geo.City != "" {
parts = append(parts, r.Geo.City)
}
if info.Region != "" {
parts = append(parts, info.Region)
if r.Geo.Region != "" {
parts = append(parts, r.Geo.Region)
}
if info.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
if r.Geo.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", r.Geo.Country, r.Geo.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if info.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
if r.Geo.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", r.Geo.ASN, r.Geo.ASNOrg)
}
}
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
// The leftmost entry in X-Forwarded-For is the originating client; intermediate
// proxies append themselves rightward.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",")

View File

@ -10,9 +10,7 @@ import (
"github.com/oschwald/geoip2-golang"
)
// Databases holds open GeoLite2 readers. A nil field means that edition
// wasn't configured. A nil *Databases means geoip is disabled; all methods
// are nil-safe no-ops so callers need not branch.
// Databases holds open GeoLite2 City + ASN readers.
type Databases struct {
City *geoip2.Reader
ASN *geoip2.Reader
@ -22,9 +20,8 @@ type Databases struct {
// GeoIP.conf with credentials is available), and opens the readers.
//
// - confPath="" → auto-discover from DefaultConfPaths
// - conf found → auto-download; cityPath/asnPath override default locations
// - conf found → auto-download to cityPath/asnPath
// - no conf → cityPath and asnPath must point to existing .mmdb files
// - no conf and no paths → returns nil, nil (geoip disabled)
func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if confPath == "" {
for _, p := range DefaultConfPaths() {
@ -40,20 +37,8 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if err != nil {
return nil, fmt.Errorf("geoip-conf: %w", err)
}
dbDir := cfg.DatabaseDirectory
if dbDir == "" {
if dbDir, err = DefaultCacheDir(); err != nil {
return nil, fmt.Errorf("geoip cache dir: %w", err)
}
}
if err := os.MkdirAll(dbDir, 0o755); err != nil {
return nil, fmt.Errorf("mkdir %s: %w", dbDir, err)
}
if cityPath == "" {
cityPath = filepath.Join(dbDir, CityEdition+".mmdb")
}
if asnPath == "" {
asnPath = filepath.Join(dbDir, ASNEdition+".mmdb")
if err := os.MkdirAll(filepath.Dir(cityPath), 0o755); err != nil {
return nil, err
}
dl := New(cfg.AccountID, cfg.LicenseKey)
if _, err := dl.NewCacher(CityEdition, cityPath).Fetch(); err != nil {
@ -62,60 +47,30 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
if _, err := dl.NewCacher(ASNEdition, asnPath).Fetch(); err != nil {
return nil, fmt.Errorf("fetch %s: %w", ASNEdition, err)
}
return Open(cityPath, asnPath)
}
if cityPath == "" && asnPath == "" {
return nil, nil
}
return Open(cityPath, asnPath)
}
// Open opens city and ASN .mmdb files from the given paths. Empty paths are
// treated as unconfigured (the corresponding field stays nil).
// Open opens city and ASN .mmdb files from the given paths.
func Open(cityPath, asnPath string) (*Databases, error) {
d := &Databases{}
if cityPath != "" {
r, err := geoip2.Open(cityPath)
city, err := geoip2.Open(cityPath)
if err != nil {
return nil, fmt.Errorf("open %s: %w", cityPath, err)
}
d.City = r
}
if asnPath != "" {
r, err := geoip2.Open(asnPath)
asn, err := geoip2.Open(asnPath)
if err != nil {
if d.City != nil {
_ = d.City.Close()
}
_ = city.Close()
return nil, fmt.Errorf("open %s: %w", asnPath, err)
}
d.ASN = r
}
return d, nil
return &Databases{City: city, ASN: asn}, nil
}
// Close closes any open readers. No-op on nil receiver.
// Close closes the city and ASN readers.
func (d *Databases) Close() error {
if d == nil {
return nil
}
var errs []error
if d.City != nil {
if err := d.City.Close(); err != nil {
errs = append(errs, err)
}
}
if d.ASN != nil {
if err := d.ASN.Close(); err != nil {
errs = append(errs, err)
}
}
return errors.Join(errs...)
return errors.Join(d.City.Close(), d.ASN.Close())
}
// Info is the structured result of a GeoIP lookup. Zero-valued fields mean
// the database didn't return a value (or wasn't configured).
// Info is the structured result of a GeoIP lookup.
type Info struct {
City string `json:"city,omitempty"`
Region string `json:"region,omitempty"`
@ -125,20 +80,16 @@ type Info struct {
ASNOrg string `json:"asn_org,omitempty"`
}
// Lookup returns city + ASN info for ip. Returns a zero Info on nil receiver,
// unparseable IP, or database miss.
// Lookup returns city + ASN info for ip. Returns a zero Info on unparseable
// IP or database miss.
func (d *Databases) Lookup(ip string) Info {
var info Info
if d == nil {
return info
}
addr, err := netip.ParseAddr(ip)
if err != nil {
return info
}
stdIP := addr.AsSlice()
if d.City != nil {
if rec, err := d.City.City(stdIP); err == nil {
info.City = rec.City.Names["en"]
info.Country = rec.Country.Names["en"]
@ -149,13 +100,9 @@ func (d *Databases) Lookup(ip string) Info {
}
}
}
}
if d.ASN != nil {
if rec, err := d.ASN.ASN(stdIP); err == nil {
info.ASN = rec.AutonomousSystemNumber
info.ASNOrg = rec.AutonomousSystemOrganization
}
}
return info
}