diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index c147dc8..a90f676 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -64,6 +64,7 @@ type IPCheck struct { RepoURL string CacheDir string WhitelistPath string + AsyncLoad bool // GeoIPBasicAuth is the pre-encoded Authorization header value for // MaxMind downloads. @@ -83,6 +84,7 @@ func main() { fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: ~/.cache)") fs.StringVar(&cfg.WhitelistPath, "whitelist", "", "path to a file of IPs and/or CIDRs (one per line) that override block decisions") + fs.BoolVar(&cfg.AsyncLoad, "async-load", false, "with --serve: start the HTTP server immediately and populate blocklists+whitelist in the background (/healthz returns 503 until ready). Ignored in CLI mode.") fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [flags] [ip...]\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s --serve [flags]\n", os.Args[0]) @@ -149,29 +151,45 @@ func main() { repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "") repo.MaxAge = refreshInterval blocklists := dataset.NewSet(repo) - cfg.inbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) { + asyncServe := cfg.AsyncLoad && cfg.Bind != "" + addCohort := func(s *dataset.Set, loader func() (*ipcohort.Cohort, error)) *dataset.View[ipcohort.Cohort] { + if asyncServe { + return dataset.AddInitial(s, ipcohort.New(), loader) + } + return dataset.Add(s, loader) + } + cfg.inbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles( repo.FilePath("tables/inbound/single_ips.txt"), repo.FilePath("tables/inbound/networks.txt"), ) }) - cfg.outbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) { + cfg.outbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles( repo.FilePath("tables/outbound/single_ips.txt"), repo.FilePath("tables/outbound/networks.txt"), ) }) - fmt.Fprint(os.Stderr, "Loading blocklists... ") - t := time.Now() - if err := blocklists.Load(context.Background()); err != nil { - fmt.Fprintln(os.Stderr) - log.Fatalf("blocklists: %v", err) + loadBlocklists := func() { + fmt.Fprint(os.Stderr, "Loading blocklists... ") + t := time.Now() + if err := blocklists.Load(context.Background()); err != nil { + fmt.Fprintln(os.Stderr) + log.Printf("blocklists: %v", err) + if !asyncServe { + os.Exit(1) + } + return + } + fmt.Fprintf(os.Stderr, "%s (inbound=%s, outbound=%s)\n", + time.Since(t).Round(time.Millisecond), + commafy(cfg.inbound.Value().Size()), + commafy(cfg.outbound.Value().Size()), + ) + } + if !asyncServe { + loadBlocklists() } - fmt.Fprintf(os.Stderr, "%s (inbound=%s, outbound=%s)\n", - time.Since(t).Round(time.Millisecond), - commafy(cfg.inbound.Value().Size()), - commafy(cfg.outbound.Value().Size()), - ) // GeoIP: download the City + ASN tar.gz archives via httpcache // conditional GETs. geoip.Open extracts in-memory — no .mmdb files @@ -206,32 +224,42 @@ func main() { return geoip.Open(maxmindDir) }) fmt.Fprint(os.Stderr, "Loading geoip... ") - t = time.Now() + tGeo := time.Now() if err := geoSet.Load(context.Background()); err != nil { fmt.Fprintln(os.Stderr) log.Fatalf("geoip: %v", err) } - fmt.Fprintf(os.Stderr, "%s\n", time.Since(t).Round(time.Millisecond)) + fmt.Fprintf(os.Stderr, "%s\n", time.Since(tGeo).Round(time.Millisecond)) defer func() { _ = cfg.geo.Value().Close() }() // Whitelist: combined IPs + CIDRs in one file, polled for mtime changes. // A match here overrides any block decision from the blocklists. var whitelistSet *dataset.Set + var loadWhitelist func() if cfg.WhitelistPath != "" { whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath)) - cfg.whitelist = dataset.Add(whitelistSet, func() (*ipcohort.Cohort, error) { + cfg.whitelist = addCohort(whitelistSet, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFile(cfg.WhitelistPath) }) - fmt.Fprint(os.Stderr, "Loading whitelist... ") - t = time.Now() - if err := whitelistSet.Load(context.Background()); err != nil { - fmt.Fprintln(os.Stderr) - log.Fatalf("whitelist: %v", err) + loadWhitelist = func() { + fmt.Fprint(os.Stderr, "Loading whitelist... ") + t := time.Now() + if err := whitelistSet.Load(context.Background()); err != nil { + fmt.Fprintln(os.Stderr) + log.Printf("whitelist: %v", err) + if !asyncServe { + os.Exit(1) + } + return + } + fmt.Fprintf(os.Stderr, "%s (entries=%s)\n", + time.Since(t).Round(time.Millisecond), + commafy(cfg.whitelist.Value().Size()), + ) + } + if !asyncServe { + loadWhitelist() } - fmt.Fprintf(os.Stderr, "%s (entries=%s)\n", - time.Since(t).Round(time.Millisecond), - commafy(cfg.whitelist.Value().Size()), - ) } // Blank line separates the stderr "Loading ..." block from the real @@ -248,6 +276,12 @@ func main() { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() + if asyncServe { + go loadBlocklists() + if loadWhitelist != nil { + go loadWhitelist() + } + } go blocklists.Tick(ctx, refreshInterval, func(err error) { log.Printf("blocklists refresh: %v", err) })