refactor(check-ip): simplify to 4 flags, push MkdirAll into libs

check-ip now takes only --serve, --geoip-conf, --blocklist-repo,
--cache-dir. Blocklist always comes from git; GeoIP mmdbs always go
through httpcache (when GeoIP.conf is available). Format negotiation
lives entirely server-side.

main.go is now straight-line wiring: parse flags, build the two
databases, run the server. All filesystem setup (MkdirAll for clone
target, for cache Path parents) is pushed into gitshallow and
httpcache so the cmd doesn't do filesystem bookkeeping.
This commit is contained in:
AJ ONeal 2026-04-20 15:51:46 -06:00
parent 3b5812ffcd
commit cdce7da04c
No known key found for this signature in database
4 changed files with 130 additions and 299 deletions

View File

@ -1,165 +1,106 @@
// check-ip reports whether an IPv4 address appears in the bitwire-it
// inbound/outbound blocklists and, when configured, prints GeoIP info.
//
// Source selection (in order of precedence):
//
// - --inbound / --outbound use local files (no syncing)
// - --git URL shallow-clone a git repo of blocklists
// - (default) fetch raw blocklist files over HTTP with caching
//
// Each mode builds a sync/dataset.Group: one Fetcher shared by the inbound
// and outbound views, so a single git pull (or HTTP-304 cycle) drives both.
//
// --serve turns check-ip into a long-running HTTP server; see server.go.
// check-ip runs an HTTP API that reports whether an IP appears in the
// configured blocklist repo and, when GeoIP.conf is available, enriches
// the response with MaxMind GeoLite2 City + ASN data.
package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
"github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset"
)
const (
bitwireGitURL = "https://github.com/bitwire-it/ipblocklist.git"
bitwireRawBase = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables"
refreshInterval = 47 * time.Minute
defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
refreshInterval = 47 * time.Minute
)
type Config struct {
DataDir string
GitURL string
Whitelist string
Inbound string
Outbound string
GeoIPConf string
CityDB string
ASNDB string
Serve string
Format string
Serve string
GeoIPConf string
BlocklistRepo string
CacheDir string
}
func main() {
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.DataDir, "data-dir", "", "blacklist cache dir (default ~/.cache/bitwire-it)")
fs.StringVar(&cfg.GitURL, "git", "", "git URL to clone/pull blacklist from (e.g. "+bitwireGitURL+")")
fs.StringVar(&cfg.Whitelist, "whitelist", "", "comma-separated paths to whitelist files")
fs.StringVar(&cfg.Inbound, "inbound", "", "comma-separated paths to inbound blacklist files")
fs.StringVar(&cfg.Outbound, "outbound", "", "comma-separated paths to outbound blacklist files")
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (auto-discovered if absent)")
fs.StringVar(&cfg.CityDB, "city-db", "", "path to GeoLite2-City.mmdb (skips auto-download)")
fs.StringVar(&cfg.ASNDB, "asn-db", "", "path to GeoLite2-ASN.mmdb (skips auto-download)")
fs.StringVar(&cfg.Serve, "serve", "", "start HTTP server at addr:port (e.g. :8080) instead of one-shot check")
fs.StringVar(&cfg.Format, "format", "", "output format: pretty, json (default pretty)")
fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip-address>\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s --serve :8080 [flags]\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fs.PrintDefaults()
}
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintln(os.Stdout, "check-ip")
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintln(os.Stdout, "check-ip")
fmt.Fprintln(os.Stdout)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1)
}
format, err := parseFormat(cfg.Format)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
if cfg.Serve == "" {
fs.Usage()
os.Exit(1)
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// Open the three "databases" that feed every IP check:
//
// 1. blocklists — inbound + outbound cohorts, hot-swapped on refresh
// 2. whitelist — static cohort loaded once from disk
// 3. geoip — city + ASN mmdb readers (optional)
//
// The blocklist Group.Tick goroutine refreshes in the background so the
// serve path actually exercises dataset's hot-swap.
group, inbound, outbound, err := openBlocklists(cfg)
if err != nil {
fatal("blocklists", err)
cacheDir := cfg.CacheDir
if cacheDir == "" {
base, err := os.UserCacheDir()
if err != nil {
fatal("cache-dir", err)
}
cacheDir = base
}
// Blocklists: one git repo, two views sharing the same pull.
repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo)
inbound := dataset.Add(group, loadCohort(
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
))
outbound := dataset.Add(group, loadCohort(
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
))
if err := group.Load(ctx); err != nil {
fatal("blocklists", err)
}
fmt.Fprintf(os.Stderr, "loaded inbound=%d outbound=%d\n",
inbound.Value().Size(), outbound.Value().Size())
go group.Tick(ctx, refreshInterval, func(err error) {
fmt.Fprintf(os.Stderr, "refresh: %v\n", err)
})
whitelist, err := openWhitelist(cfg.Whitelist)
if err != nil {
fatal("whitelist", err)
}
geo, err := geoip.OpenDatabases(cfg.GeoIPConf, cfg.CityDB, cfg.ASNDB)
// GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf
// is available; otherwise read from disk at the cache paths.
maxmindDir := filepath.Join(cacheDir, "maxmind")
geo, err := geoip.OpenDatabases(
cfg.GeoIPConf,
filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"),
filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"),
)
if err != nil {
fatal("geoip", err)
}
defer func() { _ = geo.Close() }()
checker := &Checker{
whitelist: whitelist,
inbound: inbound,
outbound: outbound,
geo: geo,
}
if cfg.Serve != "" {
if fs.NArg() != 0 {
fmt.Fprintln(os.Stderr, "error: --serve takes no positional args")
os.Exit(1)
}
if err := serve(ctx, cfg, checker); err != nil {
fatal("serve", err)
}
return
}
if fs.NArg() != 1 {
fs.Usage()
os.Exit(1)
}
blocked := checker.Check(fs.Arg(0)).Report(os.Stdout, format)
if blocked {
os.Exit(1)
checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo}
if err := serve(ctx, cfg.Serve, checker); err != nil {
fatal("serve", err)
}
}
@ -168,12 +109,11 @@ func fatal(what string, err error) {
os.Exit(1)
}
// Checker bundles the three databases plus the lookup + render logic.
// Checker bundles the blocklist views with the optional GeoIP databases.
type Checker struct {
whitelist *ipcohort.Cohort
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
geo *geoip.Databases
Inbound *dataset.View[ipcohort.Cohort]
Outbound *dataset.View[ipcohort.Cohort]
GeoIP *geoip.Databases
}
// Result is the structured verdict for a single IP.
@ -182,197 +122,28 @@ type Result struct {
Blocked bool `json:"blocked"`
BlockedInbound bool `json:"blocked_inbound"`
BlockedOutbound bool `json:"blocked_outbound"`
Whitelisted bool `json:"whitelisted,omitempty"`
Geo geoip.Info `json:"geo,omitzero"`
}
// Check returns the structured verdict for ip without rendering.
// Check returns the structured verdict for ip.
func (c *Checker) Check(ip string) Result {
whitelisted := c.whitelist != nil && c.whitelist.Contains(ip)
in := !whitelisted && cohortContains(c.inbound.Value(), ip)
out := !whitelisted && cohortContains(c.outbound.Value(), ip)
in := contains(c.Inbound.Value(), ip)
out := contains(c.Outbound.Value(), ip)
return Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Whitelisted: whitelisted,
Geo: c.geo.Lookup(ip),
Geo: c.GeoIP.Lookup(ip),
}
}
// Format selects the report rendering.
type Format string
const (
FormatPretty Format = "pretty"
FormatJSON Format = "json"
)
func parseFormat(s string) (Format, error) {
switch s {
case "", "pretty":
return FormatPretty, nil
case "json":
return FormatJSON, nil
default:
return "", fmt.Errorf("invalid --format %q (want: pretty, json)", s)
}
}
// Report renders r to w in the given format. Returns r.Blocked for convenience.
func (r Result) Report(w io.Writer, format Format) bool {
switch format {
case FormatJSON:
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(r)
default:
r.writePretty(w)
}
return r.Blocked
}
func (r Result) writePretty(w io.Writer) {
switch {
case r.BlockedInbound && r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP)
case r.BlockedInbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP)
case r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP)
default:
fmt.Fprintf(w, "%s is allowed\n", r.IP)
}
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string
if info.City != "" {
parts = append(parts, info.City)
}
if info.Region != "" {
parts = append(parts, info.Region)
}
if info.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if info.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
}
}
func cohortContains(c *ipcohort.Cohort, ip string) bool {
func contains(c *ipcohort.Cohort, ip string) bool {
return c != nil && c.Contains(ip)
}
// openBlocklists picks a Fetcher based on cfg and wires inbound/outbound views
// into a shared dataset.Group so one pull drives both.
func openBlocklists(cfg Config) (
_ *dataset.Group,
inbound, outbound *dataset.View[ipcohort.Cohort],
err error,
) {
fetcher, inPaths, outPaths, err := newBlocklistFetcher(cfg)
if err != nil {
return nil, nil, nil, err
}
g := dataset.NewGroup(fetcher)
inbound = dataset.Add(g, loadCohort(inPaths))
outbound = dataset.Add(g, loadCohort(outPaths))
return g, inbound, outbound, nil
}
// newBlocklistFetcher returns a dataset.Fetcher and the on-disk paths each
// view should parse after a sync.
func newBlocklistFetcher(cfg Config) (fetcher dataset.Fetcher, inPaths, outPaths []string, err error) {
switch {
case cfg.Inbound != "" || cfg.Outbound != "":
inPaths := splitCSV(cfg.Inbound)
outPaths := splitCSV(cfg.Outbound)
all := append(append([]string(nil), inPaths...), outPaths...)
return dataset.PollFiles(all...), inPaths, outPaths, nil
case cfg.GitURL != "":
dir, err := cacheDir(cfg.DataDir)
if err != nil {
return nil, nil, nil, err
}
repo := gitshallow.New(cfg.GitURL, dir, 1, "")
return repo,
[]string{
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
},
[]string{
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
},
nil
default:
dir, err := cacheDir(cfg.DataDir)
if err != nil {
return nil, nil, nil, err
}
cachers := []*httpcache.Cacher{
httpcache.New(bitwireRawBase+"/inbound/single_ips.txt", filepath.Join(dir, "inbound_single_ips.txt")),
httpcache.New(bitwireRawBase+"/inbound/networks.txt", filepath.Join(dir, "inbound_networks.txt")),
httpcache.New(bitwireRawBase+"/outbound/single_ips.txt", filepath.Join(dir, "outbound_single_ips.txt")),
httpcache.New(bitwireRawBase+"/outbound/networks.txt", filepath.Join(dir, "outbound_networks.txt")),
}
return dataset.FetcherFunc(func() (bool, error) {
var any bool
for _, c := range cachers {
u, err := c.Fetch()
if err != nil {
return false, err
}
any = any || u
}
return any, nil
}),
[]string{cachers[0].Path, cachers[1].Path},
[]string{cachers[2].Path, cachers[3].Path},
nil
}
}
func loadCohort(paths []string) func() (*ipcohort.Cohort, error) {
func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) {
return func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
}
}
func openWhitelist(paths string) (*ipcohort.Cohort, error) {
if paths == "" {
return nil, nil
}
return ipcohort.LoadFiles(strings.Split(paths, ",")...)
}
func cacheDir(override string) (string, error) {
dir := override
if dir == "" {
base, err := os.UserCacheDir()
if err != nil {
return "", err
}
dir = filepath.Join(base, "bitwire-it")
}
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", err
}
return dir, nil
}
func splitCSV(s string) []string {
if s == "" {
return nil
}
return strings.Split(s, ",")
}

