mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
feat: geoip.ParseConf, geoip-update uses it, check-ip auto-downloads+hot-swaps GeoIP
geoip.ParseConf() extracted from geoip-update into the geoip package so both cmds can read GeoIP.conf without duplication. check-ip gains -geoip-conf flag: reads AccountID+LicenseKey, resolves mmdb paths into data-dir, builds httpcache.Cachers with geoip.NewCacher. Background runLoop now refreshes both blocklists and GeoIP DBs on each tick, hot-swapping geoip2.Reader via atomic.Pointer.Swap + old.Close().
This commit is contained in:
parent
52f422ec93
commit
2abdc1c229
@ -1,12 +1,10 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/therootcompany/golib/net/geoip"
|
"github.com/therootcompany/golib/net/geoip"
|
||||||
)
|
)
|
||||||
@ -17,7 +15,7 @@ func main() {
|
|||||||
freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)")
|
freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
cfg, err := parseConf(*configPath)
|
cfg, err := geoip.ParseConf(*configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
@ -25,7 +23,7 @@ func main() {
|
|||||||
|
|
||||||
outDir := *dir
|
outDir := *dir
|
||||||
if outDir == "" {
|
if outDir == "" {
|
||||||
outDir = cfg["DatabaseDirectory"]
|
outDir = cfg.DatabaseDirectory
|
||||||
}
|
}
|
||||||
if outDir == "" {
|
if outDir == "" {
|
||||||
outDir = "."
|
outDir = "."
|
||||||
@ -36,24 +34,16 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
accountID := cfg["AccountID"]
|
if len(cfg.EditionIDs) == 0 {
|
||||||
licenseKey := cfg["LicenseKey"]
|
|
||||||
if accountID == "" || licenseKey == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: AccountID and LicenseKey are required in %s\n", *configPath)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
editions := strings.Fields(cfg["EditionIDs"])
|
|
||||||
if len(editions) == 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath)
|
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
d := geoip.New(accountID, licenseKey)
|
d := geoip.New(cfg.AccountID, cfg.LicenseKey)
|
||||||
d.FreshDays = *freshDays
|
d.FreshDays = *freshDays
|
||||||
|
|
||||||
exitCode := 0
|
exitCode := 0
|
||||||
for _, edition := range editions {
|
for _, edition := range cfg.EditionIDs {
|
||||||
path := filepath.Join(outDir, edition+".mmdb")
|
path := filepath.Join(outDir, edition+".mmdb")
|
||||||
updated, err := d.Fetch(edition, path)
|
updated, err := d.Fetch(edition, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -71,24 +61,3 @@ func main() {
|
|||||||
}
|
}
|
||||||
os.Exit(exitCode)
|
os.Exit(exitCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseConf reads a geoipupdate-style config file (key value pairs, # comments).
|
|
||||||
func parseConf(path string) (map[string]string, error) {
|
|
||||||
f, err := os.Open(path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
cfg := make(map[string]string)
|
|
||||||
scanner := bufio.NewScanner(f)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if line == "" || strings.HasPrefix(line, "#") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key, value, _ := strings.Cut(line, " ")
|
|
||||||
cfg[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
|
||||||
}
|
|
||||||
return cfg, scanner.Err()
|
|
||||||
}
|
|
||||||
|
|||||||
54
net/geoip/conf.go
Normal file
54
net/geoip/conf.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
package geoip
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Conf holds the fields parsed from a geoipupdate-style config file.
|
||||||
|
type Conf struct {
|
||||||
|
AccountID string
|
||||||
|
LicenseKey string
|
||||||
|
EditionIDs []string
|
||||||
|
DatabaseDirectory string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseConf reads a geoipupdate-style config file (whitespace-separated
|
||||||
|
// key/value pairs, # comments). Compatible with GeoIP.conf files used by
|
||||||
|
// the official geoipupdate tool.
|
||||||
|
func ParseConf(path string) (*Conf, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
kv := make(map[string]string)
|
||||||
|
scanner := bufio.NewScanner(f)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if line == "" || strings.HasPrefix(line, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key, value, _ := strings.Cut(line, " ")
|
||||||
|
kv[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &Conf{
|
||||||
|
AccountID: kv["AccountID"],
|
||||||
|
LicenseKey: kv["LicenseKey"],
|
||||||
|
DatabaseDirectory: kv["DatabaseDirectory"],
|
||||||
|
}
|
||||||
|
if c.AccountID == "" || c.LicenseKey == "" {
|
||||||
|
return nil, fmt.Errorf("AccountID and LicenseKey are required in %s", path)
|
||||||
|
}
|
||||||
|
if ids := kv["EditionIDs"]; ids != "" {
|
||||||
|
c.EditionIDs = strings.Fields(ids)
|
||||||
|
}
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
@ -6,11 +6,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/oschwald/geoip2-golang"
|
"github.com/oschwald/geoip2-golang"
|
||||||
|
"github.com/therootcompany/golib/net/geoip"
|
||||||
|
"github.com/therootcompany/golib/net/httpcache"
|
||||||
"github.com/therootcompany/golib/net/ipcohort"
|
"github.com/therootcompany/golib/net/ipcohort"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -27,8 +30,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb")
|
cityDBPath := flag.String("city-db", "", "path to GeoLite2-City.mmdb (overrides -geoip-conf)")
|
||||||
asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb")
|
asnDBPath := flag.String("asn-db", "", "path to GeoLite2-ASN.mmdb (overrides -geoip-conf)")
|
||||||
|
geoipConf := flag.String("geoip-conf", "", "path to GeoIP.conf; auto-downloads City+ASN into data-dir")
|
||||||
gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir")
|
gitURL := flag.String("git", "", "clone/pull blocklist from this git URL into data-dir")
|
||||||
flag.Usage = func() {
|
flag.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data-dir|blacklist.txt> <ip-address>\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <data-dir|blacklist.txt> <ip-address>\n", os.Args[0])
|
||||||
@ -46,6 +50,7 @@ func main() {
|
|||||||
dataPath := flag.Arg(0)
|
dataPath := flag.Arg(0)
|
||||||
ipStr := flag.Arg(1)
|
ipStr := flag.Arg(1)
|
||||||
|
|
||||||
|
// Blocklist source.
|
||||||
var src *Sources
|
var src *Sources
|
||||||
switch {
|
switch {
|
||||||
case *gitURL != "":
|
case *gitURL != "":
|
||||||
@ -76,38 +81,65 @@ func main() {
|
|||||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if err := reload(src, &whitelist, &inbound, &outbound); err != nil {
|
if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n",
|
fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n",
|
||||||
size(&inbound), size(&outbound))
|
cohortSize(&inbound), cohortSize(&outbound))
|
||||||
|
|
||||||
// GeoIP readers.
|
// GeoIP: resolve paths and build cachers if we have credentials.
|
||||||
var cityDB, asnDB atomic.Pointer[geoip2.Reader]
|
var cityDB, asnDB atomic.Pointer[geoip2.Reader]
|
||||||
if *cityDBPath != "" {
|
var cityCacher, asnCacher *httpcache.Cacher
|
||||||
if r, err := geoip2.Open(*cityDBPath); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "warn: city-db: %v\n", err)
|
resolvedCityPath := *cityDBPath
|
||||||
|
resolvedASNPath := *asnDBPath
|
||||||
|
|
||||||
|
if *geoipConf != "" {
|
||||||
|
cfg, err := geoip.ParseConf(*geoipConf)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "warn: geoip-conf: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
cityDB.Store(r)
|
dbDir := cfg.DatabaseDirectory
|
||||||
defer r.Close()
|
if dbDir == "" {
|
||||||
}
|
dbDir = dataPath
|
||||||
}
|
}
|
||||||
if *asnDBPath != "" {
|
d := geoip.New(cfg.AccountID, cfg.LicenseKey)
|
||||||
if r, err := geoip2.Open(*asnDBPath); err != nil {
|
if resolvedCityPath == "" {
|
||||||
fmt.Fprintf(os.Stderr, "warn: asn-db: %v\n", err)
|
resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb")
|
||||||
} else {
|
}
|
||||||
asnDB.Store(r)
|
if resolvedASNPath == "" {
|
||||||
defer r.Close()
|
resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb")
|
||||||
|
}
|
||||||
|
cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath)
|
||||||
|
asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath)
|
||||||
|
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep data fresh in the background if running as a daemon.
|
// Fetch GeoIP DBs if we have cachers; otherwise just open existing files.
|
||||||
|
if cityCacher != nil {
|
||||||
|
if _, err := cityCacher.Fetch(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if asnCacher != nil {
|
||||||
|
if _, err := asnCacher.Fetch(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
openGeoIPReader(resolvedCityPath, &cityDB)
|
||||||
|
openGeoIPReader(resolvedASNPath, &asnDB)
|
||||||
|
|
||||||
|
// Keep everything fresh in the background if running as a daemon.
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
go run(ctx, src, &whitelist, &inbound, &outbound)
|
go runLoop(ctx, src, &whitelist, &inbound, &outbound,
|
||||||
|
cityCacher, asnCacher, &cityDB, &asnDB)
|
||||||
|
|
||||||
|
// Check and report.
|
||||||
blockedInbound := containsInbound(ipStr, &whitelist, &inbound)
|
blockedInbound := containsInbound(ipStr, &whitelist, &inbound)
|
||||||
blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound)
|
blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound)
|
||||||
|
|
||||||
@ -129,6 +161,65 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) {
|
||||||
|
if path == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r, err := geoip2.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if old := ptr.Swap(r); old != nil {
|
||||||
|
old.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runLoop(ctx context.Context, src *Sources,
|
||||||
|
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
||||||
|
cityCacher, asnCacher *httpcache.Cacher,
|
||||||
|
cityDB, asnDB *atomic.Pointer[geoip2.Reader],
|
||||||
|
) {
|
||||||
|
ticker := time.NewTicker(47 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
// Blocklists.
|
||||||
|
if updated, err := src.Fetch(false); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err)
|
||||||
|
} else if updated {
|
||||||
|
if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
|
||||||
|
cohortSize(inbound), cohortSize(outbound))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeoIP DBs.
|
||||||
|
if cityCacher != nil {
|
||||||
|
if updated, err := cityCacher.Fetch(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err)
|
||||||
|
} else if updated {
|
||||||
|
openGeoIPReader(cityCacher.Path, cityDB)
|
||||||
|
fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if asnCacher != nil {
|
||||||
|
if updated, err := asnCacher.Fetch(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err)
|
||||||
|
} else if updated {
|
||||||
|
openGeoIPReader(asnCacher.Path, asnDB)
|
||||||
|
fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
|
func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
|
||||||
ip, err := netip.ParseAddr(ipStr)
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -161,12 +252,13 @@ func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
|
|||||||
|
|
||||||
if r := asnDB.Load(); r != nil {
|
if r := asnDB.Load(); r != nil {
|
||||||
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
|
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
|
||||||
fmt.Printf(" ASN: AS%d %s\n", rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
|
fmt.Printf(" ASN: AS%d %s\n",
|
||||||
|
rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func reload(src *Sources,
|
func reloadBlocklists(src *Sources,
|
||||||
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
||||||
) error {
|
) error {
|
||||||
if wl, err := src.LoadWhitelist(); err != nil {
|
if wl, err := src.LoadWhitelist(); err != nil {
|
||||||
@ -187,35 +279,6 @@ func reload(src *Sources,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func run(ctx context.Context, src *Sources,
|
|
||||||
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
|
|
||||||
) {
|
|
||||||
ticker := time.NewTicker(47 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
updated, err := src.Fetch(false)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: sync: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !updated {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := reload(src, whitelist, inbound, outbound); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: reload: %v\n", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
|
|
||||||
size(inbound), size(outbound))
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool {
|
func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool {
|
||||||
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
|
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
|
||||||
return false
|
return false
|
||||||
@ -232,7 +295,7 @@ func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Co
|
|||||||
return c != nil && c.Contains(ip)
|
return c != nil && c.Contains(ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func size(ptr *atomic.Pointer[ipcohort.Cohort]) int {
|
func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int {
|
||||||
if c := ptr.Load(); c != nil {
|
if c := ptr.Load(); c != nil {
|
||||||
return c.Size()
|
return c.Size()
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user