From 3e48e0a863eb3716a72d4fce1c1fac22ace0c85a Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 12:08:03 -0600 Subject: [PATCH] fix: check-ip fails on startup if data cannot be downloaded --- .gitignore | 4 +- cmd/check-ip/main.go | 112 +++++++++++++++++++++---------------------- 2 files changed, 56 insertions(+), 60 deletions(-) diff --git a/.gitignore b/.gitignore index f4bf406..0436191 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,8 @@ env.* # Project binaries dist/ -check-ip -cmd/check-ip/check-ip +/check-ip +/cmd/check-ip/check-ip auth/csvauth/cmd/csvauth/csvauth cmd/auth-proxy/auth-proxy cmd/httplog/httplog diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index fcc58c6..697f603 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -115,27 +115,29 @@ func main() { if *geoipConf != "" { cfg, err := geoip.ParseConf(*geoipConf) if err != nil { - fmt.Fprintf(os.Stderr, "warn: geoip-conf: %v\n", err) - } else { - dbDir := cfg.DatabaseDirectory - if dbDir == "" { - if d, err := geoip.DefaultCacheDir(); err == nil { - dbDir = d - } - } - if err := os.MkdirAll(dbDir, 0o755); err != nil { - fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err) - } - d := geoip.New(cfg.AccountID, cfg.LicenseKey) - if resolvedCityPath == "" { - resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb") - } - if resolvedASNPath == "" { - resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb") - } - cityDS = newGeoIPDataset(d, geoip.CityEdition, resolvedCityPath) - asnDS = newGeoIPDataset(d, geoip.ASNEdition, resolvedASNPath) + fmt.Fprintf(os.Stderr, "error: geoip-conf: %v\n", err) + os.Exit(1) } + dbDir := cfg.DatabaseDirectory + if dbDir == "" { + if dbDir, err = geoip.DefaultCacheDir(); err != nil { + fmt.Fprintf(os.Stderr, "error: geoip cache dir: %v\n", err) + os.Exit(1) + } + } + if err := os.MkdirAll(dbDir, 0o755); err != nil { + fmt.Fprintf(os.Stderr, "error: mkdir %s: %v\n", dbDir, err) + os.Exit(1) + } + d := geoip.New(cfg.AccountID, cfg.LicenseKey) + if resolvedCityPath == "" { + resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb") + } + if resolvedASNPath == "" { + resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb") + } + cityDS = newGeoIPDataset(d, geoip.CityEdition, resolvedCityPath) + asnDS = newGeoIPDataset(d, geoip.ASNEdition, resolvedASNPath) } else { // Manual paths: no auto-download, just open existing files. if resolvedCityPath != "" { @@ -148,12 +150,14 @@ func main() { if cityDS != nil { if err := cityDS.Init(); err != nil { - fmt.Fprintf(os.Stderr, "warn: city DB: %v\n", err) + fmt.Fprintf(os.Stderr, "error: city DB: %v\n", err) + os.Exit(1) } } if asnDS != nil { if err := asnDS.Init(); err != nil { - fmt.Fprintf(os.Stderr, "warn: ASN DB: %v\n", err) + fmt.Fprintf(os.Stderr, "error: ASN DB: %v\n", err) + os.Exit(1) } } @@ -210,31 +214,25 @@ func newGeoIPDataset(d *geoip.Downloader, edition, path string) *dataset.Dataset func containsInbound(ip string, whitelist, inbound *dataset.View[ipcohort.Cohort], ) bool { - if whitelist != nil { - if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { - return false - } + if whitelist != nil && whitelist.Load().Contains(ip) { + return false } if inbound == nil { return false } - c := inbound.Load() - return c != nil && c.Contains(ip) + return inbound.Load().Contains(ip) } func containsOutbound(ip string, whitelist, outbound *dataset.View[ipcohort.Cohort], ) bool { - if whitelist != nil { - if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { - return false - } + if whitelist != nil && whitelist.Load().Contains(ip) { + return false } if outbound == nil { return false } - c := outbound.Load() - return c != nil && c.Contains(ip) + return outbound.Load().Contains(ip) } func printGeoInfo(ipStr string, cityDS, asnDS *dataset.Dataset[geoip2.Reader]) { @@ -245,36 +243,34 @@ func printGeoInfo(ipStr string, cityDS, asnDS *dataset.Dataset[geoip2.Reader]) { stdIP := ip.AsSlice() if cityDS != nil { - if r := cityDS.Load(); r != nil { - if rec, err := r.City(stdIP); err == nil { - city := rec.City.Names["en"] - country := rec.Country.Names["en"] - iso := rec.Country.IsoCode - var parts []string - if city != "" { - parts = append(parts, city) - } - if len(rec.Subdivisions) > 0 { - if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city { - parts = append(parts, sub) - } - } - if country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", country, iso)) - } - if len(parts) > 0 { - fmt.Printf(" Location: %s\n", strings.Join(parts, ", ")) + r := cityDS.Load() + if rec, err := r.City(stdIP); err == nil { + city := rec.City.Names["en"] + country := rec.Country.Names["en"] + iso := rec.Country.IsoCode + var parts []string + if city != "" { + parts = append(parts, city) + } + if len(rec.Subdivisions) > 0 { + if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city { + parts = append(parts, sub) } } + if country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", country, iso)) + } + if len(parts) > 0 { + fmt.Printf(" Location: %s\n", strings.Join(parts, ", ")) + } } } if asnDS != nil { - if r := asnDS.Load(); r != nil { - if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { - fmt.Printf(" ASN: AS%d %s\n", - rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) - } + r := asnDS.Load() + if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { + fmt.Printf(" ASN: AS%d %s\n", + rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) } } }