diff --git a/net/geoip/cmd/geoip-update/main.go b/net/geoip/cmd/geoip-update/main.go index 661ea13..62c666a 100644 --- a/net/geoip/cmd/geoip-update/main.go +++ b/net/geoip/cmd/geoip-update/main.go @@ -1,12 +1,10 @@ package main import ( - "bufio" "flag" "fmt" "os" "path/filepath" - "strings" "github.com/therootcompany/golib/net/geoip" ) @@ -17,7 +15,7 @@ func main() { freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)") flag.Parse() - cfg, err := parseConf(*configPath) + cfg, err := geoip.ParseConf(*configPath) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) @@ -25,7 +23,7 @@ func main() { outDir := *dir if outDir == "" { - outDir = cfg["DatabaseDirectory"] + outDir = cfg.DatabaseDirectory } if outDir == "" { outDir = "." @@ -36,24 +34,16 @@ func main() { os.Exit(1) } - accountID := cfg["AccountID"] - licenseKey := cfg["LicenseKey"] - if accountID == "" || licenseKey == "" { - fmt.Fprintf(os.Stderr, "error: AccountID and LicenseKey are required in %s\n", *configPath) - os.Exit(1) - } - - editions := strings.Fields(cfg["EditionIDs"]) - if len(editions) == 0 { + if len(cfg.EditionIDs) == 0 { fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath) os.Exit(1) } - d := geoip.New(accountID, licenseKey) + d := geoip.New(cfg.AccountID, cfg.LicenseKey) d.FreshDays = *freshDays exitCode := 0 - for _, edition := range editions { + for _, edition := range cfg.EditionIDs { path := filepath.Join(outDir, edition+".mmdb") updated, err := d.Fetch(edition, path) if err != nil { @@ -71,24 +61,3 @@ func main() { } os.Exit(exitCode) } - -// parseConf reads a geoipupdate-style config file (key value pairs, # comments). -func parseConf(path string) (map[string]string, error) { - f, err := os.Open(path) - if err != nil { - return nil, err - } - defer f.Close() - - cfg := make(map[string]string) - scanner := bufio.NewScanner(f) - for scanner.Scan() { - line := strings.TrimSpace(scanner.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - key, value, _ := strings.Cut(line, " ") - cfg[strings.TrimSpace(key)] = strings.TrimSpace(value) - } - return cfg, scanner.Err() -} diff --git a/net/geoip/conf.go b/net/geoip/conf.go new file mode 100644 index 0000000..2e034c5 --- /dev/null +++ b/net/geoip/conf.go @@ -0,0 +1,54 @@ +package geoip + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +// Conf holds the fields parsed from a geoipupdate-style config file. +type Conf struct { + AccountID string + LicenseKey string + EditionIDs []string + DatabaseDirectory string +} + +// ParseConf reads a geoipupdate-style config file (whitespace-separated +// key/value pairs, # comments). Compatible with GeoIP.conf files used by +// the official geoipupdate tool. +func ParseConf(path string) (*Conf, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + kv := make(map[string]string) + scanner := bufio.NewScanner(f) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + key, value, _ := strings.Cut(line, " ") + kv[strings.TrimSpace(key)] = strings.TrimSpace(value) + } + if err := scanner.Err(); err != nil { + return nil, err + } + + c := &Conf{ + AccountID: kv["AccountID"], + LicenseKey: kv["LicenseKey"], + DatabaseDirectory: kv["DatabaseDirectory"], + } + if c.AccountID == "" || c.LicenseKey == "" { + return nil, fmt.Errorf("AccountID and LicenseKey are required in %s", path) + } + if ids := kv["EditionIDs"]; ids != "" { + c.EditionIDs = strings.Fields(ids) + } + return c, nil +} diff --git a/net/ipcohort/cmd/check-ip/main.go b/net/ipcohort/cmd/check-ip/main.go index 825a115..e11c288 100644 --- a/net/ipcohort/cmd/check-ip/main.go +++ b/net/ipcohort/cmd/check-ip/main.go @@ -6,11 +6,14 @@ import ( "fmt" "net/netip" "os" + "path/filepath" "strings" "sync/atomic" "time" "github.com/oschwald/geoip2-golang" + "github.com/therootcompany/golib/net/geoip" + "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" ) @@ -27,8 +30,9 @@ const ( ) func main() { - cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb") - asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb") + cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb (overrides -geoip-conf)") + asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb (overrides -geoip-conf)") + geoipConf := flag.String("geoip-conf", "", "path to GeoIP.conf; auto-downloads City+ASN into data-dir") gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir") flag.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [flags] \n", os.Args[0]) @@ -46,6 +50,7 @@ func main() { dataPath := flag.Arg(0) ipStr := flag.Arg(1) + // Blocklist source. var src *Sources switch { case *gitURL != "": @@ -76,38 +81,65 @@ func main() { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - if err := reload(src, &whitelist, &inbound, &outbound); err != nil { + if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", - size(&inbound), size(&outbound)) + cohortSize(&inbound), cohortSize(&outbound)) - // GeoIP readers. + // GeoIP: resolve paths and build cachers if we have credentials. var cityDB, asnDB atomic.Pointer[geoip2.Reader] - if *cityDBPath != "" { - if r, err := geoip2.Open(*cityDBPath); err != nil { - fmt.Fprintf(os.Stderr, "warn: city-db: %v\n", err) + var cityCacher, asnCacher *httpcache.Cacher + + resolvedCityPath := *cityDBPath + resolvedASNPath := *asnDBPath + + if *geoipConf != "" { + cfg, err := geoip.ParseConf(*geoipConf) + if err != nil { + fmt.Fprintf(os.Stderr, "warn: geoip-conf: %v\n", err) } else { - cityDB.Store(r) - defer r.Close() - } - } - if *asnDBPath != "" { - if r, err := geoip2.Open(*asnDBPath); err != nil { - fmt.Fprintf(os.Stderr, "warn: asn-db: %v\n", err) - } else { - asnDB.Store(r) - defer r.Close() + dbDir := cfg.DatabaseDirectory + if dbDir == "" { + dbDir = dataPath + } + d := geoip.New(cfg.AccountID, cfg.LicenseKey) + if resolvedCityPath == "" { + resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb") + } + if resolvedASNPath == "" { + resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb") + } + cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath) + asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath) + if err := os.MkdirAll(dbDir, 0o755); err != nil { + fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err) + } } } - // Keep data fresh in the background if running as a daemon. + // Fetch GeoIP DBs if we have cachers; otherwise just open existing files. + if cityCacher != nil { + if _, err := cityCacher.Fetch(); err != nil { + fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err) + } + } + if asnCacher != nil { + if _, err := asnCacher.Fetch(); err != nil { + fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err) + } + } + openGeoIPReader(resolvedCityPath, &cityDB) + openGeoIPReader(resolvedASNPath, &asnDB) + + // Keep everything fresh in the background if running as a daemon. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go run(ctx, src, &whitelist, &inbound, &outbound) + go runLoop(ctx, src, &whitelist, &inbound, &outbound, + cityCacher, asnCacher, &cityDB, &asnDB) + // Check and report. blockedInbound := containsInbound(ipStr, &whitelist, &inbound) blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound) @@ -129,6 +161,65 @@ func main() { } } +func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) { + if path == "" { + return + } + r, err := geoip2.Open(path) + if err != nil { + return + } + if old := ptr.Swap(r); old != nil { + old.Close() + } +} + +func runLoop(ctx context.Context, src *Sources, + whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], + cityCacher, asnCacher *httpcache.Cacher, + cityDB, asnDB *atomic.Pointer[geoip2.Reader], +) { + ticker := time.NewTicker(47 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + // Blocklists. + if updated, err := src.Fetch(false); err != nil { + fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err) + } else if updated { + if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil { + fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n", + cohortSize(inbound), cohortSize(outbound)) + } + } + + // GeoIP DBs. + if cityCacher != nil { + if updated, err := cityCacher.Fetch(); err != nil { + fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err) + } else if updated { + openGeoIPReader(cityCacher.Path, cityDB) + fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path) + } + } + if asnCacher != nil { + if updated, err := asnCacher.Fetch(); err != nil { + fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err) + } else if updated { + openGeoIPReader(asnCacher.Path, asnDB) + fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path) + } + } + case <-ctx.Done(): + return + } + } +} + func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) { ip, err := netip.ParseAddr(ipStr) if err != nil { @@ -161,12 +252,13 @@ func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) { if r := asnDB.Load(); r != nil { if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { - fmt.Printf(" ASN: AS%d %s\n", rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) + fmt.Printf(" ASN: AS%d %s\n", + rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) } } } -func reload(src *Sources, +func reloadBlocklists(src *Sources, whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], ) error { if wl, err := src.LoadWhitelist(); err != nil { @@ -187,35 +279,6 @@ func reload(src *Sources, return nil } -func run(ctx context.Context, src *Sources, - whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], -) { - ticker := time.NewTicker(47 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - updated, err := src.Fetch(false) - if err != nil { - fmt.Fprintf(os.Stderr, "error: sync: %v\n", err) - continue - } - if !updated { - continue - } - if err := reload(src, whitelist, inbound, outbound); err != nil { - fmt.Fprintf(os.Stderr, "error: reload: %v\n", err) - continue - } - fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n", - size(inbound), size(outbound)) - case <-ctx.Done(): - return - } - } -} - func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool { if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { return false @@ -232,7 +295,7 @@ func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Co return c != nil && c.Contains(ip) } -func size(ptr *atomic.Pointer[ipcohort.Cohort]) int { +func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int { if c := ptr.Load(); c != nil { return c.Size() }