From cdce7da04cfbe9dfe251431eeaa75d962731fc81 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 15:51:46 -0600 Subject: [PATCH] refactor(check-ip): simplify to 4 flags, push MkdirAll into libs check-ip now takes only --serve, --geoip-conf, --blocklist-repo, --cache-dir. Blocklist always comes from git; GeoIP mmdbs always go through httpcache (when GeoIP.conf is available). Format negotiation lives entirely server-side. main.go is now straight-line wiring: parse flags, build the two databases, run the server. All filesystem setup (MkdirAll for clone target, for cache Path parents) is pushed into gitshallow and httpcache so the cmd doesn't do filesystem bookkeeping. --- cmd/check-ip/main.go | 341 ++++++----------------------------- cmd/check-ip/server.go | 81 +++++++-- net/gitshallow/gitshallow.go | 3 + net/httpcache/httpcache.go | 4 + 4 files changed, 130 insertions(+), 299 deletions(-) diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index ddb497b..2093353 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -1,165 +1,106 @@ -// check-ip reports whether an IPv4 address appears in the bitwire-it -// inbound/outbound blocklists and, when configured, prints GeoIP info. -// -// Source selection (in order of precedence): -// -// - --inbound / --outbound use local files (no syncing) -// - --git URL shallow-clone a git repo of blocklists -// - (default) fetch raw blocklist files over HTTP with caching -// -// Each mode builds a sync/dataset.Group: one Fetcher shared by the inbound -// and outbound views, so a single git pull (or HTTP-304 cycle) drives both. -// -// --serve turns check-ip into a long-running HTTP server; see server.go. +// check-ip runs an HTTP API that reports whether an IP appears in the +// configured blocklist repo and, when GeoIP.conf is available, enriches +// the response with MaxMind GeoLite2 City + ASN data. package main import ( "context" - "encoding/json" "errors" "flag" "fmt" - "io" "os" "os/signal" "path/filepath" - "strings" "syscall" "time" "github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/gitshallow" - "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/sync/dataset" ) const ( - bitwireGitURL = "https://github.com/bitwire-it/ipblocklist.git" - bitwireRawBase = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables" - - refreshInterval = 47 * time.Minute + defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git" + refreshInterval = 47 * time.Minute ) type Config struct { - DataDir string - GitURL string - Whitelist string - Inbound string - Outbound string - GeoIPConf string - CityDB string - ASNDB string - Serve string - Format string + Serve string + GeoIPConf string + BlocklistRepo string + CacheDir string } func main() { cfg := Config{} fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - fs.StringVar(&cfg.DataDir, "data-dir", "", "blacklist cache dir (default ~/.cache/bitwire-it)") - fs.StringVar(&cfg.GitURL, "git", "", "git URL to clone/pull blacklist from (e.g. "+bitwireGitURL+")") - fs.StringVar(&cfg.Whitelist, "whitelist", "", "comma-separated paths to whitelist files") - fs.StringVar(&cfg.Inbound, "inbound", "", "comma-separated paths to inbound blacklist files") - fs.StringVar(&cfg.Outbound, "outbound", "", "comma-separated paths to outbound blacklist files") - fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (auto-discovered if absent)") - fs.StringVar(&cfg.CityDB, "city-db", "", "path to GeoLite2-City.mmdb (skips auto-download)") - fs.StringVar(&cfg.ASNDB, "asn-db", "", "path to GeoLite2-ASN.mmdb (skips auto-download)") - fs.StringVar(&cfg.Serve, "serve", "", "start HTTP server at addr:port (e.g. :8080) instead of one-shot check") - fs.StringVar(&cfg.Format, "format", "", "output format: pretty, json (default pretty)") + fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080") + fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") + fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") + fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s [flags] \n", os.Args[0]) - fmt.Fprintf(os.Stderr, " %s --serve :8080 [flags]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Usage: %s --serve [flags]\n", os.Args[0]) fs.PrintDefaults() } - - if len(os.Args) > 1 { - switch os.Args[1] { - case "-V", "-version", "--version", "version": - fmt.Fprintln(os.Stdout, "check-ip") - os.Exit(0) - case "help", "-help", "--help": - fmt.Fprintln(os.Stdout, "check-ip") - fmt.Fprintln(os.Stdout) - fs.SetOutput(os.Stdout) - fs.Usage() - os.Exit(0) - } - } if err := fs.Parse(os.Args[1:]); err != nil { if errors.Is(err, flag.ErrHelp) { os.Exit(0) } os.Exit(1) } - format, err := parseFormat(cfg.Format) - if err != nil { - fmt.Fprintf(os.Stderr, "error: %v\n", err) + if cfg.Serve == "" { + fs.Usage() os.Exit(1) } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - // Open the three "databases" that feed every IP check: - // - // 1. blocklists — inbound + outbound cohorts, hot-swapped on refresh - // 2. whitelist — static cohort loaded once from disk - // 3. geoip — city + ASN mmdb readers (optional) - // - // The blocklist Group.Tick goroutine refreshes in the background so the - // serve path actually exercises dataset's hot-swap. - - group, inbound, outbound, err := openBlocklists(cfg) - if err != nil { - fatal("blocklists", err) + cacheDir := cfg.CacheDir + if cacheDir == "" { + base, err := os.UserCacheDir() + if err != nil { + fatal("cache-dir", err) + } + cacheDir = base } + + // Blocklists: one git repo, two views sharing the same pull. + repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "") + group := dataset.NewGroup(repo) + inbound := dataset.Add(group, loadCohort( + repo.FilePath("tables/inbound/single_ips.txt"), + repo.FilePath("tables/inbound/networks.txt"), + )) + outbound := dataset.Add(group, loadCohort( + repo.FilePath("tables/outbound/single_ips.txt"), + repo.FilePath("tables/outbound/networks.txt"), + )) if err := group.Load(ctx); err != nil { fatal("blocklists", err) } - fmt.Fprintf(os.Stderr, "loaded inbound=%d outbound=%d\n", - inbound.Value().Size(), outbound.Value().Size()) go group.Tick(ctx, refreshInterval, func(err error) { fmt.Fprintf(os.Stderr, "refresh: %v\n", err) }) - whitelist, err := openWhitelist(cfg.Whitelist) - if err != nil { - fatal("whitelist", err) - } - - geo, err := geoip.OpenDatabases(cfg.GeoIPConf, cfg.CityDB, cfg.ASNDB) + // GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf + // is available; otherwise read from disk at the cache paths. + maxmindDir := filepath.Join(cacheDir, "maxmind") + geo, err := geoip.OpenDatabases( + cfg.GeoIPConf, + filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"), + filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"), + ) if err != nil { fatal("geoip", err) } defer func() { _ = geo.Close() }() - checker := &Checker{ - whitelist: whitelist, - inbound: inbound, - outbound: outbound, - geo: geo, - } - - if cfg.Serve != "" { - if fs.NArg() != 0 { - fmt.Fprintln(os.Stderr, "error: --serve takes no positional args") - os.Exit(1) - } - if err := serve(ctx, cfg, checker); err != nil { - fatal("serve", err) - } - return - } - - if fs.NArg() != 1 { - fs.Usage() - os.Exit(1) - } - blocked := checker.Check(fs.Arg(0)).Report(os.Stdout, format) - if blocked { - os.Exit(1) + checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo} + if err := serve(ctx, cfg.Serve, checker); err != nil { + fatal("serve", err) } } @@ -168,12 +109,11 @@ func fatal(what string, err error) { os.Exit(1) } -// Checker bundles the three databases plus the lookup + render logic. +// Checker bundles the blocklist views with the optional GeoIP databases. type Checker struct { - whitelist *ipcohort.Cohort - inbound *dataset.View[ipcohort.Cohort] - outbound *dataset.View[ipcohort.Cohort] - geo *geoip.Databases + Inbound *dataset.View[ipcohort.Cohort] + Outbound *dataset.View[ipcohort.Cohort] + GeoIP *geoip.Databases } // Result is the structured verdict for a single IP. @@ -182,197 +122,28 @@ type Result struct { Blocked bool `json:"blocked"` BlockedInbound bool `json:"blocked_inbound"` BlockedOutbound bool `json:"blocked_outbound"` - Whitelisted bool `json:"whitelisted,omitempty"` Geo geoip.Info `json:"geo,omitzero"` } -// Check returns the structured verdict for ip without rendering. +// Check returns the structured verdict for ip. func (c *Checker) Check(ip string) Result { - whitelisted := c.whitelist != nil && c.whitelist.Contains(ip) - in := !whitelisted && cohortContains(c.inbound.Value(), ip) - out := !whitelisted && cohortContains(c.outbound.Value(), ip) + in := contains(c.Inbound.Value(), ip) + out := contains(c.Outbound.Value(), ip) return Result{ IP: ip, Blocked: in || out, BlockedInbound: in, BlockedOutbound: out, - Whitelisted: whitelisted, - Geo: c.geo.Lookup(ip), + Geo: c.GeoIP.Lookup(ip), } } -// Format selects the report rendering. -type Format string - -const ( - FormatPretty Format = "pretty" - FormatJSON Format = "json" -) - -func parseFormat(s string) (Format, error) { - switch s { - case "", "pretty": - return FormatPretty, nil - case "json": - return FormatJSON, nil - default: - return "", fmt.Errorf("invalid --format %q (want: pretty, json)", s) - } -} - -// Report renders r to w in the given format. Returns r.Blocked for convenience. -func (r Result) Report(w io.Writer, format Format) bool { - switch format { - case FormatJSON: - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - _ = enc.Encode(r) - default: - r.writePretty(w) - } - return r.Blocked -} - -func (r Result) writePretty(w io.Writer) { - switch { - case r.BlockedInbound && r.BlockedOutbound: - fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP) - case r.BlockedInbound: - fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP) - case r.BlockedOutbound: - fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP) - default: - fmt.Fprintf(w, "%s is allowed\n", r.IP) - } - writeGeo(w, r.Geo) -} - -func writeGeo(w io.Writer, info geoip.Info) { - var parts []string - if info.City != "" { - parts = append(parts, info.City) - } - if info.Region != "" { - parts = append(parts, info.Region) - } - if info.Country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO)) - } - if len(parts) > 0 { - fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) - } - if info.ASN != 0 { - fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg) - } -} - -func cohortContains(c *ipcohort.Cohort, ip string) bool { +func contains(c *ipcohort.Cohort, ip string) bool { return c != nil && c.Contains(ip) } -// openBlocklists picks a Fetcher based on cfg and wires inbound/outbound views -// into a shared dataset.Group so one pull drives both. -func openBlocklists(cfg Config) ( - _ *dataset.Group, - inbound, outbound *dataset.View[ipcohort.Cohort], - err error, -) { - fetcher, inPaths, outPaths, err := newBlocklistFetcher(cfg) - if err != nil { - return nil, nil, nil, err - } - g := dataset.NewGroup(fetcher) - inbound = dataset.Add(g, loadCohort(inPaths)) - outbound = dataset.Add(g, loadCohort(outPaths)) - return g, inbound, outbound, nil -} - -// newBlocklistFetcher returns a dataset.Fetcher and the on-disk paths each -// view should parse after a sync. -func newBlocklistFetcher(cfg Config) (fetcher dataset.Fetcher, inPaths, outPaths []string, err error) { - switch { - case cfg.Inbound != "" || cfg.Outbound != "": - inPaths := splitCSV(cfg.Inbound) - outPaths := splitCSV(cfg.Outbound) - all := append(append([]string(nil), inPaths...), outPaths...) - return dataset.PollFiles(all...), inPaths, outPaths, nil - - case cfg.GitURL != "": - dir, err := cacheDir(cfg.DataDir) - if err != nil { - return nil, nil, nil, err - } - repo := gitshallow.New(cfg.GitURL, dir, 1, "") - return repo, - []string{ - repo.FilePath("tables/inbound/single_ips.txt"), - repo.FilePath("tables/inbound/networks.txt"), - }, - []string{ - repo.FilePath("tables/outbound/single_ips.txt"), - repo.FilePath("tables/outbound/networks.txt"), - }, - nil - - default: - dir, err := cacheDir(cfg.DataDir) - if err != nil { - return nil, nil, nil, err - } - cachers := []*httpcache.Cacher{ - httpcache.New(bitwireRawBase+"/inbound/single_ips.txt", filepath.Join(dir, "inbound_single_ips.txt")), - httpcache.New(bitwireRawBase+"/inbound/networks.txt", filepath.Join(dir, "inbound_networks.txt")), - httpcache.New(bitwireRawBase+"/outbound/single_ips.txt", filepath.Join(dir, "outbound_single_ips.txt")), - httpcache.New(bitwireRawBase+"/outbound/networks.txt", filepath.Join(dir, "outbound_networks.txt")), - } - return dataset.FetcherFunc(func() (bool, error) { - var any bool - for _, c := range cachers { - u, err := c.Fetch() - if err != nil { - return false, err - } - any = any || u - } - return any, nil - }), - []string{cachers[0].Path, cachers[1].Path}, - []string{cachers[2].Path, cachers[3].Path}, - nil - } -} - -func loadCohort(paths []string) func() (*ipcohort.Cohort, error) { +func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) { return func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles(paths...) } } - -func openWhitelist(paths string) (*ipcohort.Cohort, error) { - if paths == "" { - return nil, nil - } - return ipcohort.LoadFiles(strings.Split(paths, ",")...) -} - -func cacheDir(override string) (string, error) { - dir := override - if dir == "" { - base, err := os.UserCacheDir() - if err != nil { - return "", err - } - dir = filepath.Join(base, "bitwire-it") - } - if err := os.MkdirAll(dir, 0o755); err != nil { - return "", err - } - return dir, nil -} - -func splitCSV(s string) []string { - if s == "" { - return nil - } - return strings.Split(s, ",") -} diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index 108dd41..284dae1 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -2,32 +2,37 @@ package main import ( "context" + "encoding/json" "errors" "fmt" + "io" "net" "net/http" "os" "strings" "time" + + "github.com/therootcompany/golib/net/geoip" ) const shutdownTimeout = 5 * time.Second -// serve runs the HTTP server until ctx is cancelled, shutting down gracefully. +// serve runs the HTTP API until ctx is cancelled, shutting down gracefully. // // GET / checks the request's client IP // GET /check same, plus ?ip= overrides // -// Format is chosen per request via ?format=, then Accept: application/json. -func serve(ctx context.Context, cfg Config, checker *Checker) error { +// Response format is chosen per request: ?format=json, then +// Accept: application/json, else pretty text. +func serve(ctx context.Context, bind string, checker *Checker) error { handle := func(w http.ResponseWriter, r *http.Request, ip string) { format := requestFormat(r) - if format == FormatJSON { + if format == formatJSON { w.Header().Set("Content-Type", "application/json; charset=utf-8") } else { w.Header().Set("Content-Type", "text/plain; charset=utf-8") } - checker.Check(ip).Report(w, format) + writeResult(w, checker.Check(ip), format) } mux := http.NewServeMux() @@ -43,7 +48,7 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error { }) srv := &http.Server{ - Addr: cfg.Serve, + Addr: bind, Handler: mux, BaseContext: func(_ net.Listener) context.Context { return ctx @@ -57,24 +62,72 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error { _ = srv.Shutdown(shutdownCtx) }() - fmt.Fprintf(os.Stderr, "listening on %s\n", cfg.Serve) + fmt.Fprintf(os.Stderr, "listening on %s\n", bind) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } return nil } +// format is the response rendering. Server-only. +type format int + +const ( + formatPretty format = iota + formatJSON +) + // requestFormat picks a response format from ?format=, then Accept header. -func requestFormat(r *http.Request) Format { - if q := r.URL.Query().Get("format"); q != "" { - if f, err := parseFormat(q); err == nil { - return f - } +func requestFormat(r *http.Request) format { + switch r.URL.Query().Get("format") { + case "json": + return formatJSON + case "pretty": + return formatPretty } if strings.Contains(r.Header.Get("Accept"), "application/json") { - return FormatJSON + return formatJSON + } + return formatPretty +} + +func writeResult(w io.Writer, r Result, f format) { + if f == formatJSON { + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + _ = enc.Encode(r) + return + } + switch { + case r.BlockedInbound && r.BlockedOutbound: + fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP) + case r.BlockedInbound: + fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP) + case r.BlockedOutbound: + fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP) + default: + fmt.Fprintf(w, "%s is allowed\n", r.IP) + } + writeGeo(w, r.Geo) +} + +func writeGeo(w io.Writer, info geoip.Info) { + var parts []string + if info.City != "" { + parts = append(parts, info.City) + } + if info.Region != "" { + parts = append(parts, info.Region) + } + if info.Country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO)) + } + if len(parts) > 0 { + fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) + } + if info.ASN != 0 { + fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg) } - return FormatPretty } // clientIP extracts the caller's IP, honoring X-Forwarded-For when present. diff --git a/net/gitshallow/gitshallow.go b/net/gitshallow/gitshallow.go index e5867cf..b5dca2d 100644 --- a/net/gitshallow/gitshallow.go +++ b/net/gitshallow/gitshallow.go @@ -76,6 +76,9 @@ func (r *Repo) clone() (bool, error) { if r.Path == "" { return false, fmt.Errorf("local path is required") } + if err := os.MkdirAll(filepath.Dir(r.Path), 0o755); err != nil { + return false, err + } args := []string{"clone", "--no-tags"} if depth := r.effectiveDepth(); depth >= 0 { diff --git a/net/httpcache/httpcache.go b/net/httpcache/httpcache.go index 99fb353..1f5f80f 100644 --- a/net/httpcache/httpcache.go +++ b/net/httpcache/httpcache.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "os" + "path/filepath" "sync" "time" ) @@ -182,6 +183,9 @@ func (c *Cacher) Fetch() (updated bool, err error) { return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL) } + if err := os.MkdirAll(filepath.Dir(c.Path), 0o755); err != nil { + return false, err + } if c.Transform != nil { if err := c.Transform(resp.Body, c.Path); err != nil { return false, err