mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
refactor: strip all optional/nil-guard plumbing from check-ip + geoip
- drop Checker struct, loadCohort helper, and contains() nil-wrapper - inline check logic into server as a closure - geoip.Databases: no nil-receiver guards, no nil-field branches, no "disabled" mode. City + ASN are both required; caller hands explicit paths and OpenDatabases returns a fully-initialized value or an err - main.go is now straight-line wiring with no helper functions
This commit is contained in:
parent
cdce7da04c
commit
a84116f806
@ -25,21 +25,18 @@ const (
|
||||
refreshInterval = 47 * time.Minute
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Serve string
|
||||
GeoIPConf string
|
||||
BlocklistRepo string
|
||||
CacheDir string
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg := Config{}
|
||||
|
||||
var (
|
||||
bind string
|
||||
confPath string
|
||||
repoURL string
|
||||
cacheDir string
|
||||
)
|
||||
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
|
||||
fs.StringVar(&cfg.Serve, "serve", "", "bind address for the HTTP API, e.g. :8080")
|
||||
fs.StringVar(&cfg.GeoIPConf, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
|
||||
fs.StringVar(&cfg.BlocklistRepo, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
||||
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
|
||||
fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
|
||||
fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
|
||||
fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
||||
fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
|
||||
fs.PrintDefaults()
|
||||
@ -50,34 +47,34 @@ func main() {
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
if cfg.Serve == "" {
|
||||
if bind == "" {
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if cacheDir == "" {
|
||||
if d, err := os.UserCacheDir(); err == nil {
|
||||
cacheDir = d
|
||||
}
|
||||
}
|
||||
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
cacheDir := cfg.CacheDir
|
||||
if cacheDir == "" {
|
||||
base, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
fatal("cache-dir", err)
|
||||
}
|
||||
cacheDir = base
|
||||
}
|
||||
|
||||
// Blocklists: one git repo, two views sharing the same pull.
|
||||
repo := gitshallow.New(cfg.BlocklistRepo, filepath.Join(cacheDir, "bitwire-it"), 1, "")
|
||||
repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "")
|
||||
group := dataset.NewGroup(repo)
|
||||
inbound := dataset.Add(group, loadCohort(
|
||||
repo.FilePath("tables/inbound/single_ips.txt"),
|
||||
repo.FilePath("tables/inbound/networks.txt"),
|
||||
))
|
||||
outbound := dataset.Add(group, loadCohort(
|
||||
repo.FilePath("tables/outbound/single_ips.txt"),
|
||||
repo.FilePath("tables/outbound/networks.txt"),
|
||||
))
|
||||
inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
||||
return ipcohort.LoadFiles(
|
||||
repo.FilePath("tables/inbound/single_ips.txt"),
|
||||
repo.FilePath("tables/inbound/networks.txt"),
|
||||
)
|
||||
})
|
||||
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
||||
return ipcohort.LoadFiles(
|
||||
repo.FilePath("tables/outbound/single_ips.txt"),
|
||||
repo.FilePath("tables/outbound/networks.txt"),
|
||||
)
|
||||
})
|
||||
if err := group.Load(ctx); err != nil {
|
||||
fatal("blocklists", err)
|
||||
}
|
||||
@ -85,11 +82,10 @@ func main() {
|
||||
fmt.Fprintf(os.Stderr, "refresh: %v\n", err)
|
||||
})
|
||||
|
||||
// GeoIP: city + ASN readers, downloaded via httpcache when GeoIP.conf
|
||||
// is available; otherwise read from disk at the cache paths.
|
||||
// GeoIP: downloaded via httpcache when GeoIP.conf is available.
|
||||
maxmindDir := filepath.Join(cacheDir, "maxmind")
|
||||
geo, err := geoip.OpenDatabases(
|
||||
cfg.GeoIPConf,
|
||||
confPath,
|
||||
filepath.Join(maxmindDir, geoip.CityEdition+".mmdb"),
|
||||
filepath.Join(maxmindDir, geoip.ASNEdition+".mmdb"),
|
||||
)
|
||||
@ -98,8 +94,7 @@ func main() {
|
||||
}
|
||||
defer func() { _ = geo.Close() }()
|
||||
|
||||
checker := &Checker{Inbound: inbound, Outbound: outbound, GeoIP: geo}
|
||||
if err := serve(ctx, cfg.Serve, checker); err != nil {
|
||||
if err := serve(ctx, bind, inbound, outbound, geo); err != nil {
|
||||
fatal("serve", err)
|
||||
}
|
||||
}
|
||||
@ -108,42 +103,3 @@ func fatal(what string, err error) {
|
||||
fmt.Fprintf(os.Stderr, "error: %s: %v\n", what, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Checker bundles the blocklist views with the optional GeoIP databases.
|
||||
type Checker struct {
|
||||
Inbound *dataset.View[ipcohort.Cohort]
|
||||
Outbound *dataset.View[ipcohort.Cohort]
|
||||
GeoIP *geoip.Databases
|
||||
}
|
||||
|
||||
// Result is the structured verdict for a single IP.
|
||||
type Result struct {
|
||||
IP string `json:"ip"`
|
||||
Blocked bool `json:"blocked"`
|
||||
BlockedInbound bool `json:"blocked_inbound"`
|
||||
BlockedOutbound bool `json:"blocked_outbound"`
|
||||
Geo geoip.Info `json:"geo,omitzero"`
|
||||
}
|
||||
|
||||
// Check returns the structured verdict for ip.
|
||||
func (c *Checker) Check(ip string) Result {
|
||||
in := contains(c.Inbound.Value(), ip)
|
||||
out := contains(c.Outbound.Value(), ip)
|
||||
return Result{
|
||||
IP: ip,
|
||||
Blocked: in || out,
|
||||
BlockedInbound: in,
|
||||
BlockedOutbound: out,
|
||||
Geo: c.GeoIP.Lookup(ip),
|
||||
}
|
||||
}
|
||||
|
||||
func contains(c *ipcohort.Cohort, ip string) bool {
|
||||
return c != nil && c.Contains(ip)
|
||||
}
|
||||
|
||||
func loadCohort(paths ...string) func() (*ipcohort.Cohort, error) {
|
||||
return func() (*ipcohort.Cohort, error) {
|
||||
return ipcohort.LoadFiles(paths...)
|
||||
}
|
||||
}
|
||||
|
||||
@ -13,26 +13,53 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/geoip"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
"github.com/therootcompany/golib/sync/dataset"
|
||||
)
|
||||
|
||||
const shutdownTimeout = 5 * time.Second
|
||||
|
||||
// serve runs the HTTP API until ctx is cancelled, shutting down gracefully.
|
||||
// Result is the structured verdict for a single IP.
|
||||
type Result struct {
|
||||
IP string `json:"ip"`
|
||||
Blocked bool `json:"blocked"`
|
||||
BlockedInbound bool `json:"blocked_inbound"`
|
||||
BlockedOutbound bool `json:"blocked_outbound"`
|
||||
Geo geoip.Info `json:"geo,omitzero"`
|
||||
}
|
||||
|
||||
// serve runs the HTTP API until ctx is cancelled.
|
||||
//
|
||||
// GET / checks the request's client IP
|
||||
// GET /check same, plus ?ip= overrides
|
||||
//
|
||||
// Response format is chosen per request: ?format=json, then
|
||||
// Accept: application/json, else pretty text.
|
||||
func serve(ctx context.Context, bind string, checker *Checker) error {
|
||||
// Response format: ?format=json, then Accept: application/json, else pretty.
|
||||
func serve(
|
||||
ctx context.Context,
|
||||
bind string,
|
||||
inbound, outbound *dataset.View[ipcohort.Cohort],
|
||||
geo *geoip.Databases,
|
||||
) error {
|
||||
check := func(ip string) Result {
|
||||
in := inbound.Value().Contains(ip)
|
||||
out := outbound.Value().Contains(ip)
|
||||
return Result{
|
||||
IP: ip,
|
||||
Blocked: in || out,
|
||||
BlockedInbound: in,
|
||||
BlockedOutbound: out,
|
||||
Geo: geo.Lookup(ip),
|
||||
}
|
||||
}
|
||||
|
||||
handle := func(w http.ResponseWriter, r *http.Request, ip string) {
|
||||
format := requestFormat(r)
|
||||
if format == formatJSON {
|
||||
f := requestFormat(r)
|
||||
if f == formatJSON {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
}
|
||||
writeResult(w, checker.Check(ip), format)
|
||||
write(w, check(ip), f)
|
||||
}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
@ -54,7 +81,6 @@ func serve(ctx context.Context, bind string, checker *Checker) error {
|
||||
return ctx
|
||||
},
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||
@ -77,7 +103,6 @@ const (
|
||||
formatJSON
|
||||
)
|
||||
|
||||
// requestFormat picks a response format from ?format=, then Accept header.
|
||||
func requestFormat(r *http.Request) format {
|
||||
switch r.URL.Query().Get("format") {
|
||||
case "json":
|
||||
@ -91,7 +116,7 @@ func requestFormat(r *http.Request) format {
|
||||
return formatPretty
|
||||
}
|
||||
|
||||
func writeResult(w io.Writer, r Result, f format) {
|
||||
func write(w io.Writer, r Result, f format) {
|
||||
if f == formatJSON {
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
@ -108,31 +133,25 @@ func writeResult(w io.Writer, r Result, f format) {
|
||||
default:
|
||||
fmt.Fprintf(w, "%s is allowed\n", r.IP)
|
||||
}
|
||||
writeGeo(w, r.Geo)
|
||||
}
|
||||
|
||||
func writeGeo(w io.Writer, info geoip.Info) {
|
||||
var parts []string
|
||||
if info.City != "" {
|
||||
parts = append(parts, info.City)
|
||||
if r.Geo.City != "" {
|
||||
parts = append(parts, r.Geo.City)
|
||||
}
|
||||
if info.Region != "" {
|
||||
parts = append(parts, info.Region)
|
||||
if r.Geo.Region != "" {
|
||||
parts = append(parts, r.Geo.Region)
|
||||
}
|
||||
if info.Country != "" {
|
||||
parts = append(parts, fmt.Sprintf("%s (%s)", info.Country, info.CountryISO))
|
||||
if r.Geo.Country != "" {
|
||||
parts = append(parts, fmt.Sprintf("%s (%s)", r.Geo.Country, r.Geo.CountryISO))
|
||||
}
|
||||
if len(parts) > 0 {
|
||||
fmt.Fprintf(w, " Location: %s\n", strings.Join(parts, ", "))
|
||||
}
|
||||
if info.ASN != 0 {
|
||||
fmt.Fprintf(w, " ASN: AS%d %s\n", info.ASN, info.ASNOrg)
|
||||
if r.Geo.ASN != 0 {
|
||||
fmt.Fprintf(w, " ASN: AS%d %s\n", r.Geo.ASN, r.Geo.ASNOrg)
|
||||
}
|
||||
}
|
||||
|
||||
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
|
||||
// The leftmost entry in X-Forwarded-For is the originating client; intermediate
|
||||
// proxies append themselves rightward.
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
first, _, _ := strings.Cut(xff, ",")
|
||||
|
||||
@ -9,10 +9,10 @@ import (
|
||||
|
||||
// Conf holds the fields parsed from a geoipupdate-style config file.
|
||||
type Conf struct {
|
||||
AccountID string
|
||||
LicenseKey string
|
||||
EditionIDs []string
|
||||
DatabaseDirectory string
|
||||
AccountID string
|
||||
LicenseKey string
|
||||
EditionIDs []string
|
||||
DatabaseDirectory string
|
||||
}
|
||||
|
||||
// ParseConf reads a geoipupdate-style config file (whitespace-separated
|
||||
|
||||
@ -10,9 +10,7 @@ import (
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
// Databases holds open GeoLite2 readers. A nil field means that edition
|
||||
// wasn't configured. A nil *Databases means geoip is disabled; all methods
|
||||
// are nil-safe no-ops so callers need not branch.
|
||||
// Databases holds open GeoLite2 City + ASN readers.
|
||||
type Databases struct {
|
||||
City *geoip2.Reader
|
||||
ASN *geoip2.Reader
|
||||
@ -22,9 +20,8 @@ type Databases struct {
|
||||
// GeoIP.conf with credentials is available), and opens the readers.
|
||||
//
|
||||
// - confPath="" → auto-discover from DefaultConfPaths
|
||||
// - conf found → auto-download; cityPath/asnPath override default locations
|
||||
// - conf found → auto-download to cityPath/asnPath
|
||||
// - no conf → cityPath and asnPath must point to existing .mmdb files
|
||||
// - no conf and no paths → returns nil, nil (geoip disabled)
|
||||
func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
|
||||
if confPath == "" {
|
||||
for _, p := range DefaultConfPaths() {
|
||||
@ -40,20 +37,8 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("geoip-conf: %w", err)
|
||||
}
|
||||
dbDir := cfg.DatabaseDirectory
|
||||
if dbDir == "" {
|
||||
if dbDir, err = DefaultCacheDir(); err != nil {
|
||||
return nil, fmt.Errorf("geoip cache dir: %w", err)
|
||||
}
|
||||
}
|
||||
if err := os.MkdirAll(dbDir, 0o755); err != nil {
|
||||
return nil, fmt.Errorf("mkdir %s: %w", dbDir, err)
|
||||
}
|
||||
if cityPath == "" {
|
||||
cityPath = filepath.Join(dbDir, CityEdition+".mmdb")
|
||||
}
|
||||
if asnPath == "" {
|
||||
asnPath = filepath.Join(dbDir, ASNEdition+".mmdb")
|
||||
if err := os.MkdirAll(filepath.Dir(cityPath), 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dl := New(cfg.AccountID, cfg.LicenseKey)
|
||||
if _, err := dl.NewCacher(CityEdition, cityPath).Fetch(); err != nil {
|
||||
@ -62,60 +47,30 @@ func OpenDatabases(confPath, cityPath, asnPath string) (*Databases, error) {
|
||||
if _, err := dl.NewCacher(ASNEdition, asnPath).Fetch(); err != nil {
|
||||
return nil, fmt.Errorf("fetch %s: %w", ASNEdition, err)
|
||||
}
|
||||
return Open(cityPath, asnPath)
|
||||
}
|
||||
|
||||
if cityPath == "" && asnPath == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return Open(cityPath, asnPath)
|
||||
}
|
||||
|
||||
// Open opens city and ASN .mmdb files from the given paths. Empty paths are
|
||||
// treated as unconfigured (the corresponding field stays nil).
|
||||
// Open opens city and ASN .mmdb files from the given paths.
|
||||
func Open(cityPath, asnPath string) (*Databases, error) {
|
||||
d := &Databases{}
|
||||
if cityPath != "" {
|
||||
r, err := geoip2.Open(cityPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", cityPath, err)
|
||||
}
|
||||
d.City = r
|
||||
city, err := geoip2.Open(cityPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open %s: %w", cityPath, err)
|
||||
}
|
||||
if asnPath != "" {
|
||||
r, err := geoip2.Open(asnPath)
|
||||
if err != nil {
|
||||
if d.City != nil {
|
||||
_ = d.City.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("open %s: %w", asnPath, err)
|
||||
}
|
||||
d.ASN = r
|
||||
asn, err := geoip2.Open(asnPath)
|
||||
if err != nil {
|
||||
_ = city.Close()
|
||||
return nil, fmt.Errorf("open %s: %w", asnPath, err)
|
||||
}
|
||||
return d, nil
|
||||
return &Databases{City: city, ASN: asn}, nil
|
||||
}
|
||||
|
||||
// Close closes any open readers. No-op on nil receiver.
|
||||
// Close closes the city and ASN readers.
|
||||
func (d *Databases) Close() error {
|
||||
if d == nil {
|
||||
return nil
|
||||
}
|
||||
var errs []error
|
||||
if d.City != nil {
|
||||
if err := d.City.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
if d.ASN != nil {
|
||||
if err := d.ASN.Close(); err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
return errors.Join(errs...)
|
||||
return errors.Join(d.City.Close(), d.ASN.Close())
|
||||
}
|
||||
|
||||
// Info is the structured result of a GeoIP lookup. Zero-valued fields mean
|
||||
// the database didn't return a value (or wasn't configured).
|
||||
// Info is the structured result of a GeoIP lookup.
|
||||
type Info struct {
|
||||
City string `json:"city,omitempty"`
|
||||
Region string `json:"region,omitempty"`
|
||||
@ -125,37 +80,29 @@ type Info struct {
|
||||
ASNOrg string `json:"asn_org,omitempty"`
|
||||
}
|
||||
|
||||
// Lookup returns city + ASN info for ip. Returns a zero Info on nil receiver,
|
||||
// unparseable IP, or database miss.
|
||||
// Lookup returns city + ASN info for ip. Returns a zero Info on unparseable
|
||||
// IP or database miss.
|
||||
func (d *Databases) Lookup(ip string) Info {
|
||||
var info Info
|
||||
if d == nil {
|
||||
return info
|
||||
}
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return info
|
||||
}
|
||||
stdIP := addr.AsSlice()
|
||||
|
||||
if d.City != nil {
|
||||
if rec, err := d.City.City(stdIP); err == nil {
|
||||
info.City = rec.City.Names["en"]
|
||||
info.Country = rec.Country.Names["en"]
|
||||
info.CountryISO = rec.Country.IsoCode
|
||||
if len(rec.Subdivisions) > 0 {
|
||||
if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != info.City {
|
||||
info.Region = sub
|
||||
}
|
||||
if rec, err := d.City.City(stdIP); err == nil {
|
||||
info.City = rec.City.Names["en"]
|
||||
info.Country = rec.Country.Names["en"]
|
||||
info.CountryISO = rec.Country.IsoCode
|
||||
if len(rec.Subdivisions) > 0 {
|
||||
if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != info.City {
|
||||
info.Region = sub
|
||||
}
|
||||
}
|
||||
}
|
||||
if d.ASN != nil {
|
||||
if rec, err := d.ASN.ASN(stdIP); err == nil {
|
||||
info.ASN = rec.AutonomousSystemNumber
|
||||
info.ASNOrg = rec.AutonomousSystemOrganization
|
||||
}
|
||||
if rec, err := d.ASN.ASN(stdIP); err == nil {
|
||||
info.ASN = rec.AutonomousSystemNumber
|
||||
info.ASNOrg = rec.AutonomousSystemOrganization
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user