feat: geoip.ParseConf, geoip-update uses it, check-ip auto-downloads+hot-swaps GeoIP

geoip.ParseConf() extracted from geoip-update into the geoip package so
both cmds can read GeoIP.conf without duplication.

check-ip gains -geoip-conf flag: reads AccountID+LicenseKey, resolves
mmdb paths into data-dir, builds httpcache.Cachers with geoip.NewCacher.
Background runLoop now refreshes both blocklists and GeoIP DBs on each
tick, hot-swapping geoip2.Reader via atomic.Pointer.Swap + old.Close().
This commit is contained in:
AJ ONeal 2026-04-20 00:38:54 -06:00
parent 52f422ec93
commit 2abdc1c229
No known key found for this signature in database
3 changed files with 175 additions and 89 deletions

View File

@ -1,12 +1,10 @@
package main package main
import ( import (
"bufio"
"flag" "flag"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/geoip"
) )
@ -17,7 +15,7 @@ func main() {
freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)") freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)")
flag.Parse() flag.Parse()
cfg, err := parseConf(*configPath) cfg, err := geoip.ParseConf(*configPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
@ -25,7 +23,7 @@ func main() {
outDir := *dir outDir := *dir
if outDir == "" { if outDir == "" {
outDir = cfg["DatabaseDirectory"] outDir = cfg.DatabaseDirectory
} }
if outDir == "" { if outDir == "" {
outDir = "." outDir = "."
@ -36,24 +34,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
accountID := cfg["AccountID"] if len(cfg.EditionIDs) == 0 {
licenseKey := cfg["LicenseKey"]
if accountID == "" || licenseKey == "" {
fmt.Fprintf(os.Stderr, "error: AccountID and LicenseKey are required in %s\n", *configPath)
os.Exit(1)
}
editions := strings.Fields(cfg["EditionIDs"])
if len(editions) == 0 {
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath) fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath)
os.Exit(1) os.Exit(1)
} }
d := geoip.New(accountID, licenseKey) d := geoip.New(cfg.AccountID, cfg.LicenseKey)
d.FreshDays = *freshDays d.FreshDays = *freshDays
exitCode := 0 exitCode := 0
for _, edition := range editions { for _, edition := range cfg.EditionIDs {
path := filepath.Join(outDir, edition+".mmdb") path := filepath.Join(outDir, edition+".mmdb")
updated, err := d.Fetch(edition, path) updated, err := d.Fetch(edition, path)
if err != nil { if err != nil {
@ -71,24 +61,3 @@ func main() {
} }
os.Exit(exitCode) os.Exit(exitCode)
} }
// parseConf reads a geoipupdate-style config file (key value pairs, # comments).
func parseConf(path string) (map[string]string, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
cfg := make(map[string]string)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
key, value, _ := strings.Cut(line, " ")
cfg[strings.TrimSpace(key)] = strings.TrimSpace(value)
}
return cfg, scanner.Err()
}

54
net/geoip/conf.go Normal file
View File

@ -0,0 +1,54 @@
package geoip
import (
"bufio"
"fmt"
"os"
"strings"
)
// Conf holds the fields parsed from a geoipupdate-style config file.
type Conf struct {
AccountID string
LicenseKey string
EditionIDs []string
DatabaseDirectory string
}
// ParseConf reads a geoipupdate-style config file (whitespace-separated
// key/value pairs, # comments). Compatible with GeoIP.conf files used by
// the official geoipupdate tool.
func ParseConf(path string) (*Conf, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
defer f.Close()
kv := make(map[string]string)
scanner := bufio.NewScanner(f)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
key, value, _ := strings.Cut(line, " ")
kv[strings.TrimSpace(key)] = strings.TrimSpace(value)
}
if err := scanner.Err(); err != nil {
return nil, err
}
c := &Conf{
AccountID: kv["AccountID"],
LicenseKey: kv["LicenseKey"],
DatabaseDirectory: kv["DatabaseDirectory"],
}
if c.AccountID == "" || c.LicenseKey == "" {
return nil, fmt.Errorf("AccountID and LicenseKey are required in %s", path)
}
if ids := kv["EditionIDs"]; ids != "" {
c.EditionIDs = strings.Fields(ids)
}
return c, nil
}

View File

@ -6,11 +6,14 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"os" "os"
"path/filepath"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/oschwald/geoip2-golang" "github.com/oschwald/geoip2-golang"
"github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/net/ipcohort"
) )
@ -27,8 +30,9 @@ const (
) )
func main() { func main() {
cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb") cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb (overrides -geoip-conf)")
asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb") asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb (overrides -geoip-conf)")
geoipConf := flag.String("geoip-conf", "", "path to GeoIP.conf; auto-downloads City+ASN into data-dir")
gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir") gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir")
flag.Usage = func() { flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data-dir|blacklist.txt> <ip-address>\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data-dir|blacklist.txt> <ip-address>\n", os.Args[0])
@ -46,6 +50,7 @@ func main() {
dataPath := flag.Arg(0) dataPath := flag.Arg(0)
ipStr := flag.Arg(1) ipStr := flag.Arg(1)
// Blocklist source.
var src *Sources var src *Sources
switch { switch {
case *gitURL != "": case *gitURL != "":
@ -76,38 +81,65 @@ func main() {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
if err := reload(src, &whitelist, &inbound, &outbound); err != nil { if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n",
size(&inbound), size(&outbound)) cohortSize(&inbound), cohortSize(&outbound))
// GeoIP readers. // GeoIP: resolve paths and build cachers if we have credentials.
var cityDB, asnDB atomic.Pointer[geoip2.Reader] var cityDB, asnDB atomic.Pointer[geoip2.Reader]
if *cityDBPath != "" { var cityCacher, asnCacher *httpcache.Cacher
if r, err := geoip2.Open(*cityDBPath); err != nil {
fmt.Fprintf(os.Stderr, "warn: city-db: %v\n", err) resolvedCityPath := *cityDBPath
resolvedASNPath := *asnDBPath
if *geoipConf != "" {
cfg, err := geoip.ParseConf(*geoipConf)
if err != nil {
fmt.Fprintf(os.Stderr, "warn: geoip-conf: %v\n", err)
} else { } else {
cityDB.Store(r) dbDir := cfg.DatabaseDirectory
defer r.Close() if dbDir == "" {
} dbDir = dataPath
} }
if *asnDBPath != "" { d := geoip.New(cfg.AccountID, cfg.LicenseKey)
if r, err := geoip2.Open(*asnDBPath); err != nil { if resolvedCityPath == "" {
fmt.Fprintf(os.Stderr, "warn: asn-db: %v\n", err) resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb")
} else { }
asnDB.Store(r) if resolvedASNPath == "" {
defer r.Close() resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb")
}
cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath)
asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath)
if err := os.MkdirAll(dbDir, 0o755); err != nil {
fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err)
}
} }
} }
// Keep data fresh in the background if running as a daemon. // Fetch GeoIP DBs if we have cachers; otherwise just open existing files.
if cityCacher != nil {
if _, err := cityCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err)
}
}
if asnCacher != nil {
if _, err := asnCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err)
}
}
openGeoIPReader(resolvedCityPath, &cityDB)
openGeoIPReader(resolvedASNPath, &asnDB)
// Keep everything fresh in the background if running as a daemon.
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
go run(ctx, src, &whitelist, &inbound, &outbound) go runLoop(ctx, src, &whitelist, &inbound, &outbound,
cityCacher, asnCacher, &cityDB, &asnDB)
// Check and report.
blockedInbound := containsInbound(ipStr, &whitelist, &inbound) blockedInbound := containsInbound(ipStr, &whitelist, &inbound)
blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound) blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound)
@ -129,6 +161,65 @@ func main() {
} }
} }
func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) {
if path == "" {
return
}
r, err := geoip2.Open(path)
if err != nil {
return
}
if old := ptr.Swap(r); old != nil {
old.Close()
}
}
func runLoop(ctx context.Context, src *Sources,
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
cityCacher, asnCacher *httpcache.Cacher,
cityDB, asnDB *atomic.Pointer[geoip2.Reader],
) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Blocklists.
if updated, err := src.Fetch(false); err != nil {
fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err)
} else if updated {
if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil {
fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
cohortSize(inbound), cohortSize(outbound))
}
}
// GeoIP DBs.
if cityCacher != nil {
if updated, err := cityCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err)
} else if updated {
openGeoIPReader(cityCacher.Path, cityDB)
fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path)
}
}
if asnCacher != nil {
if updated, err := asnCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err)
} else if updated {
openGeoIPReader(asnCacher.Path, asnDB)
fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path)
}
}
case <-ctx.Done():
return
}
}
}
func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) { func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
ip, err := netip.ParseAddr(ipStr) ip, err := netip.ParseAddr(ipStr)
if err != nil { if err != nil {
@ -161,12 +252,13 @@ func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
if r := asnDB.Load(); r != nil { if r := asnDB.Load(); r != nil {
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
fmt.Printf(" ASN: AS%d %s\n", rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) fmt.Printf(" ASN: AS%d %s\n",
rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
} }
} }
} }
func reload(src *Sources, func reloadBlocklists(src *Sources,
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
) error { ) error {
if wl, err := src.LoadWhitelist(); err != nil { if wl, err := src.LoadWhitelist(); err != nil {
@ -187,35 +279,6 @@ func reload(src *Sources,
return nil return nil
} }
func run(ctx context.Context, src *Sources,
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
updated, err := src.Fetch(false)
if err != nil {
fmt.Fprintf(os.Stderr, "error: sync: %v\n", err)
continue
}
if !updated {
continue
}
if err := reload(src, whitelist, inbound, outbound); err != nil {
fmt.Fprintf(os.Stderr, "error: reload: %v\n", err)
continue
}
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
size(inbound), size(outbound))
case <-ctx.Done():
return
}
}
}
func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool { func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool {
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
return false return false
@ -232,7 +295,7 @@ func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Co
return c != nil && c.Contains(ip) return c != nil && c.Contains(ip)
} }
func size(ptr *atomic.Pointer[ipcohort.Cohort]) int { func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int {
if c := ptr.Load(); c != nil { if c := ptr.Load(); c != nil {
return c.Size() return c.Size()
} }