diff --git a/cmd/check-ip/main.go b/cmd/check-ip/main.go index 27386a3..350be53 100644 --- a/cmd/check-ip/main.go +++ b/cmd/check-ip/main.go @@ -1,6 +1,6 @@ // check-ip runs an HTTP API that reports whether an IP appears in the -// configured blocklist repo and, when GeoIP.conf is available, enriches -// the response with MaxMind GeoLite2 City + ASN data. +// configured blocklist repo and enriches the response with MaxMind +// GeoLite2 City + ASN data. package main import ( @@ -8,6 +8,7 @@ import ( "errors" "flag" "fmt" + "log" "os" "os/signal" "path/filepath" @@ -26,12 +27,7 @@ const ( ) func main() { - var ( - bind string - confPath string - repoURL string - cacheDir string - ) + var bind, confPath, repoURL, cacheDir string fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080") fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)") @@ -47,20 +43,13 @@ func main() { } os.Exit(1) } - if bind == "" { - fs.Usage() - os.Exit(1) - } if cacheDir == "" { - if d, err := os.UserCacheDir(); err == nil { - cacheDir = d - } + cacheDir, _ = os.UserCacheDir() } ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() - // Blocklists: one git repo, two views sharing the same pull. repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "") group := dataset.NewGroup(repo) inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) { @@ -76,30 +65,27 @@ func main() { ) }) if err := group.Load(ctx); err != nil { - fatal("blocklists", err) + log.Fatalf("blocklists: %v", err) } go group.Tick(ctx, refreshInterval, func(err error) { - fmt.Fprintf(os.Stderr, "refresh: %v\n", err) + log.Printf("refresh: %v", err) }) - // GeoIP: downloaded via httpcache when GeoIP.conf is available. - maxmindDir := filepath.Join(cacheDir, "maxmind") + maxmind := filepath.Join(cacheDir, "maxmind") geo, err := geoip.OpenDatabases( confPath, - filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"), - filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"), + filepath.Join(maxmind, geoip.CityEdition+".mmdb"), + filepath.Join(maxmind, geoip.ASNEdition+".mmdb"), ) if err != nil { - fatal("geoip", err) + log.Fatalf("geoip: %v", err) } - defer func() { _ = geo.Close() }() + defer geo.Close() + if bind == "" { + return + } if err := serve(ctx, bind, inbound, outbound, geo); err != nil { - fatal("serve", err) + log.Fatalf("serve: %v", err) } } - -func fatal(what string, err error) { - fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err) - os.Exit(1) -} diff --git a/cmd/check-ip/server.go b/cmd/check-ip/server.go index 68c6dd2..00382ec 100644 --- a/cmd/check-ip/server.go +++ b/cmd/check-ip/server.go @@ -5,10 +5,9 @@ import ( "encoding/json" "errors" "fmt" - "io" + "log" "net" "net/http" - "os" "strings" "time" @@ -17,9 +16,7 @@ import ( "github.com/therootcompany/golib/sync/dataset" ) -const shutdownTimeout = 5 * time.Second - -// Result is the structured verdict for a single IP. +// Result is the JSON verdict for a single IP. type Result struct { IP string `json:"ip"` Blocked bool `json:"blocked"` @@ -28,129 +25,88 @@ type Result struct { Geo geoip.Info `json:"geo,omitzero"` } -// serve runs the HTTP API until ctx is cancelled. -// -// GET / checks the request's client IP -// GET /check same, plus ?ip= overrides -// -// Response format: ?format=json, then Accept: application/json, else pretty. func serve( ctx context.Context, bind string, inbound, outbound *dataset.View[ipcohort.Cohort], geo *geoip.Databases, ) error { - check := func(ip string) Result { + handle := func(w http.ResponseWriter, r *http.Request) { + ip := strings.TrimSpace(r.URL.Query().Get("ip")) + if ip == "" { + ip = clientIP(r) + } in := inbound.Value().Contains(ip) out := outbound.Value().Contains(ip) - return Result{ + res := Result{ IP: ip, Blocked: in || out, BlockedInbound: in, BlockedOutbound: out, Geo: geo.Lookup(ip), } - } - handle := func(w http.ResponseWriter, r *http.Request, ip string) { - f := requestFormat(r) - if f == formatJSON { + if r.URL.Query().Get("format") == "json" || + strings.Contains(r.Header.Get("Accept"), "application/json") { w.Header().Set("Content-Type", "application/json; charset=utf-8") - } else { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + _ = enc.Encode(res) + return + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + switch { + case in && out: + fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip) + case in: + fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip) + case out: + fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip) + default: + fmt.Fprintf(w, "%s is allowed\n", ip) + } + var parts []string + if res.Geo.City != "" { + parts = append(parts, res.Geo.City) + } + if res.Geo.Region != "" { + parts = append(parts, res.Geo.Region) + } + if res.Geo.Country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", res.Geo.Country, res.Geo.CountryISO)) + } + if len(parts) > 0 { + fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) + } + if res.Geo.ASN != 0 { + fmt.Fprintf(w, " ASN: AS%d %s\n", res.Geo.ASN, res.Geo.ASNOrg) } - write(w, check(ip), f) } mux := http.NewServeMux() - mux.HandleFunc("GET /check", func(w http.ResponseWriter, r *http.Request) { - ip := strings.TrimSpace(r.URL.Query().Get("ip")) - if ip == "" { - ip = clientIP(r) - } - handle(w, r, ip) - }) - mux.HandleFunc("GET /{$}", func(w http.ResponseWriter, r *http.Request) { - handle(w, r, clientIP(r)) - }) + mux.HandleFunc("GET /check", handle) + mux.HandleFunc("GET /{$}", handle) srv := &http.Server{ - Addr: bind, - Handler: mux, - BaseContext: func(_ net.Listener) context.Context { - return ctx - }, + Addr: bind, + Handler: mux, + BaseContext: func(_ net.Listener) context.Context { return ctx }, } go func() { <-ctx.Done() - shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + shutCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - _ = srv.Shutdown(shutdownCtx) + _ = srv.Shutdown(shutCtx) }() - fmt.Fprintf(os.Stderr, "listening on %s\n", bind) + log.Printf("listening on %s", bind) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } return nil } -// format is the response rendering. Server-only. -type format int - -const ( - formatPretty format = iota - formatJSON -) - -func requestFormat(r *http.Request) format { - switch r.URL.Query().Get("format") { - case "json": - return formatJSON - case "pretty": - return formatPretty - } - if strings.Contains(r.Header.Get("Accept"), "application/json") { - return formatJSON - } - return formatPretty -} - -func write(w io.Writer, r Result, f format) { - if f == formatJSON { - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - _ = enc.Encode(r) - return - } - switch { - case r.BlockedInbound && r.BlockedOutbound: - fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", r.IP) - case r.BlockedInbound: - fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", r.IP) - case r.BlockedOutbound: - fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", r.IP) - default: - fmt.Fprintf(w, "%s is allowed\n", r.IP) - } - var parts []string - if r.Geo.City != "" { - parts = append(parts, r.Geo.City) - } - if r.Geo.Region != "" { - parts = append(parts, r.Geo.Region) - } - if r.Geo.Country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", r.Geo.Country, r.Geo.CountryISO)) - } - if len(parts) > 0 { - fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", ")) - } - if r.Geo.ASN != 0 { - fmt.Fprintf(w, " ASN: AS%d %s\n", r.Geo.ASN, r.Geo.ASNOrg) - } -} - // clientIP extracts the caller's IP, honoring X-Forwarded-For when present. func clientIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" {