mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
feat: geoip.ParseConf, geoip-update uses it, check-ip auto-downloads+hot-swaps GeoIP
geoip.ParseConf() extracted from geoip-update into the geoip package so both cmds can read GeoIP.conf without duplication. check-ip gains -geoip-conf flag: reads AccountID+LicenseKey, resolves mmdb paths into data-dir, builds httpcache.Cachers with geoip.NewCacher. Background runLoop now refreshes both blocklists and GeoIP DBs on each tick, hot-swapping geoip2.Reader via atomic.Pointer.Swap + old.Close().
This commit is contained in:
parent
52f422ec93
commit
2abdc1c229
@ -1,12 +1,10 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/therootcompany/golib/net/geoip"
|
||||
)
|
||||
@ -17,7 +15,7 @@ func main() {
|
||||
freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)")
|
||||
flag.Parse()
|
||||
|
||||
cfg, err := parseConf(*configPath)
|
||||
cfg, err := geoip.ParseConf(*configPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
@ -25,7 +23,7 @@ func main() {
|
||||
|
||||
outDir := *dir
|
||||
if outDir == "" {
|
||||
outDir = cfg["DatabaseDirectory"]
|
||||
outDir = cfg.DatabaseDirectory
|
||||
}
|
||||
if outDir == "" {
|
||||
outDir = "."
|
||||
@ -36,24 +34,16 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
accountID := cfg["AccountID"]
|
||||
licenseKey := cfg["LicenseKey"]
|
||||
if accountID == "" || licenseKey == "" {
|
||||
fmt.Fprintf(os.Stderr, "error: AccountID and LicenseKey are required in %s\n", *configPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
editions := strings.Fields(cfg["EditionIDs"])
|
||||
if len(editions) == 0 {
|
||||
if len(cfg.EditionIDs) == 0 {
|
||||
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
d := geoip.New(accountID, licenseKey)
|
||||
d := geoip.New(cfg.AccountID, cfg.LicenseKey)
|
||||
d.FreshDays = *freshDays
|
||||
|
||||
exitCode := 0
|
||||
for _, edition := range editions {
|
||||
for _, edition := range cfg.EditionIDs {
|
||||
path := filepath.Join(outDir, edition+".mmdb")
|
||||
updated, err := d.Fetch(edition, path)
|
||||
if err != nil {
|
||||
@ -71,24 +61,3 @@ func main() {
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
// parseConf reads a geoipupdate-style config file (key value pairs, # comments).
|
||||
func parseConf(path string) (map[string]string, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
cfg := make(map[string]string)
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
key, value, _ := strings.Cut(line, " ")
|
||||
cfg[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
||||
}
|
||||
return cfg, scanner.Err()
|
||||
}
|
||||
|
||||
54
net/geoip/conf.go
Normal file
54
net/geoip/conf.go
Normal file
@ -0,0 +1,54 @@
|
||||
package geoip
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Conf holds the fields parsed from a geoipupdate-style config file.
|
||||
type Conf struct {
|
||||
AccountID string
|
||||
LicenseKey string
|
||||
EditionIDs []string
|
||||
DatabaseDirectory string
|
||||
}
|
||||
|
||||
// ParseConf reads a geoipupdate-style config file (whitespace-separated
|
||||
// key/value pairs, # comments). Compatible with GeoIP.conf files used by
|
||||
// the official geoipupdate tool.
|
||||
func ParseConf(path string) (*Conf, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
kv := make(map[string]string)
|
||||
scanner := bufio.NewScanner(f)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
key, value, _ := strings.Cut(line, " ")
|
||||
kv[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &Conf{
|
||||
AccountID: kv["AccountID"],
|
||||
LicenseKey: kv["LicenseKey"],
|
||||
DatabaseDirectory: kv["DatabaseDirectory"],
|
||||
}
|
||||
if c.AccountID == "" || c.LicenseKey == "" {
|
||||
return nil, fmt.Errorf("AccountID and LicenseKey are required in %s", path)
|
||||
}
|
||||
if ids := kv["EditionIDs"]; ids != "" {
|
||||
c.EditionIDs = strings.Fields(ids)
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
@ -6,11 +6,14 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
"github.com/therootcompany/golib/net/geoip"
|
||||
"github.com/therootcompany/golib/net/httpcache"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
@ -27,8 +30,9 @@ const (
|
||||
)
|
||||
|
||||
func main() {
|
||||
cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb")
|
||||
asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb")
|
||||
cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb (overrides -geoip-conf)")
|
||||
asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb (overrides -geoip-conf)")
|
||||
geoipConf := flag.String("geoip-conf", "", "path to GeoIP.conf; auto-downloads City+ASN into data-dir")
|
||||
gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir")
|
||||
flag.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data-dir|blacklist.txt> <ip-address>\n", os.Args[0])
|
||||
@ -46,6 +50,7 @@ func main() {
|
||||
dataPath := flag.Arg(0)
|
||||
ipStr := flag.Arg(1)
|
||||
|
||||
// Blocklist source.
|
||||
var src *Sources
|
||||
switch {
|
||||
case *gitURL != "":
|
||||
@ -76,38 +81,65 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := reload(src, &whitelist, &inbound, &outbound); err != nil {
|
||||
if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n",
|
||||
size(&inbound), size(&outbound))
|
||||
cohortSize(&inbound), cohortSize(&outbound))
|
||||
|
||||
// GeoIP readers.
|
||||
// GeoIP: resolve paths and build cachers if we have credentials.
|
||||
var cityDB, asnDB atomic.Pointer[geoip2.Reader]
|
||||
if *cityDBPath != "" {
|
||||
if r, err := geoip2.Open(*cityDBPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: city-db: %v\n", err)
|
||||
var cityCacher, asnCacher *httpcache.Cacher
|
||||
|
||||
resolvedCityPath := *cityDBPath
|
||||
resolvedASNPath := *asnDBPath
|
||||
|
||||
if *geoipConf != "" {
|
||||
cfg, err := geoip.ParseConf(*geoipConf)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: geoip-conf: %v\n", err)
|
||||
} else {
|
||||
cityDB.Store(r)
|
||||
defer r.Close()
|
||||
dbDir := cfg.DatabaseDirectory
|
||||
if dbDir == "" {
|
||||
dbDir = dataPath
|
||||
}
|
||||
d := geoip.New(cfg.AccountID, cfg.LicenseKey)
|
||||
if resolvedCityPath == "" {
|
||||
resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb")
|
||||
}
|
||||
if resolvedASNPath == "" {
|
||||
resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb")
|
||||
}
|
||||
cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath)
|
||||
asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath)
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err)
|
||||
}
|
||||
if *asnDBPath != "" {
|
||||
if r, err := geoip2.Open(*asnDBPath); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: asn-db: %v\n", err)
|
||||
} else {
|
||||
asnDB.Store(r)
|
||||
defer r.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// Keep data fresh in the background if running as a daemon.
|
||||
// Fetch GeoIP DBs if we have cachers; otherwise just open existing files.
|
||||
if cityCacher != nil {
|
||||
if _, err := cityCacher.Fetch(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err)
|
||||
}
|
||||
}
|
||||
if asnCacher != nil {
|
||||
if _, err := asnCacher.Fetch(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err)
|
||||
}
|
||||
}
|
||||
openGeoIPReader(resolvedCityPath, &cityDB)
|
||||
openGeoIPReader(resolvedASNPath, &asnDB)
|
||||
|
||||
// Keep everything fresh in the background if running as a daemon.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
go run(ctx, src, &whitelist, &inbound, &outbound)
|
||||
go runLoop(ctx, src, &whitelist, &inbound, &outbound,
|
||||
cityCacher, asnCacher, &cityDB, &asnDB)
|
||||
|
||||
// Check and report.
|
||||
blockedInbound := containsInbound(ipStr, &whitelist, &inbound)
|
||||
blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound)
|
||||
|
||||
@ -129,6 +161,65 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
r, err := geoip2.Open(path)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if old := ptr.Swap(r); old != nil {
|
||||
old.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func runLoop(ctx context.Context, src *Sources,
|
||||
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
||||
cityCacher, asnCacher *httpcache.Cacher,
|
||||
cityDB, asnDB *atomic.Pointer[geoip2.Reader],
|
||||
) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Blocklists.
|
||||
if updated, err := src.Fetch(false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err)
|
||||
} else if updated {
|
||||
if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
|
||||
cohortSize(inbound), cohortSize(outbound))
|
||||
}
|
||||
}
|
||||
|
||||
// GeoIP DBs.
|
||||
if cityCacher != nil {
|
||||
if updated, err := cityCacher.Fetch(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err)
|
||||
} else if updated {
|
||||
openGeoIPReader(cityCacher.Path, cityDB)
|
||||
fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path)
|
||||
}
|
||||
}
|
||||
if asnCacher != nil {
|
||||
if updated, err := asnCacher.Fetch(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err)
|
||||
} else if updated {
|
||||
openGeoIPReader(asnCacher.Path, asnDB)
|
||||
fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path)
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
@ -161,12 +252,13 @@ func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
|
||||
|
||||
if r := asnDB.Load(); r != nil {
|
||||
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
|
||||
fmt.Printf(" ASN: AS%d %s\n", rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
|
||||
fmt.Printf(" ASN: AS%d %s\n",
|
||||
rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func reload(src *Sources,
|
||||
func reloadBlocklists(src *Sources,
|
||||
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
||||
) error {
|
||||
if wl, err := src.LoadWhitelist(); err != nil {
|
||||
@ -187,35 +279,6 @@ func reload(src *Sources,
|
||||
return nil
|
||||
}
|
||||
|
||||
func run(ctx context.Context, src *Sources,
|
||||
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
||||
) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
updated, err := src.Fetch(false)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: sync: %v\n", err)
|
||||
continue
|
||||
}
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
if err := reload(src, whitelist, inbound, outbound); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: reload: %v\n", err)
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
|
||||
size(inbound), size(outbound))
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool {
|
||||
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
|
||||
return false
|
||||
@ -232,7 +295,7 @@ func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Co
|
||||
return c != nil && c.Contains(ip)
|
||||
}
|
||||
|
||||
func size(ptr *atomic.Pointer[ipcohort.Cohort]) int {
|
||||
func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int {
|
||||
if c := ptr.Load(); c != nil {
|
||||
return c.Size()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user