From b9295608dba1b9374c766f3b4426f5e3548516fe Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 16:40:56 -0600 Subject: [PATCH] feat(check-ip): accept IP args, require --serve or args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Positional args are IPs to check and print; at least one IP or --serve must be provided. Refactor server.go: split handle into lookup + writeText methods so main can reuse them. geoip is no longer managed via dataset.Group — it's a single atomic.Pointer[geoip.Databases]. The Fetcher (httpcache or PollFiles) still drives refresh, but via an inline ticker in the serve branch that fetches, reopens, and swaps. --- cmd/check-ip/main.go | 58 ++++++++++++++++++++++++++------- cmd/check-ip/server.go | 74 ++++++++++++++++++++++++------------------ 2 files changed, 90 insertions(+), 42 deletions(-) diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index 1376e82..6c5dfa3 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -12,6 +12,7 @@ import ( "os" "os/signal" "path/filepath" + "sync/atomic" "syscall" "time" @@ -43,7 +44,7 @@ type IPCheck struct { inbound *dataset.View[ipcohort.Cohort] outbound *dataset.View[ipcohort.Cohort] - geo *dataset.View[geoip.Databases] + geo atomic.Pointer[geoip.Databases] } func main() { @@ -54,7 +55,8 @@ func main() { fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.Usage = func() { - fmt.Fprintf(os.Stderr, "Usage: %s --serve [flags]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Usage: %s [flags] [ip...]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --serve [flags]\n", os.Args[0]) fs.PrintDefaults() } @@ -77,6 +79,12 @@ func main() { } os.Exit(1) } + ips := fs.Args() + if cfg.Bind == "" && len(ips) == 0 { + fmt.Fprintln(os.Stderr, "error: provide at least one IP argument or --serve ") + fs.Usage() + os.Exit(1) + } if cfg.CacheDir == "" { d, err := os.UserCacheDir() if err != nil { @@ -160,15 +168,19 @@ func main() { } else { geoFetcher = dataset.PollFiles(cityTarPath, asnTarPath) } - geoGroup := dataset.NewGroup(geoFetcher) - cfg.geo = dataset.Add(geoGroup, func() (*geoip.Databases, error) { - return geoip.Open(maxmindDir) - }) - if err := geoGroup.Load(context.Background()); err != nil { + if _, err := geoFetcher.Fetch(); err != nil { log.Fatalf("geoip: %v", err) } - defer func() { _ = cfg.geo.Value().Close() }() + geoDB, err := geoip.Open(maxmindDir) + if err != nil { + log.Fatalf("geoip: %v", err) + } + cfg.geo.Store(geoDB) + defer func() { _ = cfg.geo.Load().Close() }() + for _, ip := range ips { + cfg.writeText(os.Stdout, cfg.lookup(ip)) + } if cfg.Bind == "" { return } @@ -178,9 +190,33 @@ func main() { go group.Tick(ctx, refreshInterval, func(err error) { log.Printf("blocklists refresh: %v", err) }) - go geoGroup.Tick(ctx, refreshInterval, func(err error) { - log.Printf("geoip refresh: %v", err) - }) + go func() { + t := time.NewTicker(refreshInterval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + updated, err := geoFetcher.Fetch() + if err != nil { + log.Printf("geoip refresh: %v", err) + continue + } + if !updated { + continue + } + db, err := geoip.Open(maxmindDir) + if err != nil { + log.Printf("geoip refresh: %v", err) + continue + } + if old := cfg.geo.Swap(db); old != nil { + _ = old.Close() + } + } + } + }() if err := cfg.serve(ctx); err != nil { log.Fatalf("serve: %v", err) } diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index bf5be04..23cd5c7 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log" "net" "net/http" @@ -23,47 +24,31 @@ type Result struct { Geo geoip.Info `json:"geo,omitzero"` } -func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) { - ip := strings.TrimSpace(r.URL.Query().Get("ip")) - if ip == "" { - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - first, _, _ := strings.Cut(xff, ",") - ip = strings.TrimSpace(first) - } else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { - ip = host - } else { - ip = r.RemoteAddr - } - } +// lookup builds a Result for ip against the currently loaded blocklists +// and GeoIP databases. +func (c *IPCheck) lookup(ip string) Result { in := c.inbound.Value().Contains(ip) out := c.outbound.Value().Contains(ip) - res := Result{ + return Result{ IP: ip, Blocked: in || out, BlockedInbound: in, BlockedOutbound: out, - Geo: c.geo.Value().Lookup(ip), + Geo: c.geo.Load().Lookup(ip), } +} - if r.URL.Query().Get("format") == "json" || - strings.Contains(r.Header.Get("Accept"), "application/json") { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - _ = enc.Encode(res) - return - } - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") +// writeText renders res as human-readable plain text. +func (c *IPCheck) writeText(w io.Writer, res Result) { switch { - case in && out: - fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip) - case in: - fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip) - case out: - fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip) + case res.BlockedInbound && res.BlockedOutbound: + fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", res.IP) + case res.BlockedInbound: + fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", res.IP) + case res.BlockedOutbound: + fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", res.IP) default: - fmt.Fprintf(w, "%s is allowed\n", ip) + fmt.Fprintf(w, "%s is allowed\n", res.IP) } var parts []string if res.Geo.City != "" { @@ -83,6 +68,33 @@ func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) { } } +func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) { + ip := strings.TrimSpace(r.URL.Query().Get("ip")) + if ip == "" { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + first, _, _ := strings.Cut(xff, ",") + ip = strings.TrimSpace(first) + } else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + ip = host + } else { + ip = r.RemoteAddr + } + } + res := c.lookup(ip) + + if r.URL.Query().Get("format") == "json" || + strings.Contains(r.Header.Get("Accept"), "application/json") { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + _ = enc.Encode(res) + return + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + c.writeText(w, res) +} + func (c *IPCheck) serve(ctx context.Context) error { mux := http.NewServeMux() mux.HandleFunc("GET /check", c.handle)