refactor(dataset)!: plumb ctx into loader callbacks

Loader signature changes from func() (*T, error) to
func(context.Context) (*T, error). Set.Load(ctx) already accepts a
ctx; it now flows through reload() into the loader so long-running
parses or downloads can honor ctx.Err() for graceful shutdown.

check-ip's loaders don't consume ctx yet (ipcohort/geoip are
in-memory and fast), but the hook is in place for future work.

BREAKING: dataset.Add and dataset.AddInitial signatures changed.
This commit is contained in:
AJ ONeal 2026-04-20 20:02:48 -06:00
parent 06e6cfa211
commit da2b230bc3
No known key found for this signature in database
2 changed files with 16 additions and 14 deletions

View File

@ -152,19 +152,19 @@ func main() {
repo.MaxAge = refreshInterval repo.MaxAge = refreshInterval
blocklists := dataset.NewSet(repo) blocklists := dataset.NewSet(repo)
asyncServe := cfg.AsyncLoad && cfg.Bind != "" asyncServe := cfg.AsyncLoad && cfg.Bind != ""
addCohort := func(s *dataset.Set, loader func() (*ipcohort.Cohort, error)) *dataset.View[ipcohort.Cohort] { addCohort := func(s *dataset.Set, loader func(context.Context) (*ipcohort.Cohort, error)) *dataset.View[ipcohort.Cohort] {
if asyncServe { if asyncServe {
return dataset.AddInitial(s, ipcohort.New(), loader) return dataset.AddInitial(s, ipcohort.New(), loader)
} }
return dataset.Add(s, loader) return dataset.Add(s, loader)
} }
cfg.inbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) { cfg.inbound = addCohort(blocklists, func(_ context.Context) (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles( return ipcohort.LoadFiles(
repo.FilePath("tables/inbound/single_ips.txt"), repo.FilePath("tables/inbound/single_ips.txt"),
repo.FilePath("tables/inbound/networks.txt"), repo.FilePath("tables/inbound/networks.txt"),
) )
}) })
cfg.outbound = addCohort(blocklists, func() (*ipcohort.Cohort, error) { cfg.outbound = addCohort(blocklists, func(_ context.Context) (*ipcohort.Cohort, error) {
return ipcohort.LoadFiles( return ipcohort.LoadFiles(
repo.FilePath("tables/outbound/single_ips.txt"), repo.FilePath("tables/outbound/single_ips.txt"),
repo.FilePath("tables/outbound/networks.txt"), repo.FilePath("tables/outbound/networks.txt"),
@ -220,7 +220,7 @@ func main() {
Header: authHeader, Header: authHeader,
}, },
) )
cfg.geo = dataset.Add(geoSet, func() (*geoip.Databases, error) { cfg.geo = dataset.Add(geoSet, func(_ context.Context) (*geoip.Databases, error) {
return geoip.Open(maxmindDir) return geoip.Open(maxmindDir)
}) })
fmt.Fprint(os.Stderr, "Loading geoip... ") fmt.Fprint(os.Stderr, "Loading geoip... ")
@ -238,7 +238,7 @@ func main() {
var loadWhitelist func() var loadWhitelist func()
if cfg.WhitelistPath != "" { if cfg.WhitelistPath != "" {
whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath)) whitelistSet = dataset.NewSet(dataset.PollFiles(cfg.WhitelistPath))
cfg.whitelist = addCohort(whitelistSet, func() (*ipcohort.Cohort, error) { cfg.whitelist = addCohort(whitelistSet, func(_ context.Context) (*ipcohort.Cohort, error) {
return ipcohort.LoadFile(cfg.WhitelistPath) return ipcohort.LoadFile(cfg.WhitelistPath)
}) })
loadWhitelist = func() { loadWhitelist = func() {

View File

@ -7,8 +7,8 @@
// Typical lifecycle: // Typical lifecycle:
// //
// s := dataset.NewSet(repo) // *gitshallow.Repo satisfies Fetcher // s := dataset.NewSet(repo) // *gitshallow.Repo satisfies Fetcher
// inbound := dataset.Add(s, func() (*ipcohort.Cohort, error) { ... }) // inbound := dataset.Add(s, func(ctx context.Context) (*ipcohort.Cohort, error) { ... })
// outbound := dataset.Add(s, func() (*ipcohort.Cohort, error) { ... }) // outbound := dataset.Add(s, func(ctx context.Context) (*ipcohort.Cohort, error) { ... })
// if err := s.Load(ctx); err != nil { ... } // initial populate // if err := s.Load(ctx); err != nil { ... } // initial populate
// go s.Tick(ctx, 47*time.Minute, onError) // background refresh // go s.Tick(ctx, 47*time.Minute, onError) // background refresh
// current := inbound.Value() // lock-free read // current := inbound.Value() // lock-free read
@ -95,7 +95,7 @@ type Set struct {
// reloader is a type-erased handle to a View's reload function. // reloader is a type-erased handle to a View's reload function.
type reloader interface { type reloader interface {
reload() error reload(ctx context.Context) error
} }
// NewSet creates a Set backed by fetchers. All fetchers are called on every // NewSet creates a Set backed by fetchers. All fetchers are called on every
@ -129,7 +129,7 @@ func (s *Set) Load(ctx context.Context) error {
if err := ctx.Err(); err != nil { if err := ctx.Err(); err != nil {
return err return err
} }
if err := v.reload(); err != nil { if err := v.reload(ctx); err != nil {
return err return err
} }
} }
@ -157,7 +157,7 @@ func (s *Set) Tick(ctx context.Context, interval time.Duration, onError func(err
// View is a read-only handle to one dataset inside a Set. // View is a read-only handle to one dataset inside a Set.
type View[T any] struct { type View[T any] struct {
loader func() (*T, error) loader func(ctx context.Context) (*T, error)
ptr atomic.Pointer[T] ptr atomic.Pointer[T]
loadedAt atomic.Pointer[time.Time] // nil until first successful reload loadedAt atomic.Pointer[time.Time] // nil until first successful reload
} }
@ -177,8 +177,8 @@ func (v *View[T]) LoadedAt() time.Time {
return time.Time{} return time.Time{}
} }
func (v *View[T]) reload() error { func (v *View[T]) reload(ctx context.Context) error {
t, err := v.loader() t, err := v.loader(ctx)
if err != nil { if err != nil {
return err return err
} }
@ -197,7 +197,9 @@ func (v *View[T]) reload() error {
// Add registers a new view in s and returns it. Call after NewSet and before // Add registers a new view in s and returns it. Call after NewSet and before
// the first Load. View.Value() returns nil until Set.Load succeeds. // the first Load. View.Value() returns nil until Set.Load succeeds.
func Add[T any](s *Set, loader func() (*T, error)) *View[T] { // The loader receives the ctx passed to Set.Load, so long-running parses
// should honor ctx.Err() to support graceful shutdown.
func Add[T any](s *Set, loader func(ctx context.Context) (*T, error)) *View[T] {
v := &View[T]{loader: loader} v := &View[T]{loader: loader}
s.views = append(s.views, v) s.views = append(s.views, v)
return v return v
@ -208,7 +210,7 @@ func Add[T any](s *Set, loader func() (*T, error)) *View[T] {
// Load completes. Use when the initial state is benign (e.g. an empty // Load completes. Use when the initial state is benign (e.g. an empty
// cohort matches nothing) and you want to start serving before the // cohort matches nothing) and you want to start serving before the
// first load finishes. // first load finishes.
func AddInitial[T any](s *Set, initial *T, loader func() (*T, error)) *View[T] { func AddInitial[T any](s *Set, initial *T, loader func(ctx context.Context) (*T, error)) *View[T] {
v := &View[T]{loader: loader} v := &View[T]{loader: loader}
v.ptr.Store(initial) v.ptr.Store(initial)
s.views = append(s.views, v) s.views = append(s.views, v)