mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
feat(check-ip): --async-load flag for non-blocking server startup
With --serve --async-load, blocklists + whitelist start empty and load in background goroutines so the HTTP server binds immediately. /healthz returns 503 until loads complete, then 200. Ignored in CLI mode. Geo stays synchronous — geoip readers aren't nil-safe.
This commit is contained in:
parent
1c99c8b831
commit
b40abe0a06
@ -64,6 +64,7 @@ type IPCheck struct {
|
|||||||
RepoURL string
|
RepoURL string
|
||||||
CacheDir string
|
CacheDir string
|
||||||
WhitelistPath string
|
WhitelistPath string
|
||||||
|
AsyncLoad bool
|
||||||
|
|
||||||
// GeoIPBasicAuth is the pre-encoded Authorization header value for
|
// GeoIPBasicAuth is the pre-encoded Authorization header value for
|
||||||
// MaxMind downloads.
|
// MaxMind downloads.
|
||||||
@ -83,6 +84,7 @@ func main() {
|
|||||||
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
||||||
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: ~/.cache)")
|
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: ~/.cache)")
|
||||||
fs.StringVar(&cfg.WhitelistPath, "whitelist", "", "path to a file of IPs and/or CIDRs (one per line) that override block decisions")
|
fs.StringVar(&cfg.WhitelistPath, "whitelist", "", "path to a file of IPs and/or CIDRs (one per line) that override block decisions")
|
||||||
|
fs.BoolVar(&cfg.AsyncLoad, "async-load", false, "with --serve: start the HTTP server immediately and populate blocklists+whitelist in the background (/healthz returns 503 until ready). Ignored in CLI mode.")
|
||||||
fs.Usage = func() {
|
fs.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip> [ip...]\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip> [ip...]\n", os.Args[0])
|
||||||
fmt.Fprintf(os.Stderr, " %s --serve <bind> [flags]\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, " %s --serve <bind> [flags]\n", os.Args[0])
|
||||||
@ -149,29 +151,45 @@ func main() {
|
|||||||
repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
|
repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
|
||||||
repo.MaxAge = refreshInterval
|
repo.MaxAge = refreshInterval
|
||||||
blocklists := dataset.NewSet(repo)
|
blocklists := dataset.NewSet(repo)
|
||||||
cfg.inbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) {
|
asyncServe := cfg.AsyncLoad && cfg.Bind != ""
|
||||||
|
addCohort := func(s *dataset.Set, loader func() (*ipcohort.Cohort, error)) *dataset.View[ipcohort.Cohort] {
|
||||||
|
if asyncServe {
|
||||||
|
return dataset.AddInitial(s, ipcohort.New(), loader)
|
||||||
|
}
|
||||||
|
return dataset.Add(s, loader)
|
||||||
|
}
|
||||||
|
cfg.inbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) {
|
||||||
return ipcohort.LoadFiles(
|
return ipcohort.LoadFiles(
|
||||||
repo.FilePath("tables/inbound/single_ips.txt"),
|
repo.FilePath("tables/inbound/single_ips.txt"),
|
||||||
repo.FilePath("tables/inbound/networks.txt"),
|
repo.FilePath("tables/inbound/networks.txt"),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
cfg.outbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) {
|
cfg.outbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) {
|
||||||
return ipcohort.LoadFiles(
|
return ipcohort.LoadFiles(
|
||||||
repo.FilePath("tables/outbound/single_ips.txt"),
|
repo.FilePath("tables/outbound/single_ips.txt"),
|
||||||
repo.FilePath("tables/outbound/networks.txt"),
|
repo.FilePath("tables/outbound/networks.txt"),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
loadBlocklists := func() {
|
||||||
fmt.Fprint(os.Stderr, "Loading blocklists... ")
|
fmt.Fprint(os.Stderr, "Loading blocklists... ")
|
||||||
t := time.Now()
|
t := time.Now()
|
||||||
if err := blocklists.Load(context.Background()); err != nil {
|
if err := blocklists.Load(context.Background()); err != nil {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
log.Fatalf("blocklists: %v", err)
|
log.Printf("blocklists: %v", err)
|
||||||
|
if !asyncServe {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "%s (inbound=%s, outbound=%s)\n",
|
fmt.Fprintf(os.Stderr, "%s (inbound=%s, outbound=%s)\n",
|
||||||
time.Since(t).Round(time.Millisecond),
|
time.Since(t).Round(time.Millisecond),
|
||||||
commafy(cfg.inbound.Value().Size()),
|
commafy(cfg.inbound.Value().Size()),
|
||||||
commafy(cfg.outbound.Value().Size()),
|
commafy(cfg.outbound.Value().Size()),
|
||||||
)
|
)
|
||||||
|
}
|
||||||
|
if !asyncServe {
|
||||||
|
loadBlocklists()
|
||||||
|
}
|
||||||
|
|
||||||
// GeoIP: download the City + ASN tar.gz archives via httpcache
|
// GeoIP: download the City + ASN tar.gz archives via httpcache
|
||||||
// conditional GETs. geoip.Open extracts in-memory — no .mmdb files
|
// conditional GETs. geoip.Open extracts in-memory — no .mmdb files
|
||||||
@ -206,33 +224,43 @@ func main() {
|
|||||||
return geoip.Open(maxmindDir)
|
return geoip.Open(maxmindDir)
|
||||||
})
|
})
|
||||||
fmt.Fprint(os.Stderr, "Loading geoip... ")
|
fmt.Fprint(os.Stderr, "Loading geoip... ")
|
||||||
t = time.Now()
|
tGeo := time.Now()
|
||||||
if err := geoSet.Load(context.Background()); err != nil {
|
if err := geoSet.Load(context.Background()); err != nil {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
log.Fatalf("geoip: %v", err)
|
log.Fatalf("geoip: %v", err)
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "%s\n", time.Since(t).Round(time.Millisecond))
|
fmt.Fprintf(os.Stderr, "%s\n", time.Since(tGeo).Round(time.Millisecond))
|
||||||
defer func() { _ = cfg.geo.Value().Close() }()
|
defer func() { _ = cfg.geo.Value().Close() }()
|
||||||
|
|
||||||
// Whitelist: combined IPs + CIDRs in one file, polled for mtime changes.
|
// Whitelist: combined IPs + CIDRs in one file, polled for mtime changes.
|
||||||
// A match here overrides any block decision from the blocklists.
|
// A match here overrides any block decision from the blocklists.
|
||||||
var whitelistSet *dataset.Set
|
var whitelistSet *dataset.Set
|
||||||
|
var loadWhitelist func()
|
||||||
if cfg.WhitelistPath != "" {
|
if cfg.WhitelistPath != "" {
|
||||||
whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath))
|
whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath))
|
||||||
cfg.whitelist = dataset.Add(whitelistSet, func() (*ipcohort.Cohort, error) {
|
cfg.whitelist = addCohort(whitelistSet, func() (*ipcohort.Cohort, error) {
|
||||||
return ipcohort.LoadFile(cfg.WhitelistPath)
|
return ipcohort.LoadFile(cfg.WhitelistPath)
|
||||||
})
|
})
|
||||||
|
loadWhitelist = func() {
|
||||||
fmt.Fprint(os.Stderr, "Loading whitelist... ")
|
fmt.Fprint(os.Stderr, "Loading whitelist... ")
|
||||||
t = time.Now()
|
t := time.Now()
|
||||||
if err := whitelistSet.Load(context.Background()); err != nil {
|
if err := whitelistSet.Load(context.Background()); err != nil {
|
||||||
fmt.Fprintln(os.Stderr)
|
fmt.Fprintln(os.Stderr)
|
||||||
log.Fatalf("whitelist: %v", err)
|
log.Printf("whitelist: %v", err)
|
||||||
|
if !asyncServe {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "%s (entries=%s)\n",
|
fmt.Fprintf(os.Stderr, "%s (entries=%s)\n",
|
||||||
time.Since(t).Round(time.Millisecond),
|
time.Since(t).Round(time.Millisecond),
|
||||||
commafy(cfg.whitelist.Value().Size()),
|
commafy(cfg.whitelist.Value().Size()),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
if !asyncServe {
|
||||||
|
loadWhitelist()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Blank line separates the stderr "Loading ..." block from the real
|
// Blank line separates the stderr "Loading ..." block from the real
|
||||||
// output (stdout results for CLI mode, or the stderr "listening on"
|
// output (stdout results for CLI mode, or the stderr "listening on"
|
||||||
@ -248,6 +276,12 @@ func main() {
|
|||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
|
if asyncServe {
|
||||||
|
go loadBlocklists()
|
||||||
|
if loadWhitelist != nil {
|
||||||
|
go loadWhitelist()
|
||||||
|
}
|
||||||
|
}
|
||||||
go blocklists.Tick(ctx, refreshInterval, func(err error) {
|
go blocklists.Tick(ctx, refreshInterval, func(err error) {
|
||||||
log.Printf("blocklists refresh: %v", err)
|
log.Printf("blocklists refresh: %v", err)
|
||||||
})
|
})
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user