refactor(check-ip): IPCheck struct holds flag config + handler method

Follow golang-cli-flags pattern: config struct holds parsed flags and
loaded resources; handle and serve are methods on *IPCheck. Adds -V/help
pre-parse handling. Inlines clientIP into the handler.
This commit is contained in:
AJ ONeal 2026-04-20 16:02:55 -06:00
parent 0c281a494b
commit 7aa4493cb0
No known key found for this signature in database
2 changed files with 110 additions and 89 deletions

View File

@ -24,45 +24,78 @@ import (
const (
defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
refreshInterval = 47 * time.Minute
version = "dev"
)
// IPCheck holds the parsed CLI config and the loaded data sources used by
// the HTTP handler.
type IPCheck struct {
Bind string
ConfPath string
RepoURL string
CacheDir string
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
geo *geoip.Databases
}
func printVersion(w *os.File) {
fmt.Fprintf(w, "check-ip %s\n", version)
}
func main() {
var bind, confPath, repoURL, cacheDir string
cfg := IPCheck{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.StringVar(&cfg.Bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
fs.StringVar(&cfg.ConfPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
fs.PrintDefaults()
}
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
printVersion(os.Stdout)
os.Exit(0)
case "help", "-help", "--help":
printVersion(os.Stdout)
fmt.Fprintln(os.Stdout, "")
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1)
}
if cacheDir == "" {
if cfg.CacheDir == "" {
d, err := os.UserCacheDir()
if err != nil {
log.Fatalf("cache-dir: %v", err)
}
cacheDir = d
cfg.CacheDir = d
}
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "")
repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo)
inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
)
})
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
cfg.outbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
@ -75,9 +108,9 @@ func main() {
log.Printf("refresh: %v", err)
})
maxmind := filepath.Join(cacheDir, "maxmind")
maxmind := filepath.Join(cfg.CacheDir, "maxmind")
geo, err := geoip.OpenDatabases(
confPath,
cfg.ConfPath,
filepath.Join(maxmind, geoip.CityEdition+".mmdb"),
filepath.Join(maxmind, geoip.ASNEdition+".mmdb"),
)
@ -85,11 +118,12 @@ func main() {
log.Fatalf("geoip: %v", err)
}
defer func() { _ = geo.Close() }()
cfg.geo = geo
if bind == "" {
if cfg.Bind == "" {
return
}
if err := serve(ctx, bind, inbound, outbound, geo); err != nil {
if err := cfg.serve(ctx); err != nil {
log.Fatalf("serve: %v", err)
}
}

View File

@ -12,8 +12,6 @@ import (
"time"
"github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/ipcohort"
"github.com/therootcompany/golib/sync/dataset"
)
// Result is the JSON verdict for a single IP.
@ -25,71 +23,73 @@ type Result struct {
Geo geoip.Info `json:"geo,omitzero"`
}
func serve(
ctx context.Context,
bind string,
inbound, outbound *dataset.View[ipcohort.Cohort],
geo *geoip.Databases,
) error {
handle := func(w http.ResponseWriter, r *http.Request) {
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
if ip == "" {
ip = clientIP(r)
}
in := inbound.Value().Contains(ip)
out := outbound.Value().Contains(ip)
res := Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: geo.Lookup(ip),
}
if r.URL.Query().Get("format") == "json" ||
strings.Contains(r.Header.Get("Accept"), "application/json") {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(res)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
switch {
case in && out:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip)
case in:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip)
case out:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip)
default:
fmt.Fprintf(w, "%s is allowed\n", ip)
}
var parts []string
if res.Geo.City != "" {
parts = append(parts, res.Geo.City)
}
if res.Geo.Region != "" {
parts = append(parts, res.Geo.Region)
}
if res.Geo.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", res.Geo.Country, res.Geo.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if res.Geo.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg)
func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
if ip == "" {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",")
ip = strings.TrimSpace(first)
} else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
ip = host
} else {
ip = r.RemoteAddr
}
}
in := c.inbound.Value().Contains(ip)
out := c.outbound.Value().Contains(ip)
res := Result{
IP: ip,
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: c.geo.Lookup(ip),
}
if r.URL.Query().Get("format") == "json" ||
strings.Contains(r.Header.Get("Accept"), "application/json") {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")
_ = enc.Encode(res)
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
switch {
case in && out:
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip)
case in:
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip)
case out:
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip)
default:
fmt.Fprintf(w, "%s is allowed\n", ip)
}
var parts []string
if res.Geo.City != "" {
parts = append(parts, res.Geo.City)
}
if res.Geo.Region != "" {
parts = append(parts, res.Geo.Region)
}
if res.Geo.Country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", res.Geo.Country, res.Geo.CountryISO))
}
if len(parts) > 0 {
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
}
if res.Geo.ASN != 0 {
fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg)
}
}
func (c *IPCheck) serve(ctx context.Context) error {
mux := http.NewServeMux()
mux.HandleFunc("GET /check", handle)
mux.HandleFunc("GET /{$}", handle)
mux.HandleFunc("GET /check", c.handle)
mux.HandleFunc("GET /{$}", c.handle)
srv := &http.Server{
Addr: bind,
Addr: c.Bind,
Handler: mux,
BaseContext: func(_ net.Listener) context.Context { return ctx },
}
@ -100,22 +100,9 @@ func serve(
_ = srv.Shutdown(shutCtx)
}()
log.Printf("listening on %s", bind)
log.Printf("listening on %s", c.Bind)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
}
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
func clientIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
first, _, _ := strings.Cut(xff, ",")
return strings.TrimSpace(first)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return r.RemoteAddr
}
return host
}