refactor(check-ip): simplify to 4 flags, push MkdirAll into libs

check-ip now takes only --serve, --geoip-conf, --blocklist-repo,
--cache-dir. Blocklist always comes from git; GeoIP mmdbs always go
through httpcache (when GeoIP.conf is available). Format negotiation
lives entirely server-side.

main.go is now straight-line wiring: parse flags, build the two
databases, run the server. All filesystem setup (MkdirAll for clone
target, for cache Path parents) is pushed into gitshallow and
httpcache so the cmd doesn't do filesystem bookkeeping.
This commit is contained in:
AJ ONeal 2026-04-20 15:51:46 -06:00
parent 3b5812ffcd
commit cdce7da04c
No known key found for this signature in database
4 changed files with 130 additions and 299 deletions

View File

@ -1,165 +1,106 @@
// check-ip reports whether an IPv4 address appears in the bitwire-it // check-ip runs an HTTP API that reports whether an IP appears in the
// inbound/outbound blocklists and, when configured, prints GeoIP info. // configured blocklist repo and, when GeoIP.conf is available, enriches
// // the response with MaxMind GeoLite2 City + ASN data.
// Source selection (in order of precedence):
//
// - --inbound / --outbound use local files (no syncing)
// - --git URL shallow-clone a git repo of blocklists
// - (default) fetch raw blocklist files over HTTP with caching
//
// Each mode builds a sync/dataset.Group: one Fetcher shared by the inbound
// and outbound views, so a single git pull (or HTTP-304 cycle) drives both.
//
// --serve turns check-ip into a long-running HTTP server; see server.go.
package main package main
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings"
"syscall" "syscall"
"time" "time"
"github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/gitshallow" "github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset" "github.com/therootcompany/golib/sync/dataset"
) )
const ( const (
bitwireGitURL = "https://github.com/bitwire-it/ipblocklist.git" defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
bitwireRawBase = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables" refreshInterval = 47 * time.Minute
refreshInterval = 47 * time.Minute
) )
type Config struct { type Config struct {
DataDir string Serve string
GitURL string GeoIPConf string
Whitelist string BlocklistRepo string
Inbound string CacheDir string
Outbound string
GeoIPConf string
CityDB string
ASNDB string
Serve string
Format string
} }
func main() { func main() {
cfg := Config{} cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.DataDir, "data-dir", "", "blacklist cache dir (default ~/.cache/bitwire-it)") fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&cfg.GitURL, "git", "", "git URL to clone/pull blacklist from (e.g. "+bitwireGitURL+")") fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.Whitelist, "whitelist", "", "comma-separated paths to whitelist files") fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.Inbound, "inbound", "", "comma-separated paths to inbound blacklist files") fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.StringVar(&cfg.Outbound, "outbound", "", "comma-separated paths to outbound blacklist files")
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (auto-discovered if absent)")
fs.StringVar(&cfg.CityDB, "city-db", "", "path to GeoLite2-City.mmdb (skips auto-download)")
fs.StringVar(&cfg.ASNDB, "asn-db", "", "path to GeoLite2-ASN.mmdb (skips auto-download)")
fs.StringVar(&cfg.Serve, "serve", "", "start HTTP server at addr:port (e.g. :8080) instead of one-shot check")
fs.StringVar(&cfg.Format, "format", "", "output format: pretty, json (default pretty)")
fs.Usage = func() { fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip-address>\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s --serve :8080 [flags]\n", os.Args[0])
fs.PrintDefaults() fs.PrintDefaults()
} }
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintln(os.Stdout, "check-ip")
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintln(os.Stdout, "check-ip")
fmt.Fprintln(os.Stdout)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil { if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) { if errors.Is(err, flag.ErrHelp) {
os.Exit(0) os.Exit(0)
} }
os.Exit(1) os.Exit(1)
} }
format, err := parseFormat(cfg.Format) if cfg.Serve == "" {
if err != nil { fs.Usage()
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop() defer stop()
// Open the three "databases" that feed every IP check: cacheDir := cfg.CacheDir
// if cacheDir == "" {
// 1. blocklists — inbound + outbound cohorts, hot-swapped on refresh base, err := os.UserCacheDir()
// 2. whitelist — static cohort loaded once from disk if err != nil {
// 3. geoip — city + ASN mmdb readers (optional) fatal("cache-dir", err)
// }
// The blocklist Group.Tick goroutine refreshes in the background so the cacheDir = base
// serve path actually exercises dataset's hot-swap.
group, inbound, outbound, err := openBlocklists(cfg)
if err != nil {
fatal("blocklists", err)
} }
// Blocklists: one git repo, two views sharing the same pull.
repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo)
inbound := dataset.Add(group, loadCohort(
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
))
outbound := dataset.Add(group, loadCohort(
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
))
if err := group.Load(ctx); err != nil { if err := group.Load(ctx); err != nil {
fatal("blocklists", err) fatal("blocklists", err)
} }
fmt.Fprintf(os.Stderr, "loaded inbound=%d outbound=%d\n",
inbound.Value().Size(), outbound.Value().Size())
go group.Tick(ctx, refreshInterval, func(err error) { go group.Tick(ctx, refreshInterval, func(err error) {
fmt.Fprintf(os.Stderr, "refresh: %v\n", err) fmt.Fprintf(os.Stderr, "refresh: %v\n", err)
}) })
whitelist, err := openWhitelist(cfg.Whitelist) // GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf
if err != nil { // is available; otherwise read from disk at the cache paths.
fatal("whitelist", err) maxmindDir := filepath.Join(cacheDir, "maxmind")
} geo, err := geoip.OpenDatabases(
cfg.GeoIPConf,
geo, err := geoip.OpenDatabases(cfg.GeoIPConf, cfg.CityDB, cfg.ASNDB) filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"),
filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"),
)
if err != nil { if err != nil {
fatal("geoip", err) fatal("geoip", err)
} }
defer func() { _ = geo.Close() }() defer func() { _ = geo.Close() }()
checker := &Checker{ checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo}
whitelist: whitelist, if err := serve(ctx, cfg.Serve, checker); err != nil {
inbound: inbound, fatal("serve", err)
outbound: outbound,
geo: geo,
}
if cfg.Serve != "" {
if fs.NArg() != 0 {
fmt.Fprintln(os.Stderr, "error: --serve takes no positional args")
os.Exit(1)
}
if err := serve(ctx, cfg, checker); err != nil {
fatal("serve", err)
}
return
}
if fs.NArg() != 1 {
fs.Usage()
os.Exit(1)
}
blocked := checker.Check(fs.Arg(0)).Report(os.Stdout, format)
if blocked {
os.Exit(1)
} }
} }
@ -168,12 +109,11 @@ func fatal(what string, err error) {
os.Exit(1) os.Exit(1)
} }
// Checker bundles the three databases plus the lookup + render logic. // Checker bundles the blocklist views with the optional GeoIP databases.
type Checker struct { type Checker struct {
whitelist *ipcohort.Cohort Inbound *dataset.View[ipcohort.Cohort]
inbound *dataset.View[ipcohort.Cohort] Outbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort] GeoIP *geoip.Databases
geo *geoip.Databases
} }
// Result is the structured verdict for a single IP. // Result is the structured verdict for a single IP.
@ -182,197 +122,28 @@ type Result struct {
Blocked bool `json:"blocked"` Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"` BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"` BlockedOutbound bool `json:"blocked_outbound"`
Whitelisted bool `json:"whitelisted,omitempty"`
Geo geoip.Info `json:"geo,omitzero"` Geo geoip.Info `json:"geo,omitzero"`
} }
// Check returns the structured verdict for ip without rendering. // Check returns the structured verdict for ip.
func (c *Checker) Check(ip string) Result { func (c *Checker) Check(ip string) Result {
whitelisted := c.whitelist != nil && c.whitelist.Contains(ip) in := contains(c.Inbound.Value(), ip)
in := !whitelisted && cohortContains(c.inbound.Value(), ip) out := contains(c.Outbound.Value(), ip)
out := !whitelisted && cohortContains(c.outbound.Value(), ip)
return Result{ return Result{
IP: ip, IP: ip,
Blocked: in || out, Blocked: in || out,
BlockedInbound: in, BlockedInbound: in,
BlockedOutbound: out, BlockedOutbound: out,
Whitelisted: whitelisted, Geo: c.GeoIP.Lookup(ip),
Geo: c.geo.Lookup(ip),
} }
} }
// Format selects the report rendering. func contains(c *ipcohort.Cohort, ip string) bool {
type Format string
const (
FormatPretty Format = "pretty"
FormatJSON Format = "json"
)
func parseFormat(s string) (Format, error) {
switch s {
case "", "pretty":
return FormatPretty, nil
case "json":
return FormatJSON, nil
default:
return "", fmt.Errorf("invalid --format %q (want: pretty, json)", s)
}
}
// Report renders r to w in the given format. Returns r.Blocked for convenience.
func (r Result) Report(w io.Writer, format Format) bool {
switch format {
case FormatJSON:
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(r)
default:
r.writePretty(w)
}
return r.Blocked
}
func (r Result) writePretty(w io.Writer) {
switch {
case r.BlockedInbound && r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP)
case r.BlockedInbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP)
case r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP)
default:
fmt.Fprintf(w, "%s is allowed\n", r.IP)
}
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string
if info.City != "" {
parts = append(parts, info.City)
}
if info.Region != "" {
parts = append(parts, info.Region)
}
if info.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if info.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
}
}
func cohortContains(c *ipcohort.Cohort, ip string) bool {
return c != nil && c.Contains(ip) return c != nil && c.Contains(ip)
} }
// openBlocklists picks a Fetcher based on cfg and wires inbound/outbound views func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) {
// into a shared dataset.Group so one pull drives both.
func openBlocklists(cfg Config) (
_ *dataset.Group,
inbound, outbound *dataset.View[ipcohort.Cohort],
err error,
) {
fetcher, inPaths, outPaths, err := newBlocklistFetcher(cfg)
if err != nil {
return nil, nil, nil, err
}
g := dataset.NewGroup(fetcher)
inbound = dataset.Add(g, loadCohort(inPaths))
outbound = dataset.Add(g, loadCohort(outPaths))
return g, inbound, outbound, nil
}
// newBlocklistFetcher returns a dataset.Fetcher and the on-disk paths each
// view should parse after a sync.
func newBlocklistFetcher(cfg Config) (fetcher dataset.Fetcher, inPaths, outPaths []string, err error) {
switch {
case cfg.Inbound != "" || cfg.Outbound != "":
inPaths := splitCSV(cfg.Inbound)
outPaths := splitCSV(cfg.Outbound)
all := append(append([]string(nil), inPaths...), outPaths...)
return dataset.PollFiles(all...), inPaths, outPaths, nil
case cfg.GitURL != "":
dir, err := cacheDir(cfg.DataDir)
if err != nil {
return nil, nil, nil, err
}
repo := gitshallow.New(cfg.GitURL, dir, 1, "")
return repo,
[]string{
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
},
[]string{
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
},
nil
default:
dir, err := cacheDir(cfg.DataDir)
if err != nil {
return nil, nil, nil, err
}
cachers := []*httpcache.Cacher{
httpcache.New(bitwireRawBase+"/inbound/single_ips.txt", filepath.Join(dir, "inbound_single_ips.txt")),
httpcache.New(bitwireRawBase+"/inbound/networks.txt", filepath.Join(dir, "inbound_networks.txt")),
httpcache.New(bitwireRawBase+"/outbound/single_ips.txt", filepath.Join(dir, "outbound_single_ips.txt")),
httpcache.New(bitwireRawBase+"/outbound/networks.txt", filepath.Join(dir, "outbound_networks.txt")),
}
return dataset.FetcherFunc(func() (bool, error) {
var any bool
for _, c := range cachers {
u, err := c.Fetch()
if err != nil {
return false, err
}
any = any || u
}
return any, nil
}),
[]string{cachers[0].Path, cachers[1].Path},
[]string{cachers[2].Path, cachers[3].Path},
nil
}
}
func loadCohort(paths []string) func() (*ipcohort.Cohort, error) {
return func() (*ipcohort.Cohort, error) { return func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...) return ipcohort.LoadFiles(paths...)
} }
} }
func openWhitelist(paths string) (*ipcohort.Cohort, error) {
if paths == "" {
return nil, nil
}
return ipcohort.LoadFiles(strings.Split(paths, ",")...)
}
func cacheDir(override string) (string, error) {
dir := override
if dir == "" {
base, err := os.UserCacheDir()
if err != nil {
return "", err
}
dir = filepath.Join(base, "bitwire-it")
}
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
return dir, nil
}
func splitCSV(s string) []string {
if s == "" {
return nil
}
return strings.Split(s, ",")
}

View File

@ -2,32 +2,37 @@ package main
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"net/http" "net/http"
"os" "os"
"strings" "strings"
"time" "time"
"github.com/therootcompany/golib/net/geoip"
) )
const shutdownTimeout = 5 * time.Second const shutdownTimeout = 5 * time.Second
// serve runs the HTTP server until ctx is cancelled, shutting down gracefully. // serve runs the HTTP API until ctx is cancelled, shutting down gracefully.
// //
// GET / checks the request's client IP // GET / checks the request's client IP
// GET /check same, plus ?ip= overrides // GET /check same, plus ?ip= overrides
// //
// Format is chosen per request via ?format=, then Accept: application/json. // Response format is chosen per request: ?format=json, then
func serve(ctx context.Context, cfg Config, checker *Checker) error { // Accept: application/json, else pretty text.
func serve(ctx context.Context, bind string, checker *Checker) error {
handle := func(w http.ResponseWriter, r *http.Request, ip string) { handle := func(w http.ResponseWriter, r *http.Request, ip string) {
format := requestFormat(r) format := requestFormat(r)
if format == FormatJSON { if format == formatJSON {
w.Header().Set("Content-Type", "application/json; charset=utf-8") w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else { } else {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
} }
checker.Check(ip).Report(w, format) writeResult(w, checker.Check(ip), format)
} }
mux := http.NewServeMux() mux := http.NewServeMux()
@ -43,7 +48,7 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error {
}) })
srv := &http.Server{ srv := &http.Server{
Addr: cfg.Serve, Addr: bind,
Handler: mux, Handler: mux,
BaseContext: func(_ net.Listener) context.Context { BaseContext: func(_ net.Listener) context.Context {
return ctx return ctx
@ -57,24 +62,72 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error {
_ = srv.Shutdown(shutdownCtx) _ = srv.Shutdown(shutdownCtx)
}() }()
fmt.Fprintf(os.Stderr, "listening on %s\n", cfg.Serve) fmt.Fprintf(os.Stderr, "listening on %s\n", bind)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err return err
} }
return nil return nil
} }
// format is the response rendering. Server-only.
type format int
const (
formatPretty format = iota
formatJSON
)
// requestFormat picks a response format from ?format=, then Accept header. // requestFormat picks a response format from ?format=, then Accept header.
func requestFormat(r *http.Request) Format { func requestFormat(r *http.Request) format {
if q := r.URL.Query().Get("format"); q != "" { switch r.URL.Query().Get("format") {
if f, err := parseFormat(q); err == nil { case "json":
return f return formatJSON
} case "pretty":
return formatPretty
} }
if strings.Contains(r.Header.Get("Accept"), "application/json") { if strings.Contains(r.Header.Get("Accept"), "application/json") {
return FormatJSON return formatJSON
}
return formatPretty
}
func writeResult(w io.Writer, r Result, f format) {
if f == formatJSON {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(r)
return
}
switch {
case r.BlockedInbound && r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP)
case r.BlockedInbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP)
case r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP)
default:
fmt.Fprintf(w, "%s is allowed\n", r.IP)
}
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string
if info.City != "" {
parts = append(parts, info.City)
}
if info.Region != "" {
parts = append(parts, info.Region)
}
if info.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if info.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
} }
return FormatPretty
} }
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present. // clientIP extracts the caller's IP, honoring X-Forwarded-For when present.

View File

@ -76,6 +76,9 @@ func (r *Repo) clone() (bool, error) {
if r.Path == "" { if r.Path == "" {
return false, fmt.Errorf("local path is required") return false, fmt.Errorf("local path is required")
} }
if err := os.MkdirAll(filepath.Dir(r.Path), 0o755); err != nil {
return false, err
}
args := []string{"clone", "--no-tags"} args := []string{"clone", "--no-tags"}
if depth := r.effectiveDepth(); depth >= 0 { if depth := r.effectiveDepth(); depth >= 0 {

View File

@ -7,6 +7,7 @@ import (
"net" "net"
"net/http" "net/http"
"os" "os"
"path/filepath"
"sync" "sync"
"time" "time"
) )
@ -182,6 +183,9 @@ func (c *Cacher) Fetch() (updated bool, err error) {
return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL) return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL)
} }
if err := os.MkdirAll(filepath.Dir(c.Path), 0o755); err != nil {
return false, err
}
if c.Transform != nil { if c.Transform != nil {
if err := c.Transform(resp.Body, c.Path); err != nil { if err := c.Transform(resp.Body, c.Path); err != nil {
return false, err return false, err