refactor: dataset uses closure Loader + Close callback; check-ip uses Dataset/Group

dataset.Loader[T] is now func() (*T, error) — a closure capturing its own
paths/config, so multi-file cases (LoadFiles(paths...)) work naturally.

Dataset.Close func(*T) is called with the old value after each swap, enabling
resource cleanup (e.g. geoip2.Reader.Close).

Sources.Datasets() builds a dataset.Group + three typed *Dataset[ipcohort.Cohort].
main.go now uses blGroup.Run / cityDS.Run / asnDS.Run instead of hand-rolled
atomic.Pointer + polling loops. containsInbound/OutBound accept *Dataset[Cohort].
nopSyncer handles file-only GeoIP paths (no download, just open).
This commit is contained in:
AJ ONeal 2026-04-20 09:28:20 -06:00
parent 7c0cd26da1
commit 673d084bd2
No known key found for this signature in database
3 changed files with 192 additions and 203 deletions

View File

@ -2,22 +2,22 @@
// atomic.Pointer (hot-swap), providing a generic periodically-updated
// in-memory dataset with lock-free reads.
//
// Single dataset:
// Standalone dataset (one syncer, one value):
//
// ds := dataset.New(cacher, ipcohort.LoadFile, path)
// ds := dataset.New(cacher, func() (*MyType, error) {
// return mytype.LoadFile(path)
// })
// if err := ds.Init(); err != nil { ... }
// go ds.Run(ctx, 47*time.Minute)
// cohort := ds.Load()
// val := ds.Load() // *MyType, lock-free
//
// Multiple datasets sharing one syncer (e.g. inbound + outbound from one git repo):
// Group (one syncer, multiple values — e.g. inbound+outbound from one git repo):
//
// g := dataset.NewGroup(repo)
// inbound := dataset.Add(g, ipcohort.LoadFile, inboundPath)
// outbound := dataset.Add(g, ipcohort.LoadFile, outboundPath)
// inbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles(inboundPaths...) })
// outbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { return ipcohort.LoadFiles(outboundPaths...) })
// if err := g.Init(); err != nil { ... }
// go g.Run(ctx, 47*time.Minute)
// in := inbound.Load()
// out := outbound.Load()
package dataset
import (
@ -30,21 +30,24 @@ import (
"github.com/therootcompany/golib/net/httpcache"
)
// Loader reads path and returns the parsed value, or an error.
type Loader[T any] func(path string) (*T, error)
// Dataset couples a Syncer, a Loader, and an atomic.Pointer.
// Dataset couples a Syncer, a load function, and an atomic.Pointer[T].
// Load is safe for concurrent use without locks.
type Dataset[T any] struct {
// Name is used in error messages. Optional.
Name string
// Close is called with the previous value after each successful swap.
// Use this for values that hold resources, e.g. func(r *geoip2.Reader) { r.Close() }.
Close func(*T)
syncer httpcache.Syncer
load Loader[T]
path string
load func() (*T, error)
ptr atomic.Pointer[T]
}
// New creates a Dataset. The syncer fetches updates to path; load parses it.
func New[T any](syncer httpcache.Syncer, load Loader[T], path string) *Dataset[T] {
return &Dataset[T]{syncer: syncer, load: load, path: path}
// New creates a Dataset. The syncer fetches updates; load produces the value.
// load is a closure — it captures whatever paths or config it needs.
func New[T any](syncer httpcache.Syncer, load func() (*T, error)) *Dataset[T] {
return &Dataset[T]{syncer: syncer, load: load}
}
// Load returns the current value. Returns nil before Init is called.
@ -52,7 +55,7 @@ func (d *Dataset[T]) Load() *T {
return d.ptr.Load()
}
// Init fetches (if the syncer needs it) then loads, ensuring the dataset is
// Init fetches (if needed) then always loads, ensuring the dataset is
// populated on startup from an existing local file even if nothing changed.
func (d *Dataset[T]) Init() error {
if _, err := d.syncer.Fetch(); err != nil {
@ -61,8 +64,7 @@ func (d *Dataset[T]) Init() error {
return d.reload()
}
// Sync fetches from the remote and reloads if the content changed.
// Returns whether the value was updated.
// Sync fetches and reloads if the content changed. Returns whether updated.
func (d *Dataset[T]) Sync() (bool, error) {
updated, err := d.syncer.Fetch()
if err != nil || !updated {
@ -80,7 +82,11 @@ func (d *Dataset[T]) Run(ctx context.Context, interval time.Duration) {
select {
case <-ticker.C:
if _, err := d.Sync(); err != nil {
fmt.Fprintf(os.Stderr, "dataset %s: sync error: %v\n", d.path, err)
name := d.Name
if name == "" {
name = "dataset"
}
fmt.Fprintf(os.Stderr, "%s: sync error: %v\n", name, err)
}
case <-ctx.Done():
return
@ -89,27 +95,28 @@ func (d *Dataset[T]) Run(ctx context.Context, interval time.Duration) {
}
func (d *Dataset[T]) reload() error {
val, err := d.load(d.path)
val, err := d.load()
if err != nil {
return err
}
d.ptr.Store(val)
if old := d.ptr.Swap(val); old != nil && d.Close != nil {
d.Close(old)
}
return nil
}
// -- Group: one Syncer driving multiple datasets ---------------------------
// entry is the type-erased reload handle stored in a Group.
type entry interface {
// member is the type-erased reload handle stored in a Group.
type member interface {
reload() error
}
// Group ties one Syncer to multiple datasets so a single Fetch drives all
// reloads — avoiding redundant network calls when datasets share a source
// (e.g. multiple files from the same git repo or HTTP directory).
// reloads — no redundant network calls when datasets share a source.
type Group struct {
syncer httpcache.Syncer
entries []entry
members []member
}
// NewGroup creates a Group backed by syncer.
@ -117,11 +124,12 @@ func NewGroup(syncer httpcache.Syncer) *Group {
return &Group{syncer: syncer}
}
// Add registers a new dataset in g and returns it. Subsequent Init/Sync/Run
// calls on g will reload this dataset whenever the syncer reports an update.
func Add[T any](g *Group, load Loader[T], path string) *Dataset[T] {
d := &Dataset[T]{load: load, path: path}
g.entries = append(g.entries, d)
// Add registers a new dataset in g and returns it. Call Init or Run on g —
// not on the returned dataset — to drive updates.
// load is a closure capturing whatever paths or config it needs.
func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] {
d := &Dataset[T]{load: load}
g.members = append(g.members, d)
return d
}
@ -159,8 +167,8 @@ func (g *Group) Run(ctx context.Context, interval time.Duration) {
}
func (g *Group) reloadAll() error {
for _, e := range g.entries {
if err := e.reload(); err != nil {
for _, m := range g.members {
if err := m.reload(); err != nil {
return err
}
}

View File

@ -3,6 +3,7 @@ package main
import (
"path/filepath"
"github.com/therootcompany/golib/net/dataset"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort"
@ -14,15 +15,15 @@ type HTTPSource struct {
Path string
}
// Sources holds the configuration for fetching and loading the three cohorts.
// Sources holds fetch configuration for the three blocklist cohorts.
// It knows how to pull data from git or HTTP, but owns no atomic state.
type Sources struct {
whitelistPaths []string
inboundPaths []string
outboundPaths []string
gitRepo *gitshallow.Repo // non-nil for git source; used by Init for clone-if-missing
syncs []httpcache.Syncer // all syncable sources (git repo or HTTP cachers)
gitRepo *gitshallow.Repo // non-nil for git source; used by Init for clone-if-missing
syncs []httpcache.Syncer // all syncable sources
}
func newFileSources(whitelist, inbound, outbound []string) *Sources {
@ -78,8 +79,7 @@ func (s *Sources) Fetch() (bool, error) {
return anyUpdated, nil
}
// Init ensures remotes are ready. For git: clones if missing then syncs.
// For HTTP: fetches each cacher unconditionally on first run.
// Init ensures remotes are ready: clones git if missing, or fetches HTTP files.
func (s *Sources) Init() error {
if s.gitRepo != nil {
_, err := s.gitRepo.Init()
@ -93,23 +93,33 @@ func (s *Sources) Init() error {
return nil
}
func (s *Sources) LoadWhitelist() (*ipcohort.Cohort, error) {
if len(s.whitelistPaths) == 0 {
return nil, nil
// Datasets builds a dataset.Group backed by this Sources and returns typed
// datasets for whitelist, inbound, and outbound cohorts. Either whitelist or
// outbound may be nil if no paths were configured.
func (s *Sources) Datasets() (
g *dataset.Group,
whitelist *dataset.Dataset[ipcohort.Cohort],
inbound *dataset.Dataset[ipcohort.Cohort],
outbound *dataset.Dataset[ipcohort.Cohort],
) {
g = dataset.NewGroup(s)
if len(s.whitelistPaths) > 0 {
paths := s.whitelistPaths
whitelist = dataset.Add(g, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
})
}
return ipcohort.LoadFiles(s.whitelistPaths...)
}
func (s *Sources) LoadInbound() (*ipcohort.Cohort, error) {
if len(s.inboundPaths) == 0 {
return nil, nil
if len(s.inboundPaths) > 0 {
paths := s.inboundPaths
inbound = dataset.Add(g, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
})
}
return ipcohort.LoadFiles(s.inboundPaths...)
}
func (s *Sources) LoadOutbound() (*ipcohort.Cohort, error) {
if len(s.outboundPaths) == 0 {
return nil, nil
if len(s.outboundPaths) > 0 {
paths := s.outboundPaths
outbound = dataset.Add(g, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(paths...)
})
}
return ipcohort.LoadFiles(s.outboundPaths...)
return g, whitelist, inbound, outbound
}

View File

@ -8,12 +8,11 @@ import (
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"github.com/oschwald/geoip2-golang"
"github.com/therootcompany/golib/net/dataset"
"github.com/therootcompany/golib/net/geoip"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort"
)
@ -75,26 +74,25 @@ func main() {
)
}
var whitelist, inbound, outbound atomic.Pointer[ipcohort.Cohort]
// Build typed datasets from the source.
if err := src.Init(); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
if err := reloadBlocklists(src, &whitelist, &inbound, &outbound); err != nil {
blGroup, whitelistDS, inboundDS, outboundDS := src.Datasets()
if err := blGroup.Init(); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n",
cohortSize(&inbound), cohortSize(&outbound))
// GeoIP: resolve paths and build cachers if we have credentials.
var cityDB, asnDB atomic.Pointer[geoip2.Reader]
var cityCacher, asnCacher *httpcache.Cacher
cohortSize(inboundDS), cohortSize(outboundDS))
// GeoIP datasets.
resolvedCityPath := *cityDBPath
resolvedASNPath := *asnDBPath
var cityDS, asnDS *dataset.Dataset[geoip2.Reader]
if *geoipConf != "" {
cfg, err := geoip.ParseConf(*geoipConf)
if err != nil {
@ -104,6 +102,9 @@ func main() {
if dbDir == "" {
dbDir = dataPath
}
if err := os.MkdirAll(dbDir, 0o755); err != nil {
fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err)
}
d := geoip.New(cfg.AccountID, cfg.LicenseKey)
if resolvedCityPath == "" {
resolvedCityPath = filepath.Join(dbDir, geoip.CityEdition+".mmdb")
@ -111,37 +112,44 @@ func main() {
if resolvedASNPath == "" {
resolvedASNPath = filepath.Join(dbDir, geoip.ASNEdition+".mmdb")
}
cityCacher = d.NewCacher(geoip.CityEdition, resolvedCityPath)
asnCacher = d.NewCacher(geoip.ASNEdition, resolvedASNPath)
if err := os.MkdirAll(dbDir, 0o755); err != nil {
fmt.Fprintf(os.Stderr, "warn: mkdir %s: %v\n", dbDir, err)
}
cityDS = newGeoIPDataset(d, geoip.CityEdition, resolvedCityPath)
asnDS = newGeoIPDataset(d, geoip.ASNEdition, resolvedASNPath)
}
} else {
// Manual paths: no auto-download, just open existing files.
if resolvedCityPath != "" {
cityDS = newGeoIPDataset(nil, "", resolvedCityPath)
}
if resolvedASNPath != "" {
asnDS = newGeoIPDataset(nil, "", resolvedASNPath)
}
}
// Fetch GeoIP DBs if we have cachers; otherwise just open existing files.
if cityCacher != nil {
if _, err := cityCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "warn: city DB fetch: %v\n", err)
if cityDS != nil {
if err := cityDS.Init(); err != nil {
fmt.Fprintf(os.Stderr, "warn: city DB: %v\n", err)
}
}
if asnCacher != nil {
if _, err := asnCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "warn: ASN DB fetch: %v\n", err)
if asnDS != nil {
if err := asnDS.Init(); err != nil {
fmt.Fprintf(os.Stderr, "warn: ASN DB: %v\n", err)
}
}
openGeoIPReader(resolvedCityPath, &cityDB)
openGeoIPReader(resolvedASNPath, &asnDB)
// Keep everything fresh in the background if running as a daemon.
// Keep everything fresh in the background.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go runLoop(ctx, src, &whitelist, &inbound, &outbound,
cityCacher, asnCacher, &cityDB, &asnDB)
go blGroup.Run(ctx, 47*time.Minute)
if cityDS != nil {
go cityDS.Run(ctx, 47*time.Minute)
}
if asnDS != nil {
go asnDS.Run(ctx, 47*time.Minute)
}
// Check and report.
blockedInbound := containsInbound(ipStr, &whitelist, &inbound)
blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound)
blockedInbound := containsInbound(ipStr, whitelistDS, inboundDS)
blockedOutbound := containsOutbound(ipStr, whitelistDS, outboundDS)
switch {
case blockedInbound && blockedOutbound:
@ -154,149 +162,112 @@ func main() {
fmt.Printf("%s is allowed\n", ipStr)
}
printGeoInfo(ipStr, &cityDB, &asnDB)
printGeoInfo(ipStr, cityDS, asnDS)
if blockedInbound || blockedOutbound {
os.Exit(1)
}
}
func openGeoIPReader(path string, ptr *atomic.Pointer[geoip2.Reader]) {
if path == "" {
return
}
r, err := geoip2.Open(path)
if err != nil {
return
}
if old := ptr.Swap(r); old != nil {
old.Close()
// newGeoIPDataset creates a Dataset[geoip2.Reader]. If d is nil, only
// opens the existing file (no download). Close is wired to Reader.Close.
func newGeoIPDataset(d *geoip.Downloader, edition, path string) *dataset.Dataset[geoip2.Reader] {
var syncer interface{ Fetch() (bool, error) }
if d != nil {
syncer = d.NewCacher(edition, path)
} else {
syncer = &nopSyncer{}
}
ds := dataset.New(syncer, func() (*geoip2.Reader, error) {
return geoip2.Open(path)
})
ds.Name = edition
ds.Close = func(r *geoip2.Reader) { r.Close() }
return ds
}
func runLoop(ctx context.Context, src *Sources,
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
cityCacher, asnCacher *httpcache.Cacher,
cityDB, asnDB *atomic.Pointer[geoip2.Reader],
) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
// nopSyncer satisfies httpcache.Syncer for file-only datasets (no download).
type nopSyncer struct{}
for {
select {
case <-ticker.C:
// Blocklists.
if updated, err := src.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err)
} else if updated {
if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil {
fmt.Fprintf(os.Stderr, "error: blocklist reload: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n",
cohortSize(inbound), cohortSize(outbound))
}
}
func (n *nopSyncer) Fetch() (bool, error) { return false, nil }
// GeoIP DBs.
if cityCacher != nil {
if updated, err := cityCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "error: city DB sync: %v\n", err)
} else if updated {
openGeoIPReader(cityCacher.Path, cityDB)
fmt.Fprintf(os.Stderr, "reloaded: %s\n", cityCacher.Path)
}
}
if asnCacher != nil {
if updated, err := asnCacher.Fetch(); err != nil {
fmt.Fprintf(os.Stderr, "error: ASN DB sync: %v\n", err)
} else if updated {
openGeoIPReader(asnCacher.Path, asnDB)
fmt.Fprintf(os.Stderr, "reloaded: %s\n", asnCacher.Path)
}
}
case <-ctx.Done():
return
func containsInbound(ip string,
whitelist, inbound *dataset.Dataset[ipcohort.Cohort],
) bool {
if whitelist != nil {
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
}
}
func printGeoInfo(ipStr string, cityDB, asnDB *atomic.Pointer[geoip2.Reader]) {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return
}
stdIP := ip.AsSlice()
if r := cityDB.Load(); r != nil {
if rec, err := r.City(stdIP); err == nil {
city := rec.City.Names["en"]
country := rec.Country.Names["en"]
iso := rec.Country.IsoCode
var parts []string
if city != "" {
parts = append(parts, city)
}
if len(rec.Subdivisions) > 0 {
if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city {
parts = append(parts, sub)
}
}
if country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", country, iso))
}
if len(parts) > 0 {
fmt.Printf(" Location: %s\n", strings.Join(parts, ", "))
}
}
}
if r := asnDB.Load(); r != nil {
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
fmt.Printf(" ASN: AS%d %s\n",
rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
}
}
}
func reloadBlocklists(src *Sources,
whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort],
) error {
if wl, err := src.LoadWhitelist(); err != nil {
return err
} else if wl != nil {
whitelist.Store(wl)
}
if in, err := src.LoadInbound(); err != nil {
return err
} else if in != nil {
inbound.Store(in)
}
if out, err := src.LoadOutbound(); err != nil {
return err
} else if out != nil {
outbound.Store(out)
}
return nil
}
func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool {
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
if inbound == nil {
return false
}
c := inbound.Load()
return c != nil && c.Contains(ip)
}
func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Cohort]) bool {
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
func containsOutbound(ip string,
whitelist, outbound *dataset.Dataset[ipcohort.Cohort],
) bool {
if whitelist != nil {
if wl := whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
}
if outbound == nil {
return false
}
c := outbound.Load()
return c != nil && c.Contains(ip)
}
func cohortSize(ptr *atomic.Pointer[ipcohort.Cohort]) int {
if c := ptr.Load(); c != nil {
func printGeoInfo(ipStr string, cityDS, asnDS *dataset.Dataset[geoip2.Reader]) {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return
}
stdIP := ip.AsSlice()
if cityDS != nil {
if r := cityDS.Load(); r != nil {
if rec, err := r.City(stdIP); err == nil {
city := rec.City.Names["en"]
country := rec.Country.Names["en"]
iso := rec.Country.IsoCode
var parts []string
if city != "" {
parts = append(parts, city)
}
if len(rec.Subdivisions) > 0 {
if sub := rec.Subdivisions[0].Names["en"]; sub != "" && sub != city {
parts = append(parts, sub)
}
}
if country != "" {
parts = append(parts, fmt.Sprintf("%s (%s)", country, iso))
}
if len(parts) > 0 {
fmt.Printf(" Location: %s\n", strings.Join(parts, ", "))
}
}
}
}
if asnDS != nil {
if r := asnDS.Load(); r != nil {
if rec, err := r.ASN(stdIP); err == nil && rec.AutonomousSystemNumber != 0 {
fmt.Printf(" ASN: AS%d %s\n",
rec.AutonomousSystemNumber, rec.AutonomousSystemOrganization)
}
}
}
}
func cohortSize(ds *dataset.Dataset[ipcohort.Cohort]) int {
if ds == nil {
return 0
}
if c := ds.Load(); c != nil {
return c.Size()
}
return 0