View File

@ -2,32 +2,37 @@ package main
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/therootcompany/golib/net/geoip"
)
const shutdownTimeout = 5 * time.Second
// serve runs the HTTP server until ctx is cancelled, shutting down gracefully.
// serve runs the HTTP API until ctx is cancelled, shutting down gracefully.
//
// GET / checks the request's client IP
// GET /check same, plus ?ip= overrides
//
// Format is chosen per request via ?format=, then Accept: application/json.
func serve(ctx context.Context, cfg Config, checker *Checker) error {
// Response format is chosen per request: ?format=json, then
// Accept: application/json, else pretty text.
func serve(ctx context.Context, bind string, checker *Checker) error {
handle := func(w http.ResponseWriter, r *http.Request, ip string) {
format := requestFormat(r)
if format == FormatJSON {
if format == formatJSON {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
} else {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
checker.Check(ip).Report(w, format)
writeResult(w, checker.Check(ip), format)
}
mux := http.NewServeMux()
@ -43,7 +48,7 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error {
})
srv := &http.Server{
Addr: cfg.Serve,
Addr: bind,
Handler: mux,
BaseContext: func(_ net.Listener) context.Context {
return ctx
@ -57,24 +62,72 @@ func serve(ctx context.Context, cfg Config, checker *Checker) error {
_ = srv.Shutdown(shutdownCtx)
}()
fmt.Fprintf(os.Stderr, "listening on %s\n", cfg.Serve)
fmt.Fprintf(os.Stderr, "listening on %s\n", bind)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
// format is the response rendering. Server-only.
type format int
const (
formatPretty format = iota
formatJSON
)
// requestFormat picks a response format from ?format=, then Accept header.
func requestFormat(r *http.Request) Format {
if q := r.URL.Query().Get("format"); q != "" {
if f, err := parseFormat(q); err == nil {
return f
}
func requestFormat(r *http.Request) format {
switch r.URL.Query().Get("format") {
case "json":
return formatJSON
case "pretty":
return formatPretty
}
if strings.Contains(r.Header.Get("Accept"), "application/json") {
return FormatJSON
return formatJSON
}
return formatPretty
}
func writeResult(w io.Writer, r Result, f format) {
if f == formatJSON {
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(r)
return
}
switch {
case r.BlockedInbound && r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP)
case r.BlockedInbound:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP)
case r.BlockedOutbound:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP)
default:
fmt.Fprintf(w, "%s is allowed\n", r.IP)
}
writeGeo(w, r.Geo)
}
func writeGeo(w io.Writer, info geoip.Info) {
var parts []string
if info.City != "" {
parts = append(parts, info.City)
}
if info.Region != "" {
parts = append(parts, info.Region)
}
if info.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if info.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
}
return FormatPretty
}
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.

View File

@ -76,6 +76,9 @@ func (r *Repo) clone() (bool, error) {
if r.Path == "" {
return false, fmt.Errorf("local path is required")
}
if err := os.MkdirAll(filepath.Dir(r.Path), 0o755); err != nil {
return false, err
}
args := []string{"clone", "--no-tags"}
if depth := r.effectiveDepth(); depth >= 0 {

View File

@ -7,6 +7,7 @@ import (
"net"
"net/http"
"os"
"path/filepath"
"sync"
"time"
)
@ -182,6 +183,9 @@ func (c *Cacher) Fetch() (updated bool, err error) {
return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL)
}
if err := os.MkdirAll(filepath.Dir(c.Path), 0o755); err != nil {
return false, err
}
if c.Transform != nil {
if err := c.Transform(resp.Body, c.Path); err != nil {
return false, err