diff --git a/net/dataset/dataset.go b/net/dataset/dataset.go index eb5fb8d..26d481e 100644 --- a/net/dataset/dataset.go +++ b/net/dataset/dataset.go @@ -2,22 +2,22 @@ // atomic.Pointer (hot-swap), providing a generic periodically-updated // in-memory dataset with lock-free reads. // -// Single dataset: +// Standalone dataset (one syncer, one value): // -// ds := dataset.New(cacher, ipcohort.LoadFile, path) +// ds := dataset.New(cacher, func() (*MyType, error) { +// return mytype.LoadFile(path) +// }) // if err := ds.Init(); err != nil { ... } // go ds.Run(ctx, 47*time.Minute) -// cohort := ds.Load() +// val := ds.Load() // *MyType, lock-free // -// Multiple datasets sharing one syncer (e.g. inbound + outbound from one git repo): +// Group (one syncer, multiple values — e.g. inbound+outbound from one git repo): // // g := dataset.NewGroup(repo) -// inbound := dataset.Add(g, ipcohort.LoadFile, inboundPath) -// outbound := dataset.Add(g, ipcohort.LoadFile, outboundPath) +// inbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles(inboundPaths...) }) +// outbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles(outboundPaths...) }) // if err := g.Init(); err != nil { ... } // go g.Run(ctx, 47*time.Minute) -// in := inbound.Load() -// out := outbound.Load() package dataset import ( @@ -30,21 +30,24 @@ import ( "github.com/therootcompany/golib/net/httpcache" ) -// Loader reads path and returns the parsed value, or an error. -type Loader[T any] func(path string) (*T, error) - -// Dataset couples a Syncer, a Loader, and an atomic.Pointer. +// Dataset couples a Syncer, a load function, and an atomic.Pointer[T]. // Load is safe for concurrent use without locks. type Dataset[T any] struct { + // Name is used in error messages. Optional. + Name string + // Close is called with the previous value after each successful swap. + // Use this for values that hold resources, e.g. func(r *geoip2.Reader) { r.Close() }. + Close func(*T) + syncer httpcache.Syncer - load Loader[T] - path string + load func() (*T, error) ptr atomic.Pointer[T] } -// New creates a Dataset. The syncer fetches updates to path; load parses it. -func New[T any](syncer httpcache.Syncer, load Loader[T], path string) *Dataset[T] { - return &Dataset[T]{syncer: syncer, load: load, path: path} +// New creates a Dataset. The syncer fetches updates; load produces the value. +// load is a closure — it captures whatever paths or config it needs. +func New[T any](syncer httpcache.Syncer, load func() (*T, error)) *Dataset[T] { + return &Dataset[T]{syncer: syncer, load: load} } // Load returns the current value. Returns nil before Init is called. @@ -52,7 +55,7 @@ func (d *Dataset[T]) Load() *T { return d.ptr.Load() } -// Init fetches (if the syncer needs it) then loads, ensuring the dataset is +// Init fetches (if needed) then always loads, ensuring the dataset is // populated on startup from an existing local file even if nothing changed. func (d *Dataset[T]) Init() error { if _, err := d.syncer.Fetch(); err != nil { @@ -61,8 +64,7 @@ func (d *Dataset[T]) Init() error { return d.reload() } -// Sync fetches from the remote and reloads if the content changed. -// Returns whether the value was updated. +// Sync fetches and reloads if the content changed. Returns whether updated. func (d *Dataset[T]) Sync() (bool, error) { updated, err := d.syncer.Fetch() if err != nil || !updated { @@ -80,7 +82,11 @@ func (d *Dataset[T]) Run(ctx context.Context, interval time.Duration) { select { case <-ticker.C: if _, err := d.Sync(); err != nil { - fmt.Fprintf(os.Stderr, "dataset %s: sync error: %v\n", d.path, err) + name := d.Name + if name == "" { + name = "dataset" + } + fmt.Fprintf(os.Stderr, "%s: sync error: %v\n", name, err) } case <-ctx.Done(): return @@ -89,27 +95,28 @@ func (d *Dataset[T]) Run(ctx context.Context, interval time.Duration) { } func (d *Dataset[T]) reload() error { - val, err := d.load(d.path) + val, err := d.load() if err != nil { return err } - d.ptr.Store(val) + if old := d.ptr.Swap(val); old != nil && d.Close != nil { + d.Close(old) + } return nil } // -- Group: one Syncer driving multiple datasets --------------------------- -// entry is the type-erased reload handle stored in a Group. -type entry interface { +// member is the type-erased reload handle stored in a Group. +type member interface { reload() error } // Group ties one Syncer to multiple datasets so a single Fetch drives all -// reloads — avoiding redundant network calls when datasets share a source -// (e.g. multiple files from the same git repo or HTTP directory). +// reloads — no redundant network calls when datasets share a source. type Group struct { syncer httpcache.Syncer - entries []entry + members []member } // NewGroup creates a Group backed by syncer. @@ -117,11 +124,12 @@ func NewGroup(syncer httpcache.Syncer) *Group { return &Group{syncer: syncer} } -// Add registers a new dataset in g and returns it. Subsequent Init/Sync/Run -// calls on g will reload this dataset whenever the syncer reports an update. -func Add[T any](g *Group, load Loader[T], path string) *Dataset[T] { - d := &Dataset[T]{load: load, path: path} - g.entries = append(g.entries, d) +// Add registers a new dataset in g and returns it. Call Init or Run on g — +// not on the returned dataset — to drive updates. +// load is a closure capturing whatever paths or config it needs. +func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] { + d := &Dataset[T]{load: load} + g.members = append(g.members, d) return d } @@ -159,8 +167,8 @@ func (g *Group) Run(ctx context.Context, interval time.Duration) { } func (g *Group) reloadAll() error { - for _, e := range g.entries { - if err := e.reload(); err != nil { + for _, m := range g.members { + if err := m.reload(); err != nil { return err } } diff --git a/net/ipcohort/cmd/check-ip/blacklist.go b/net/ipcohort/cmd/check-ip/blacklist.go index c62c1b2..a630d85 100644 --- a/net/ipcohort/cmd/check-ip/blacklist.go +++ b/net/ipcohort/cmd/check-ip/blacklist.go @@ -3,6 +3,7 @@ package main import ( "path/filepath" + "github.com/therootcompany/golib/net/dataset" "github.com/therootcompany/golib/net/gitshallow" "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" @@ -14,15 +15,15 @@ type HTTPSource struct { Path string } -// Sources holds the configuration for fetching and loading the three cohorts. +// Sources holds fetch configuration for the three blocklist cohorts. // It knows how to pull data from git or HTTP, but owns no atomic state. type Sources struct { whitelistPaths []string inboundPaths []string outboundPaths []string - gitRepo *gitshallow.Repo // non-nil for git source; used by Init for clone-if-missing - syncs []httpcache.Syncer // all syncable sources (git repo or HTTP cachers) + gitRepo *gitshallow.Repo // non-nil for git source; used by Init for clone-if-missing + syncs []httpcache.Syncer // all syncable sources } func newFileSources(whitelist, inbound, outbound []string) *Sources { @@ -78,8 +79,7 @@ func (s *Sources) Fetch() (bool, error) { return anyUpdated, nil } -// Init ensures remotes are ready. For git: clones if missing then syncs. -// For HTTP: fetches each cacher unconditionally on first run. +// Init ensures remotes are ready: clones git if missing, or fetches HTTP files. func (s *Sources) Init() error { if s.gitRepo != nil { _, err := s.gitRepo.Init() @@ -93,23 +93,33 @@ func (s *Sources) Init() error { return nil } -func (s *Sources) LoadWhitelist() (*ipcohort.Cohort, error) { - if len(s.whitelistPaths) == 0 { - return nil, nil +// Datasets builds a dataset.Group backed by this Sources and returns typed +// datasets for whitelist, inbound, and outbound cohorts. Either whitelist or +// outbound may be nil if no paths were configured. +func (s *Sources) Datasets() ( + g *dataset.Group, + whitelist *dataset.Dataset[ipcohort.Cohort], + inbound *dataset.Dataset[ipcohort.Cohort], + outbound *dataset.Dataset[ipcohort.Cohort], +) { + g = dataset.NewGroup(s) + if len(s.whitelistPaths) > 0 { + paths := s.whitelistPaths + whitelist = dataset.Add(g, func() (*ipcohort.Cohort, error) { + return ipcohort.LoadFiles(paths...) + }) } - return ipcohort.LoadFiles(s.whitelistPaths...) -} - -func (s *Sources) LoadInbound() (*ipcohort.Cohort, error) { - if len(s.inboundPaths) == 0 { - return nil, nil + if len(s.inboundPaths) > 0 { + paths := s.inboundPaths + inbound = dataset.Add(g, func() (*ipcohort.Cohort, error) { + return ipcohort.LoadFiles(paths...) + }) } - return ipcohort.LoadFiles(s.inboundPaths...) -} - -func (s *Sources) LoadOutbound() (*ipcohort.Cohort, error) { - if len(s.outboundPaths) == 0 { - return nil, nil + if len(s.outboundPaths) > 0 { + paths := s.outboundPaths + outbound = dataset.Add(g, func() (*ipcohort.Cohort, error) { + return ipcohort.LoadFiles(paths...) + }) } - return ipcohort.LoadFiles(s.outboundPaths...) + return g, whitelist, inbound, outbound } diff --git a/net/ipcohort/cmd/check-ip/main.go b/net/ipcohort/cmd/check-ip/main.go index c9ad092..df8e353 100644 --- a/net/ipcohort/cmd/check-ip/main.go +++ b/net/ipcohort/cmd/check-ip/main.go @@ -8,12 +8,11 @@ import ( "os" "path/filepath" "strings" - "sync/atomic" "time" "github.com/oschwald/geoip2-golang" + "github.com/therootcompany/golib/net/dataset" "github.com/therootcompany/golib/net/geoip" - "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" ) @@ -75,26 +74,25 @@ func main() { ) } - var whitelist, inbound, outbound atomic.Pointer[ipcohort.Cohort] - + // Build typed datasets from the source. if err := src.Init(); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil { + blGroup, whitelistDS, inboundDS, outboundDS := src.Datasets() + if err := blGroup.Init(); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", - cohortSize(&inbound), cohortSize(&outbound)) - - // GeoIP: resolve paths and build cachers if we have credentials. - var cityDB, asnDB atomic.Pointer[geoip2.Reader] - var cityCacher, asnCacher *httpcache.Cacher + cohortSize(inboundDS), cohortSize(outboundDS)) + // GeoIP datasets. resolvedCityPath := *cityDBPath resolvedASNPath := *asnDBPath + var cityDS, asnDS *dataset.Dataset[geoip2.Reader] + if *geoipConf != "" { cfg, err := geoip.ParseConf(*geoipConf) if err != nil { @@ -104,6 +102,9 @@ func main() { if dbDir == "" { dbDir = dataPath } + if err := os.MkdirAll(dbDir, 0o755); err != nil { + fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err) + } d := geoip.New(cfg.AccountID, cfg.LicenseKey) if resolvedCityPath == "" { resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb") @@ -111,37 +112,44 @@ func main() { if resolvedASNPath == "" { resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb") } - cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath) - asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath) - if err := os.MkdirAll(dbDir, 0o755); err != nil { - fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err) - } + cityDS = newGeoIPDataset(d, geoip.CityEdition, resolvedCityPath) + asnDS = newGeoIPDataset(d, geoip.ASNEdition, resolvedASNPath) + } + } else { + // Manual paths: no auto-download, just open existing files. + if resolvedCityPath != "" { + cityDS = newGeoIPDataset(nil, "", resolvedCityPath) + } + if resolvedASNPath != "" { + asnDS = newGeoIPDataset(nil, "", resolvedASNPath) } } - // Fetch GeoIP DBs if we have cachers; otherwise just open existing files. - if cityCacher != nil { - if _, err := cityCacher.Fetch(); err != nil { - fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err) + if cityDS != nil { + if err := cityDS.Init(); err != nil { + fmt.Fprintf(os.Stderr, "warn: city DB: %v\n", err) } } - if asnCacher != nil { - if _, err := asnCacher.Fetch(); err != nil { - fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err) + if asnDS != nil { + if err := asnDS.Init(); err != nil { + fmt.Fprintf(os.Stderr, "warn: ASN DB: %v\n", err) } } - openGeoIPReader(resolvedCityPath, &cityDB) - openGeoIPReader(resolvedASNPath, &asnDB) - // Keep everything fresh in the background if running as a daemon. + // Keep everything fresh in the background. ctx, cancel := context.WithCancel(context.Background()) defer cancel() - go runLoop(ctx, src, &whitelist, &inbound, &outbound, - cityCacher, asnCacher, &cityDB, &asnDB) + go blGroup.Run(ctx, 47*time.Minute) + if cityDS != nil { + go cityDS.Run(ctx, 47*time.Minute) + } + if asnDS != nil { + go asnDS.Run(ctx, 47*time.Minute) + } // Check and report. - blockedInbound := containsInbound(ipStr, &whitelist, &inbound) - blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound) + blockedInbound := containsInbound(ipStr, whitelistDS, inboundDS) + blockedOutbound := containsOutbound(ipStr, whitelistDS, outboundDS) switch { case blockedInbound && blockedOutbound: @@ -154,149 +162,112 @@ func main() { fmt.Printf("%s is allowed\n", ipStr) } - printGeoInfo(ipStr, &cityDB, &asnDB) + printGeoInfo(ipStr, cityDS, asnDS) if blockedInbound || blockedOutbound { os.Exit(1) } } -func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) { - if path == "" { - return - } - r, err := geoip2.Open(path) - if err != nil { - return - } - if old := ptr.Swap(r); old != nil { - old.Close() +// newGeoIPDataset creates a Dataset[geoip2.Reader]. If d is nil, only +// opens the existing file (no download). Close is wired to Reader.Close. +func newGeoIPDataset(d *geoip.Downloader, edition, path string) *dataset.Dataset[geoip2.Reader] { + var syncer interface{ Fetch() (bool, error) } + if d != nil { + syncer = d.NewCacher(edition, path) + } else { + syncer = &nopSyncer{} } + ds := dataset.New(syncer, func() (*geoip2.Reader, error) { + return geoip2.Open(path) + }) + ds.Name = edition + ds.Close = func(r *geoip2.Reader) { r.Close() } + return ds } -func runLoop(ctx context.Context, src *Sources, - whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], - cityCacher, asnCacher *httpcache.Cacher, - cityDB, asnDB *atomic.Pointer[geoip2.Reader], -) { - ticker := time.NewTicker(47 * time.Minute) - defer ticker.Stop() +// nopSyncer satisfies httpcache.Syncer for file-only datasets (no download). +type nopSyncer struct{} - for { - select { - case <-ticker.C: - // Blocklists. - if updated, err := src.Fetch(); err != nil { - fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err) - } else if updated { - if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil { - fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err) - } else { - fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n", - cohortSize(inbound), cohortSize(outbound)) - } - } +func (n *nopSyncer) Fetch() (bool, error) { return false, nil } - // GeoIP DBs. - if cityCacher != nil { - if updated, err := cityCacher.Fetch(); err != nil { - fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err) - } else if updated { - openGeoIPReader(cityCacher.Path, cityDB) - fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path) - } - } - if asnCacher != nil { - if updated, err := asnCacher.Fetch(); err != nil { - fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err) - } else if updated { - openGeoIPReader(asnCacher.Path, asnDB) - fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path) - } - } - case <-ctx.Done(): - return +func containsInbound(ip string, + whitelist, inbound *dataset.Dataset[ipcohort.Cohort], +) bool { + if whitelist != nil { + if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { + return false } } -} - -func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) { - ip, err := netip.ParseAddr(ipStr) - if err != nil { - return - } - stdIP := ip.AsSlice() - - if r := cityDB.Load(); r != nil { - if rec, err := r.City(stdIP); err == nil { - city := rec.City.Names["en"] - country := rec.Country.Names["en"] - iso := rec.Country.IsoCode - var parts []string - if city != "" { - parts = append(parts, city) - } - if len(rec.Subdivisions) > 0 { - if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city { - parts = append(parts, sub) - } - } - if country != "" { - parts = append(parts, fmt.Sprintf("%s (%s)", country, iso)) - } - if len(parts) > 0 { - fmt.Printf(" Location: %s\n", strings.Join(parts, ", ")) - } - } - } - - if r := asnDB.Load(); r != nil { - if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { - fmt.Printf(" ASN: AS%d %s\n", - rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) - } - } -} - -func reloadBlocklists(src *Sources, - whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], -) error { - if wl, err := src.LoadWhitelist(); err != nil { - return err - } else if wl != nil { - whitelist.Store(wl) - } - if in, err := src.LoadInbound(); err != nil { - return err - } else if in != nil { - inbound.Store(in) - } - if out, err := src.LoadOutbound(); err != nil { - return err - } else if out != nil { - outbound.Store(out) - } - return nil -} - -func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool { - if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { + if inbound == nil { return false } c := inbound.Load() return c != nil && c.Contains(ip) } -func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Cohort]) bool { - if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { +func containsOutbound(ip string, + whitelist, outbound *dataset.Dataset[ipcohort.Cohort], +) bool { + if whitelist != nil { + if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { + return false + } + } + if outbound == nil { return false } c := outbound.Load() return c != nil && c.Contains(ip) } -func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int { - if c := ptr.Load(); c != nil { +func printGeoInfo(ipStr string, cityDS, asnDS *dataset.Dataset[geoip2.Reader]) { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return + } + stdIP := ip.AsSlice() + + if cityDS != nil { + if r := cityDS.Load(); r != nil { + if rec, err := r.City(stdIP); err == nil { + city := rec.City.Names["en"] + country := rec.Country.Names["en"] + iso := rec.Country.IsoCode + var parts []string + if city != "" { + parts = append(parts, city) + } + if len(rec.Subdivisions) > 0 { + if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city { + parts = append(parts, sub) + } + } + if country != "" { + parts = append(parts, fmt.Sprintf("%s (%s)", country, iso)) + } + if len(parts) > 0 { + fmt.Printf(" Location: %s\n", strings.Join(parts, ", ")) + } + } + } + } + + if asnDS != nil { + if r := asnDS.Load(); r != nil { + if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 { + fmt.Printf(" ASN: AS%d %s\n", + rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization) + } + } + } +} + +func cohortSize(ds *dataset.Dataset[ipcohort.Cohort]) int { + if ds == nil { + return 0 + } + if c := ds.Load(); c != nil { return c.Size() } return 0