refactor(check-ip): IPCheck struct holds flag config + handler method

Follow golang-cli-flags pattern: config struct holds parsed flags and
loaded resources; handle and serve are methods on *IPCheck. Adds -V/help
pre-parse handling. Inlines clientIP into the handler.
This commit is contained in:
AJ ONeal 2026-04-20 16:02:55 -06:00
parent 0c281a494b
commit 7aa4493cb0
No known key found for this signature in database
2 changed files with 110 additions and 89 deletions

View File

@ -24,45 +24,78 @@ import (
const ( const (
defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git" defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
refreshInterval = 47 * time.Minute refreshInterval = 47 * time.Minute
version = "dev"
) )
// IPCheck holds the parsed CLI config and the loaded data sources used by
// the HTTP handler.
type IPCheck struct {
Bind string
ConfPath string
RepoURL string
CacheDir string
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
geo *geoip.Databases
}
func printVersion(w *os.File) {
fmt.Fprintf(w, "check-ip %s\n", version)
}
func main() { func main() {
var bind, confPath, repoURL, cacheDir string cfg := IPCheck{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080") fs.StringVar(&cfg.Bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") fs.StringVar(&cfg.ConfPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.Usage = func() { fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fs.PrintDefaults() fs.PrintDefaults()
} }
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
printVersion(os.Stdout)
os.Exit(0)
case "help", "-help", "--help":
printVersion(os.Stdout)
fmt.Fprintln(os.Stdout, "")
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil { if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) { if errors.Is(err, flag.ErrHelp) {
os.Exit(0) os.Exit(0)
} }
os.Exit(1) os.Exit(1)
} }
if cacheDir == "" { if cfg.CacheDir == "" {
d, err := os.UserCacheDir() d, err := os.UserCacheDir()
if err != nil { if err != nil {
log.Fatalf("cache-dir: %v", err) log.Fatalf("cache-dir: %v", err)
} }
cacheDir = d cfg.CacheDir = d
} }
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop() defer stop()
repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "") repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo) group := dataset.NewGroup(repo)
inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles( return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"), repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"), repo.FilePath("tables/inbound/networks.txt"),
) )
}) })
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { cfg.outbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles( return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"), repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"), repo.FilePath("tables/outbound/networks.txt"),
@ -75,9 +108,9 @@ func main() {
log.Printf("refresh: %v", err) log.Printf("refresh: %v", err)
}) })
maxmind := filepath.Join(cacheDir, "maxmind") maxmind := filepath.Join(cfg.CacheDir, "maxmind")
geo, err := geoip.OpenDatabases( geo, err := geoip.OpenDatabases(
confPath, cfg.ConfPath,
filepath.Join(maxmind, geoip.CityEdition+".mmdb"), filepath.Join(maxmind, geoip.CityEdition+".mmdb"),
filepath.Join(maxmind, geoip.ASNEdition+".mmdb"), filepath.Join(maxmind, geoip.ASNEdition+".mmdb"),
) )
@ -85,11 +118,12 @@ func main() {
log.Fatalf("geoip: %v", err) log.Fatalf("geoip: %v", err)
} }
defer func() { _ = geo.Close() }() defer func() { _ = geo.Close() }()
cfg.geo = geo
if bind == "" { if cfg.Bind == "" {
return return
} }
if err := serve(ctx, bind, inbound, outbound, geo); err != nil { if err := cfg.serve(ctx); err != nil {
log.Fatalf("serve: %v", err) log.Fatalf("serve: %v", err)
} }
} }

View File

@ -12,8 +12,6 @@ import (
"time" "time"
"github.com/therootcompany/golib/net/geoip" "github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset"
) )
// Result is the JSON verdict for a single IP. // Result is the JSON verdict for a single IP.
@ -25,25 +23,26 @@ type Result struct {
Geo geoip.Info `json:"geo,omitzero"` Geo geoip.Info `json:"geo,omitzero"`
} }
func serve( func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
ctx context.Context,
bind string,
inbound, outbound *dataset.View[ipcohort.Cohort],
geo *geoip.Databases,
) error {
handle := func(w http.ResponseWriter, r *http.Request) {
ip := strings.TrimSpace(r.URL.Query().Get("ip")) ip := strings.TrimSpace(r.URL.Query().Get("ip"))
if ip == "" { if ip == "" {
ip = clientIP(r) if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",")
ip = strings.TrimSpace(first)
} else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
ip = host
} else {
ip = r.RemoteAddr
} }
in := inbound.Value().Contains(ip) }
out := outbound.Value().Contains(ip) in := c.inbound.Value().Contains(ip)
out := c.outbound.Value().Contains(ip)
res := Result{ res := Result{
IP: ip, IP: ip,
Blocked: in || out, Blocked: in || out,
BlockedInbound: in, BlockedInbound: in,
BlockedOutbound: out, BlockedOutbound: out,
Geo: geo.Lookup(ip), Geo: c.geo.Lookup(ip),
} }
if r.URL.Query().Get("format") == "json" || if r.URL.Query().Get("format") == "json" ||
@ -82,14 +81,15 @@ func serve(
if res.Geo.ASN != 0 { if res.Geo.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg) fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg)
} }
} }
func (c *IPCheck) serve(ctx context.Context) error {
mux := http.NewServeMux() mux := http.NewServeMux()
mux.HandleFunc("GET /check", handle) mux.HandleFunc("GET /check", c.handle)
mux.HandleFunc("GET /{$}", handle) mux.HandleFunc("GET /{$}", c.handle)
srv := &http.Server{ srv := &http.Server{
Addr: bind, Addr: c.Bind,
Handler: mux, Handler: mux,
BaseContext: func(_ net.Listener) context.Context { return ctx }, BaseContext: func(_ net.Listener) context.Context { return ctx },
} }
@ -100,22 +100,9 @@ func serve(
_ = srv.Shutdown(shutCtx) _ = srv.Shutdown(shutCtx)
}() }()
log.Printf("listening on %s", bind) log.Printf("listening on %s", c.Bind)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err return err
} }
return nil return nil
} }
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",")
return strings.TrimSpace(first)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}