diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index 5e0e590..d234c73 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -24,45 +24,78 @@ import ( const ( defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git" refreshInterval = 47 * time.Minute + version = "dev" ) +// IPCheck holds the parsed CLI config and the loaded data sources used by +// the HTTP handler. +type IPCheck struct { + Bind string + ConfPath string + RepoURL string + CacheDir string + + inbound *dataset.View[ipcohort.Cohort] + outbound *dataset.View[ipcohort.Cohort] + geo *geoip.Databases +} + +func printVersion(w *os.File) { + fmt.Fprintf(w, "check-ip %s\n", version) +} + func main() { - var bind, confPath, repoURL, cacheDir string + cfg := IPCheck{} fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) - fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080") - fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") - fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") - fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") + fs.StringVar(&cfg.Bind, "serve", "", "bind address for the HTTP API, e.g. :8080") + fs.StringVar(&cfg.ConfPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") + fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)") + fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)") fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s --serve [flags]\n", os.Args[0]) fs.PrintDefaults() } + + if len(os.Args) > 1 { + switch os.Args[1] { + case "-V", "-version", "--version", "version": + printVersion(os.Stdout) + os.Exit(0) + case "help", "-help", "--help": + printVersion(os.Stdout) + fmt.Fprintln(os.Stdout, "") + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + if err := fs.Parse(os.Args[1:]); err != nil { if errors.Is(err, flag.ErrHelp) { os.Exit(0) } os.Exit(1) } - if cacheDir == "" { + if cfg.CacheDir == "" { d, err := os.UserCacheDir() if err != nil { log.Fatalf("cache-dir: %v", err) } - cacheDir = d + cfg.CacheDir = d } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "") + repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "") group := dataset.NewGroup(repo) - inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { + cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles( repo.FilePath("tables/inbound/single_ips.txt"), repo.FilePath("tables/inbound/networks.txt"), ) }) - outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { + cfg.outbound = dataset.Add(group, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles( repo.FilePath("tables/outbound/single_ips.txt"), repo.FilePath("tables/outbound/networks.txt"), @@ -75,9 +108,9 @@ func main() { log.Printf("refresh: %v", err) }) - maxmind := filepath.Join(cacheDir, "maxmind") + maxmind := filepath.Join(cfg.CacheDir, "maxmind") geo, err := geoip.OpenDatabases( - confPath, + cfg.ConfPath, filepath.Join(maxmind, geoip.CityEdition+".mmdb"), filepath.Join(maxmind, geoip.ASNEdition+".mmdb"), ) @@ -85,11 +118,12 @@ func main() { log.Fatalf("geoip: %v", err) } defer func() { _ = geo.Close() }() + cfg.geo = geo - if bind == "" { + if cfg.Bind == "" { return } - if err := serve(ctx, bind, inbound, outbound, geo); err != nil { + if err := cfg.serve(ctx); err != nil { log.Fatalf("serve: %v", err) } } diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index 00382ec..7c3379d 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -12,8 +12,6 @@ import ( "time" "github.com/therootcompany/golib/net/geoip" - "github.com/therootcompany/golib/net/ipcohort" - "github.com/therootcompany/golib/sync/dataset" ) // Result is the JSON verdict for a single IP. @@ -25,71 +23,73 @@ type Result struct { Geo geoip.Info `json:"geo,omitzero"` } -func serve( - ctx context.Context, - bind string, - inbound, outbound *dataset.View[ipcohort.Cohort], - geo *geoip.Databases, -) error { - handle := func(w http.ResponseWriter, r *http.Request) { - ip := strings.TrimSpace(r.URL.Query().Get("ip")) - if ip == "" { - ip = clientIP(r) - } - in := inbound.Value().Contains(ip) - out := outbound.Value().Contains(ip) - res := Result{ - IP: ip, - Blocked: in || out, - BlockedInbound: in, - BlockedOutbound: out, - Geo: geo.Lookup(ip), - } - - if r.URL.Query().Get("format") == "json" || - strings.Contains(r.Header.Get("Accept"), "application/json") { - w.Header().Set("Content-Type", "application/json; charset=utf-8") - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - _ = enc.Encode(res) - return - } - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - switch { - case in && out: - fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip) - case in: - fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip) - case out: - fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip) - default: - fmt.Fprintf(w, "%s is allowed\n", ip) - } - var parts []string - if res.Geo.City != "" { - parts = append(parts, res.Geo.City) - } - if res.Geo.Region != "" { - parts = append(parts, res.Geo.Region) - } - if res.Geo.Country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", res.Geo.Country, res.Geo.CountryISO)) - } - if len(parts) > 0 { - fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) - } - if res.Geo.ASN != 0 { - fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg) +func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) { + ip := strings.TrimSpace(r.URL.Query().Get("ip")) + if ip == "" { + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + first, _, _ := strings.Cut(xff, ",") + ip = strings.TrimSpace(first) + } else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { + ip = host + } else { + ip = r.RemoteAddr } } + in := c.inbound.Value().Contains(ip) + out := c.outbound.Value().Contains(ip) + res := Result{ + IP: ip, + Blocked: in || out, + BlockedInbound: in, + BlockedOutbound: out, + Geo: c.geo.Lookup(ip), + } + if r.URL.Query().Get("format") == "json" || + strings.Contains(r.Header.Get("Accept"), "application/json") { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + _ = enc.Encode(res) + return + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + switch { + case in && out: + fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip) + case in: + fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip) + case out: + fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip) + default: + fmt.Fprintf(w, "%s is allowed\n", ip) + } + var parts []string + if res.Geo.City != "" { + parts = append(parts, res.Geo.City) + } + if res.Geo.Region != "" { + parts = append(parts, res.Geo.Region) + } + if res.Geo.Country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", res.Geo.Country, res.Geo.CountryISO)) + } + if len(parts) > 0 { + fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) + } + if res.Geo.ASN != 0 { + fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg) + } +} + +func (c *IPCheck) serve(ctx context.Context) error { mux := http.NewServeMux() - mux.HandleFunc("GET /check", handle) - mux.HandleFunc("GET /{$}", handle) + mux.HandleFunc("GET /check", c.handle) + mux.HandleFunc("GET /{$}", c.handle) srv := &http.Server{ - Addr: bind, + Addr: c.Bind, Handler: mux, BaseContext: func(_ net.Listener) context.Context { return ctx }, } @@ -100,22 +100,9 @@ func serve( _ = srv.Shutdown(shutCtx) }() - log.Printf("listening on %s", bind) + log.Printf("listening on %s", c.Bind) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } return nil } - -// clientIP extracts the caller's IP, honoring X-Forwarded-For when present. -func clientIP(r *http.Request) string { - if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - first, _, _ := strings.Cut(xff, ",") - return strings.TrimSpace(first) - } - host, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - return r.RemoteAddr - } - return host -}