refactor(dataset): rename Group to Set, accept variadic fetchers

Set handles both single-fetcher (one git repo) and multi-fetcher
(GeoLite2 City + ASN) cases uniformly. Any fetcher reporting an update
triggers a view reload. This replaces the per-caller FetcherFunc wrapper
that combined the two MaxMind cachers and the ad-hoc atomic.Pointer +
ticker goroutine in cmd/check-ip — geoip now rides on the same
Set/View/Load/Tick surface as the blocklists.
This commit is contained in:
AJ ONeal 2026-04-20 16:50:33 -06:00
parent 01158aee55
commit e329c0f86b
No known key found for this signature in database
4 changed files with 95 additions and 121 deletions

View File

@ -12,7 +12,6 @@ import (
"os"
"os/signal"
"path/filepath"
"sync/atomic"
"syscall"
"time"
@ -44,7 +43,7 @@ type IPCheck struct {
inbound *dataset.View[ipcohort.Cohort]
outbound *dataset.View[ipcohort.Cohort]
geo atomic.Pointer[geoip.Databases]
geo *dataset.View[geoip.Databases]
}
func main() {
@ -114,20 +113,20 @@ func main() {
// Blocklists: git repo with inbound + outbound IP cohort files.
repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
group := dataset.NewGroup(repo)
cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
blocklists := dataset.NewSet(repo)
cfg.inbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"),
)
})
cfg.outbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
cfg.outbound = dataset.Add(blocklists, func() (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"),
)
})
if err := group.Load(context.Background()); err != nil {
if err := blocklists.Load(context.Background()); err != nil {
log.Fatalf("blocklists: %v", err)
}
@ -138,45 +137,34 @@ func main() {
maxmindDir := filepath.Join(cfg.CacheDir, "maxmind")
cityTarPath := filepath.Join(maxmindDir, "GeoLite2-City.tar.gz")
asnTarPath := filepath.Join(maxmindDir, "GeoLite2-ASN.tar.gz")
var geoFetcher dataset.Fetcher
var geoSet *dataset.Set
if cfg.GeoIPBasicAuth != "" {
city := &httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz",
Path: cityTarPath,
MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization",
AuthValue: cfg.GeoIPBasicAuth,
}
asn := &httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz",
Path: asnTarPath,
MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization",
AuthValue: cfg.GeoIPBasicAuth,
}
geoFetcher = dataset.FetcherFunc(func() (bool, error) {
cityUpdated, err := city.Fetch()
if err != nil {
return false, fmt.Errorf("fetch GeoLite2-City: %w", err)
}
asnUpdated, err := asn.Fetch()
if err != nil {
return false, fmt.Errorf("fetch GeoLite2-ASN: %w", err)
}
return cityUpdated || asnUpdated, nil
})
geoSet = dataset.NewSet(
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-City/download?suffix=tar.gz",
Path: cityTarPath,
MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization",
AuthValue: cfg.GeoIPBasicAuth,
},
&httpcache.Cacher{
URL: geoip.DownloadBase + "/GeoLite2-ASN/download?suffix=tar.gz",
Path: asnTarPath,
MaxAge: 3 * 24 * time.Hour,
AuthHeader: "Authorization",
AuthValue: cfg.GeoIPBasicAuth,
},
)
} else {
geoFetcher = dataset.PollFiles(cityTarPath, asnTarPath)
geoSet = dataset.NewSet(dataset.PollFiles(cityTarPath, asnTarPath))
}
if _, err := geoFetcher.Fetch(); err != nil {
cfg.geo = dataset.Add(geoSet, func() (*geoip.Databases, error) {
return geoip.Open(maxmindDir)
})
if err := geoSet.Load(context.Background()); err != nil {
log.Fatalf("geoip: %v", err)
}
geoDB, err := geoip.Open(maxmindDir)
if err != nil {
log.Fatalf("geoip: %v", err)
}
cfg.geo.Store(geoDB)
defer func() { _ = cfg.geo.Load().Close() }()
defer func() { _ = cfg.geo.Value().Close() }()
for _, ip := range ips {
cfg.writeText(os.Stdout, cfg.lookup(ip))
@ -187,36 +175,12 @@ func main() {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
go group.Tick(ctx, refreshInterval, func(err error) {
go blocklists.Tick(ctx, refreshInterval, func(err error) {
log.Printf("blocklists refresh: %v", err)
})
go func() {
t := time.NewTicker(refreshInterval)
defer t.Stop()
for {
select {
case <-ctx.Done():
return
case <-t.C:
updated, err := geoFetcher.Fetch()
if err != nil {
log.Printf("geoip refresh: %v", err)
continue
}
if !updated {
continue
}
db, err := geoip.Open(maxmindDir)
if err != nil {
log.Printf("geoip refresh: %v", err)
continue
}
if old := cfg.geo.Swap(db); old != nil {
_ = old.Close()
}
}
}
}()
go geoSet.Tick(ctx, refreshInterval, func(err error) {
log.Printf("geoip refresh: %v", err)
})
if err := cfg.serve(ctx); err != nil {
log.Fatalf("serve: %v", err)
}

View File

@ -34,7 +34,7 @@ func (c *IPCheck) lookup(ip string) Result {
Blocked: in || out,
BlockedInbound: in,
BlockedOutbound: out,
Geo: c.geo.Load().Lookup(ip),
Geo: c.geo.Value().Lookup(ip),
}
}

View File

@ -1,17 +1,17 @@
// Package dataset manages values that are periodically re-fetched from an
// upstream source and hot-swapped behind atomic pointers. Consumers read via
// View.Value (lock-free); a single Load drives any number of views off one
// Fetcher, so shared sources (one git pull, one zip download) don't get
// re-fetched per view.
// View.Value (lock-free); a single Load drives any number of views off a
// shared set of Fetchers, so upstreams (one git pull, one tar.gz download)
// don't get re-fetched per view.
//
// Typical lifecycle:
//
// g := dataset.NewGroup(repo) // *gitshallow.Repo satisfies Fetcher
// inbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { ... })
// outbound := dataset.Add(g, func() (*ipcohort.Cohort, error) { ... })
// if err := g.Load(ctx); err != nil { ... } // initial populate
// go g.Tick(ctx, 47*time.Minute) // background refresh
// current := inbound.Value() // lock-free read
// s := dataset.NewSet(repo) // *gitshallow.Repo satisfies Fetcher
// inbound := dataset.Add(s, func() (*ipcohort.Cohort, error) { ... })
// outbound := dataset.Add(s, func() (*ipcohort.Cohort, error) { ... })
// if err := s.Load(ctx); err != nil { ... } // initial populate
// go s.Tick(ctx, 47*time.Minute, onError) // background refresh
// current := inbound.Value() // lock-free read
package dataset
import (
@ -34,7 +34,7 @@ type FetcherFunc func() (bool, error)
func (f FetcherFunc) Fetch() (bool, error) { return f() }
// NopFetcher always reports no update. Use for groups whose source never
// NopFetcher always reports no update. Use for sets whose source never
// changes (test fixtures, embedded data).
type NopFetcher struct{}
@ -44,8 +44,8 @@ func (NopFetcher) Fetch() (bool, error) { return false, nil }
// "updated" whenever any file's size or modtime has changed since the last
// call. The first call always reports updated=true.
//
// Use for Group's whose source is local files that may be edited out of band
// (e.g. a user-provided --inbound list) — pair with Group.Tick to pick up
// Use for Sets whose source is local files that may be edited out of band
// (e.g. a user-provided --inbound list) — pair with Set.Tick to pick up
// changes automatically.
func PollFiles(paths ...string) Fetcher {
return &filePoller{paths: paths, stats: make(map[string]fileStat, len(paths))}
@ -80,13 +80,16 @@ func (p *filePoller) Fetch() (bool, error) {
return changed, nil
}
// Group ties one Fetcher to one or more views. A Load call fetches once and,
// on the first call or when the source reports a change, reloads every view
// and atomically swaps its current value.
type Group struct {
fetcher Fetcher
views []reloader
loaded atomic.Bool
// Set ties one or more Fetchers to one or more views. A Load call fetches
// each source and, on the first call or when any source reports a change,
// reloads every view and atomically swaps its current value. Use multiple
// fetchers when a single logical dataset is spread across several archives
// (e.g. GeoLite2 City + ASN); a single fetcher is the common case (one git
// repo, one tar.gz).
type Set struct {
fetchers []Fetcher
views []reloader
loaded atomic.Bool
}
// reloader is a type-erased handle to a View's reload function.
@ -94,22 +97,29 @@ type reloader interface {
reload() error
}
// NewGroup creates a Group backed by fetcher.
func NewGroup(fetcher Fetcher) *Group {
return &Group{fetcher: fetcher}
// NewSet creates a Set backed by fetchers. All fetchers are called on every
// Load; the set reloads its views whenever any one of them reports a change.
func NewSet(fetchers ...Fetcher) *Set {
return &Set{fetchers: fetchers}
}
// Load fetches upstream and, on the first call or whenever the fetcher reports
// a change, reloads every view and atomically installs the new values.
func (g *Group) Load(ctx context.Context) error {
updated, err := g.fetcher.Fetch()
if err != nil {
return err
// Load fetches upstream and, on the first call or whenever any fetcher
// reports a change, reloads every view and atomically installs the new values.
func (s *Set) Load(ctx context.Context) error {
updated := false
for _, f := range s.fetchers {
u, err := f.Fetch()
if err != nil {
return err
}
if u {
updated = true
}
}
if g.loaded.Load() && !updated {
if s.loaded.Load() && !updated {
return nil
}
for _, v := range g.views {
for _, v := range s.views {
if err := ctx.Err(); err != nil {
return err
}
@ -117,14 +127,14 @@ func (g *Group) Load(ctx context.Context) error {
return err
}
}
g.loaded.Store(true)
s.loaded.Store(true)
return nil
}
// Tick calls Load every interval until ctx is done. Load errors are passed to
// onError (if non-nil) and do not stop the loop; callers choose whether to log,
// count, page, or ignore. Run in a goroutine: `go g.Tick(ctx, d, onError)`.
func (g *Group) Tick(ctx context.Context, interval time.Duration, onError func(error)) {
// count, page, or ignore. Run in a goroutine: `go s.Tick(ctx, d, onError)`.
func (s *Set) Tick(ctx context.Context, interval time.Duration, onError func(error)) {
t := time.NewTicker(interval)
defer t.Stop()
for {
@ -132,20 +142,20 @@ func (g *Group) Tick(ctx context.Context, interval time.Duration, onError func(e
case <-ctx.Done():
return
case <-t.C:
if err := g.Load(ctx); err != nil && onError != nil {
if err := s.Load(ctx); err != nil && onError != nil {
onError(err)
}
}
}
}
// View is a read-only handle to one dataset inside a Group.
// View is a read-only handle to one dataset inside a Set.
type View[T any] struct {
loader func() (*T, error)
ptr atomic.Pointer[T]
}
// Value returns the current snapshot. Nil before the Group is first loaded.
// Value returns the current snapshot. Nil before the Set is first loaded.
func (v *View[T]) Value() *T {
return v.ptr.Load()
}
@ -159,10 +169,10 @@ func (v *View[T]) reload() error {
return nil
}
// Add registers a new view in g and returns it. Call after NewGroup and
// before the first Load.
func Add[T any](g *Group, loader func() (*T, error)) *View[T] {
// Add registers a new view in s and returns it. Call after NewSet and before
// the first Load.
func Add[T any](s *Set, loader func() (*T, error)) *View[T] {
v := &View[T]{loader: loader}
g.views = append(g.views, v)
s.views = append(s.views, v)
return v
}

View File

@ -21,9 +21,9 @@ func (f *countFetcher) Fetch() (bool, error) {
return f.updated, f.err
}
func TestGroup_LoadPopulatesAllViews(t *testing.T) {
func TestSet_LoadPopulatesAllViews(t *testing.T) {
f := &countFetcher{}
g := dataset.NewGroup(f)
g := dataset.NewSet(f)
var aCalls, bCalls int
a := dataset.Add(g, func() (*string, error) {
@ -54,9 +54,9 @@ func TestGroup_LoadPopulatesAllViews(t *testing.T) {
}
}
func TestGroup_SecondLoadSkipsUnchanged(t *testing.T) {
func TestSet_SecondLoadSkipsUnchanged(t *testing.T) {
f := &countFetcher{updated: false}
g := dataset.NewGroup(f)
g := dataset.NewSet(f)
calls := 0
dataset.Add(g, func() (*string, error) {
calls++
@ -77,9 +77,9 @@ func TestGroup_SecondLoadSkipsUnchanged(t *testing.T) {
}
}
func TestGroup_LoadOnUpdateSwaps(t *testing.T) {
func TestSet_LoadOnUpdateSwaps(t *testing.T) {
f := &countFetcher{updated: true}
g := dataset.NewGroup(f)
g := dataset.NewSet(f)
n := 0
v := dataset.Add(g, func() (*int, error) {
n++
@ -96,8 +96,8 @@ func TestGroup_LoadOnUpdateSwaps(t *testing.T) {
}
}
func TestGroup_ValueBeforeLoad(t *testing.T) {
g := dataset.NewGroup(dataset.NopFetcher{})
func TestSet_ValueBeforeLoad(t *testing.T) {
g := dataset.NewSet(dataset.NopFetcher{})
v := dataset.Add(g, func() (*string, error) {
s := "x"
return &s, nil
@ -107,9 +107,9 @@ func TestGroup_ValueBeforeLoad(t *testing.T) {
}
}
func TestGroup_FetchError(t *testing.T) {
func TestSet_FetchError(t *testing.T) {
f := &countFetcher{err: errors.New("offline")}
g := dataset.NewGroup(f)
g := dataset.NewSet(f)
dataset.Add(g, func() (*string, error) {
s := "x"
return &s, nil
@ -119,8 +119,8 @@ func TestGroup_FetchError(t *testing.T) {
}
}
func TestGroup_LoaderError(t *testing.T) {
g := dataset.NewGroup(dataset.NopFetcher{})
func TestSet_LoaderError(t *testing.T) {
g := dataset.NewSet(dataset.NopFetcher{})
dataset.Add(g, func() (*string, error) {
return nil, errors.New("parse fail")
})