diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index c2d5b10..aaa82ff 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -17,6 +17,7 @@ import ( "github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/gitshallow" + "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/sync/dataset" ) @@ -37,11 +38,7 @@ type IPCheck struct { inbound *dataset.View[ipcohort.Cohort] outbound *dataset.View[ipcohort.Cohort] - geo *dataset.View[geoip.Databases] -} - -func printVersion(w *os.File) { - fmt.Fprintf(w, "check-ip %s\n", version) + geo *geoip.Databases } func main() { @@ -59,11 +56,10 @@ func main() { if len(os.Args) > 1 { switch os.Args[1] { case "-V", "-version", "--version", "version": - printVersion(os.Stdout) + fmt.Fprintf(os.Stdout, "check-ip %s\n", version) os.Exit(0) case "help", "-help", "--help": - printVersion(os.Stdout) - fmt.Fprintln(os.Stdout, "") + fmt.Fprintf(os.Stdout, "check-ip %s\n\n", version) fs.SetOutput(os.Stdout) fs.Usage() os.Exit(0) @@ -84,6 +80,7 @@ func main() { cfg.CacheDir = d } + // Blocklists: git repo with inbound + outbound IP cohort files. repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "") group := dataset.NewGroup(repo) cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) { @@ -102,15 +99,53 @@ func main() { log.Fatalf("blocklists: %v", err) } - maxmind := filepath.Join(cfg.CacheDir, "maxmind") - geoGroup := dataset.NewGroup(geoFetcher(cfg.ConfPath, maxmind)) - cfg.geo = dataset.Add(geoGroup, func() (*geoip.Databases, error) { - return geoip.Open(maxmind) - }) - if err := geoGroup.Load(context.Background()); err != nil { + // GeoIP: with GeoIP.conf, download the City + ASN tar.gz archives via + // httpcache conditional GETs. Without it, expect the tar.gz files to + // already be in maxmindDir. geoip.Open extracts in-memory — no .mmdb + // files are written to disk. + maxmindDir := filepath.Join(cfg.CacheDir, "maxmind") + confPath := cfg.ConfPath + if confPath == "" { + for _, p := range geoip.DefaultConfPaths() { + if _, err := os.Stat(p); err == nil { + confPath = p + break + } + } + } + if confPath != "" { + conf, err := geoip.ParseConf(confPath) + if err != nil { + log.Fatalf("geoip-conf: %v", err) + } + auth := httpcache.BasicAuth(conf.AccountID, conf.LicenseKey) + city := &httpcache.Cacher{ + URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz", + Path: filepath.Join(maxmindDir, "GeoLite2-City.tar.gz"), + MaxAge: 3 * 24 * time.Hour, + AuthHeader: "Authorization", + AuthValue: auth, + } + asn := &httpcache.Cacher{ + URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz", + Path: filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz"), + MaxAge: 3 * 24 * time.Hour, + AuthHeader: "Authorization", + AuthValue: auth, + } + if _, err := city.Fetch(); err != nil { + log.Fatalf("fetch GeoLite2-City: %v", err) + } + if _, err := asn.Fetch(); err != nil { + log.Fatalf("fetch GeoLite2-ASN: %v", err) + } + } + geo, err := geoip.Open(maxmindDir) + if err != nil { log.Fatalf("geoip: %v", err) } - defer func() { _ = cfg.geo.Value().Close() }() + defer func() { _ = geo.Close() }() + cfg.geo = geo if cfg.Bind == "" { return @@ -121,48 +156,7 @@ func main() { go group.Tick(ctx, refreshInterval, func(err error) { log.Printf("blocklists refresh: %v", err) }) - go geoGroup.Tick(ctx, refreshInterval, func(err error) { - log.Printf("geoip refresh: %v", err) - }) if err := cfg.serve(ctx); err != nil { log.Fatalf("serve: %v", err) } } - -// geoFetcher returns a Fetcher for the GeoLite2 City + ASN .mmdb files. -// With a GeoIP.conf (explicit path or auto-discovered) both files are -// downloaded via httpcache conditional GETs; otherwise the files are -// expected to exist on disk and are polled for out-of-band changes. -func geoFetcher(confPath, dir string) dataset.Fetcher { - cityPath := filepath.Join(dir, "GeoLite2-City.mmdb") - asnPath := filepath.Join(dir, "GeoLite2-ASN.mmdb") - if confPath == "" { - for _, p := range geoip.DefaultConfPaths() { - if _, err := os.Stat(p); err == nil { - confPath = p - break - } - } - } - if confPath == "" { - return dataset.PollFiles(cityPath, asnPath) - } - conf, err := geoip.ParseConf(confPath) - if err != nil { - log.Fatalf("geoip-conf: %v", err) - } - dl := geoip.New(conf.AccountID, conf.LicenseKey) - city := dl.NewCacher(geoip.CityEdition, cityPath) - asn := dl.NewCacher(geoip.ASNEdition, asnPath) - return dataset.FetcherFunc(func() (bool, error) { - cityUpdated, err := city.Fetch() - if err != nil { - return false, fmt.Errorf("fetch %s: %w", geoip.CityEdition, err) - } - asnUpdated, err := asn.Fetch() - if err != nil { - return false, fmt.Errorf("fetch %s: %w", geoip.ASNEdition, err) - } - return cityUpdated || asnUpdated, nil - }) -} diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index bf5be04..7c3379d 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -42,7 +42,7 @@ func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) { Blocked: in || out, BlockedInbound: in, BlockedOutbound: out, - Geo: c.geo.Value().Lookup(ip), + Geo: c.geo.Lookup(ip), } if r.URL.Query().Get("format") == "json" || diff --git a/net/geoip/cmd/geoip-update/main.go b/net/geoip/cmd/geoip-update/main.go index 62c666a..5d3fae2 100644 --- a/net/geoip/cmd/geoip-update/main.go +++ b/net/geoip/cmd/geoip-update/main.go @@ -5,14 +5,16 @@ import ( "fmt" "os" "path/filepath" + "time" "github.com/therootcompany/golib/net/geoip" + "github.com/therootcompany/golib/net/httpcache" ) func main() { configPath := flag.String("config", "GeoIP.conf", "path to GeoIP.conf") - dir := flag.String("dir", "", "directory to store .mmdb files (overrides DatabaseDirectory in config)") - freshDays := flag.Int("fresh-days", 0, "skip download if file is younger than N days (default 3)") + dir := flag.String("dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)") + freshDays := flag.Int("fresh-days", 3, "skip download if file is younger than N days") flag.Parse() cfg, err := geoip.ParseConf(*configPath) @@ -28,7 +30,6 @@ func main() { if outDir == "" { outDir = "." } - if err := os.MkdirAll(outDir, 0o755); err != nil { fmt.Fprintf(os.Stderr, "error: mkdir %s: %v\n", outDir, err) os.Exit(1) @@ -39,25 +40,31 @@ func main() { os.Exit(1) } - d := geoip.New(cfg.AccountID, cfg.LicenseKey) - d.FreshDays = *freshDays + auth := httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey) + maxAge := time.Duration(*freshDays) * 24 * time.Hour exitCode := 0 for _, edition := range cfg.EditionIDs { - path := filepath.Join(outDir, edition+".mmdb") - updated, err := d.Fetch(edition, path) + path := filepath.Join(outDir, edition+".tar.gz") + cacher := &httpcache.Cacher{ + URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", + Path: path, + MaxAge: maxAge, + AuthHeader: "Authorization", + AuthValue: auth, + } + updated, err := cacher.Fetch() if err != nil { fmt.Fprintf(os.Stderr, "error: %s: %v\n", edition, err) exitCode = 1 continue } + info, _ := os.Stat(path) + state := "fresh: " if updated { - info, _ := os.Stat(path) - fmt.Printf("updated: %s -> %s (%s)\n", edition, path, info.ModTime().Format("2006-01-02")) - } else { - info, _ := os.Stat(path) - fmt.Printf("fresh: %s (%s)\n", edition, info.ModTime().Format("2006-01-02")) + state = "updated:" } + fmt.Printf("%s %s -> %s (%s)\n", state, edition, path, info.ModTime().Format("2006-01-02")) } os.Exit(exitCode) } diff --git a/net/geoip/databases.go b/net/geoip/databases.go index 2b82237..fac14ee 100644 --- a/net/geoip/databases.go +++ b/net/geoip/databases.go @@ -1,10 +1,15 @@ package geoip import ( + "archive/tar" + "compress/gzip" "errors" "fmt" + "io" "net/netip" + "os" "path/filepath" + "strings" "github.com/oschwald/geoip2-golang" ) @@ -15,22 +20,53 @@ type Databases struct { ASN *geoip2.Reader } -// Open opens /GeoLite2-City.mmdb and /GeoLite2-ASN.mmdb. +// Open reads /GeoLite2-City.tar.gz and /GeoLite2-ASN.tar.gz, +// extracts the .mmdb entry from each archive in memory, and returns open +// readers. No .mmdb files are written to disk. func Open(dir string) (*Databases, error) { - cityPath := filepath.Join(dir, "GeoLite2-City.mmdb") - asnPath := filepath.Join(dir, "GeoLite2-ASN.mmdb") - city, err := geoip2.Open(cityPath) + city, err := openMMDBTarGz(filepath.Join(dir, "GeoLite2-City.tar.gz")) if err != nil { - return nil, fmt.Errorf("open %s: %w", cityPath, err) + return nil, fmt.Errorf("city: %w", err) } - asn, err := geoip2.Open(asnPath) + asn, err := openMMDBTarGz(filepath.Join(dir, "GeoLite2-ASN.tar.gz")) if err != nil { _ = city.Close() - return nil, fmt.Errorf("open %s: %w", asnPath, err) + return nil, fmt.Errorf("asn: %w", err) } return &Databases{City: city, ASN: asn}, nil } +func openMMDBTarGz(path string) (*geoip2.Reader, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + gr, err := gzip.NewReader(f) + if err != nil { + return nil, fmt.Errorf("gzip %s: %w", path, err) + } + defer gr.Close() + tr := tar.NewReader(gr) + for { + hdr, err := tr.Next() + if err == io.EOF { + return nil, fmt.Errorf("no .mmdb entry in %s", path) + } + if err != nil { + return nil, err + } + if !strings.HasSuffix(hdr.Name, ".mmdb") { + continue + } + data, err := io.ReadAll(tr) + if err != nil { + return nil, err + } + return geoip2.FromBytes(data) + } +} + // Close closes the city and ASN readers. func (d *Databases) Close() error { return errors.Join(d.City.Close(), d.ASN.Close()) diff --git a/net/geoip/geoip.go b/net/geoip/geoip.go index ad05a01..8756edb 100644 --- a/net/geoip/geoip.go +++ b/net/geoip/geoip.go @@ -1,17 +1,8 @@ package geoip import ( - "archive/tar" - "compress/gzip" - "encoding/base64" - "fmt" - "io" "os" "path/filepath" - "strings" - "time" - - "github.com/therootcompany/golib/net/httpcache" ) const ( @@ -19,21 +10,11 @@ const ( ASNEdition = "GeoLite2-ASN" CountryEdition = "GeoLite2-Country" - downloadBase = "https://download.maxmind.com/geoip/databases" - defaultFreshDays = 3 - defaultTimeout = 5 * time.Minute + // DownloadBase is the MaxMind databases download endpoint. Full URL: + // //download?suffix=tar.gz + DownloadBase = "https://download.maxmind.com/geoip/databases" ) -// Downloader fetches MaxMind GeoLite2 .mmdb files from the download API. -// For one-shot use call Fetch; for polling loops call NewCacher and reuse -// the Cacher so ETag state is preserved across calls. -type Downloader struct { - AccountID string - LicenseKey string - FreshDays int // 0 uses 3 - Timeout time.Duration // 0 uses 5m -} - // DefaultConfPaths returns the standard locations where GeoIP.conf is looked // up: ./GeoIP.conf, then ~/.config/maxmind/GeoIP.conf. func DefaultConfPaths() []string { @@ -53,87 +34,3 @@ func DefaultCacheDir() (string, error) { } return filepath.Join(base, "maxmind"), nil } - -// New returns a Downloader configured with the given credentials. -func New(accountID, licenseKey string) *Downloader { - return &Downloader{AccountID: accountID, LicenseKey: licenseKey} -} - -// NewCacher returns an httpcache.Cacher pre-configured for this edition and -// path. Hold the Cacher and call Fetch() on it periodically — ETag state is -// preserved across calls, enabling conditional GETs that skip the download -// count on unchanged releases. -func (d *Downloader) NewCacher(edition, path string) *httpcache.Cacher { - freshDays := d.FreshDays - if freshDays == 0 { - freshDays = defaultFreshDays - } - timeout := d.Timeout - if timeout == 0 { - timeout = defaultTimeout - } - creds := base64.StdEncoding.EncodeToString([]byte(d.AccountID + ":" + d.LicenseKey)) - return &httpcache.Cacher{ - URL: fmt.Sprintf("%s/%s/download?suffix=tar.gz", downloadBase, edition), - Path: path, - MaxAge: time.Duration(freshDays) * 24 * time.Hour, - Timeout: timeout, - AuthHeader: "Authorization", - AuthValue: "Basic " + creds, - Transform: ExtractMMDB, - } -} - -// Fetch downloads edition to path if the file is stale. Convenience wrapper -// around NewCacher for one-shot use; ETag state is not retained. -func (d *Downloader) Fetch(edition, path string) (bool, error) { - return d.NewCacher(edition, path).Fetch() -} - -// ExtractMMDB reads a MaxMind tar.gz archive, writes the .mmdb entry to path -// atomically (via tmp+rename), and sets its mtime to MaxMind's release date. -func ExtractMMDB(r io.Reader, path string) error { - gr, err := gzip.NewReader(r) - if err != nil { - return err - } - defer gr.Close() - - tr := tar.NewReader(gr) - for { - hdr, err := tr.Next() - if err == io.EOF { - return fmt.Errorf("no .mmdb file found in archive") - } - if err != nil { - return err - } - if !strings.HasSuffix(hdr.Name, ".mmdb") { - continue - } - - tmp := path + ".tmp" - f, err := os.Create(tmp) - if err != nil { - return err - } - if _, err := io.Copy(f, tr); err != nil { - f.Close() - os.Remove(tmp) - return err - } - f.Close() - - if err := os.Rename(tmp, path); err != nil { - os.Remove(tmp) - return err - } - - // Preserve MaxMind's release date so mtime == data age, not download time. - if !hdr.ModTime.IsZero() { - os.Chtimes(path, hdr.ModTime, hdr.ModTime) - } - - return nil - } -} diff --git a/net/geoip/geoip_integration_test.go b/net/geoip/geoip_integration_test.go index 6e486ab..8db9e6d 100644 --- a/net/geoip/geoip_integration_test.go +++ b/net/geoip/geoip_integration_test.go @@ -8,6 +8,7 @@ import ( "testing" "github.com/therootcompany/golib/net/geoip" + "github.com/therootcompany/golib/net/httpcache" ) func testdataDir(t *testing.T) string { @@ -27,7 +28,6 @@ func testdataDir(t *testing.T) string { func geoipConf(t *testing.T) *geoip.Conf { t.Helper() - // Look for GeoIP.conf relative to the module root. dir, _ := filepath.Abs(".") for { p := filepath.Join(dir, "GeoIP.conf") @@ -48,19 +48,25 @@ func geoipConf(t *testing.T) *geoip.Conf { return nil } -func TestDownloader_CityAndASN(t *testing.T) { +func newCacher(cfg *geoip.Conf, edition, path string) *httpcache.Cacher { + return &httpcache.Cacher{ + URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", + Path: path, + AuthHeader: "Authorization", + AuthValue: httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey), + } +} + +func TestDownload_CityAndASN(t *testing.T) { cfg := geoipConf(t) td := testdataDir(t) - d := geoip.New(cfg.AccountID, cfg.LicenseKey) - for _, edition := range []string{geoip.CityEdition, geoip.ASNEdition} { - path := filepath.Join(td, edition+".mmdb") + path := filepath.Join(td, edition+".tar.gz") os.Remove(path) os.Remove(path + ".meta") - cacher := d.NewCacher(edition, path) - updated, err := cacher.Fetch() + updated, err := newCacher(cfg, edition, path).Fetch() if err != nil { t.Fatalf("%s Fetch: %v", edition, err) } @@ -83,23 +89,18 @@ func TestDownloader_CityAndASN(t *testing.T) { } } -func TestDownloader_ConditionalGet_FreshCacher(t *testing.T) { +func TestDownload_ConditionalGet_FreshCacher(t *testing.T) { cfg := geoipConf(t) td := testdataDir(t) - d := geoip.New(cfg.AccountID, cfg.LicenseKey) - for _, edition := range []string{geoip.CityEdition, geoip.ASNEdition} { - path := filepath.Join(td, edition+".mmdb") + path := filepath.Join(td, edition+".tar.gz") - // Ensure downloaded. - if _, err := d.NewCacher(edition, path).Fetch(); err != nil { + if _, err := newCacher(cfg, edition, path).Fetch(); err != nil { t.Fatalf("%s initial Fetch: %v", edition, err) } - // Fresh cacher — no in-memory ETag, must use sidecar. - fresh := d.NewCacher(edition, path) - updated, err := fresh.Fetch() + updated, err := newCacher(cfg, edition, path).Fetch() if err != nil { t.Fatalf("%s fresh Fetch: %v", edition, err) } diff --git a/net/httpcache/httpcache.go b/net/httpcache/httpcache.go index 1f5f80f..783f9b6 100644 --- a/net/httpcache/httpcache.go +++ b/net/httpcache/httpcache.go @@ -1,6 +1,7 @@ package httpcache import ( + "encoding/base64" "encoding/json" "fmt" "io" @@ -12,6 +13,19 @@ import ( "time" ) +// BasicAuth returns an HTTP Basic Authorization header value: +// "Basic " + base64(user:pass). Assign to Cacher.AuthValue with +// AuthHeader "Authorization". +func BasicAuth(user, pass string) string { + return "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+pass)) +} + +// Bearer returns a Bearer Authorization header value: "Bearer " + token. +// Assign to Cacher.AuthValue with AuthHeader "Authorization". +func Bearer(token string) string { + return "Bearer " + token +} + const ( defaultConnTimeout = 5 * time.Second // TCP connect + TLS handshake defaultTimeout = 5 * time.Minute // overall including body read @@ -34,11 +48,9 @@ const ( // Auth — AuthHeader/AuthValue set a request header on every attempt. Auth is // stripped before following redirects so presigned targets (e.g. S3/R2 URLs) // never receive credentials. Use any scheme: "Authorization"/"Bearer token", -// "X-API-Key"/"secret", "Authorization"/"Basic base64(user:pass)", etc. -// -// Transform — if set, called with the response body instead of the default -// atomic file copy. The func is responsible for writing to path atomically. -// Use this for archives (e.g. extracting a .mmdb from a MaxMind tar.gz). +// "X-API-Key"/"secret", "Authorization"/"Basic base64(user:pass)", etc. The +// BasicAuth and Bearer helpers produce the right AuthValue for the common +// cases. type Cacher struct { URL string Path string @@ -48,7 +60,6 @@ type Cacher struct { MinInterval time.Duration // 0 disables; skip HTTP if last Fetch attempt was within this AuthHeader string // e.g. "Authorization" or "X-API-Key" AuthValue string // e.g. "Bearer token" or "Basic base64(user:pass)" - Transform func(r io.Reader, path string) error // nil = direct atomic copy mu sync.Mutex etag string @@ -186,30 +197,24 @@ func (c *Cacher) Fetch() (updated bool, err error) { if err := os.MkdirAll(filepath.Dir(c.Path), 0o755); err != nil { return false, err } - if c.Transform != nil { - if err := c.Transform(resp.Body, c.Path); err != nil { - return false, err - } - } else { - tmp := c.Path + ".tmp" - f, err := os.Create(tmp) - if err != nil { - return false, err - } - n, err := io.Copy(f, resp.Body) - f.Close() - if err != nil { - os.Remove(tmp) - return false, err - } - if n == 0 { - os.Remove(tmp) - return false, fmt.Errorf("empty response from %s", c.URL) - } - if err := os.Rename(tmp, c.Path); err != nil { - os.Remove(tmp) - return false, err - } + tmp := c.Path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return false, err + } + n, err := io.Copy(f, resp.Body) + f.Close() + if err != nil { + os.Remove(tmp) + return false, err + } + if n == 0 { + os.Remove(tmp) + return false, fmt.Errorf("empty response from %s", c.URL) + } + if err := os.Rename(tmp, c.Path); err != nil { + os.Remove(tmp) + return false, err } if etag := resp.Header.Get("ETag"); etag != "" {