From 4895553a918d2bd541ef91bd1f1cfb444516b9e2 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Sun, 19 Apr 2026 23:36:38 -0600 Subject: [PATCH] refactor: move atomic swaps and polling loop into main MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Sources (blacklist.go) now owns only fetch/load logic — no atomic state. main.go holds the three atomic.Pointer[Cohort] vars, calls reload() on startup, and runs the background ticker directly. This makes the dataset pattern (fetch → load → atomic.Store → poll) visible at the call site. --- .../cmd/check-ip-blacklist/blacklist.go | 205 +++++------------- net/ipcohort/cmd/check-ip-blacklist/main.go | 107 ++++++++- 2 files changed, 153 insertions(+), 159 deletions(-) diff --git a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go index 9bf69c4..0c714c8 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go +++ b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go @@ -1,12 +1,7 @@ package main import ( - "context" - "fmt" - "os" "path/filepath" - "sync/atomic" - "time" "github.com/therootcompany/golib/net/gitshallow" "github.com/therootcompany/golib/net/httpcache" @@ -19,13 +14,9 @@ type HTTPSource struct { Path string } -// IPFilter holds up to three cohorts: a whitelist (IPs never blocked), -// an inbound blocklist, and an outbound blocklist. -type IPFilter struct { - whitelist atomic.Pointer[ipcohort.Cohort] - inbound atomic.Pointer[ipcohort.Cohort] - outbound atomic.Pointer[ipcohort.Cohort] - +// Sources holds the configuration for fetching and loading the three cohorts. +// It knows how to pull data from git or HTTP, but owns no atomic state. +type Sources struct { whitelistPaths []string inboundPaths []string outboundPaths []string @@ -35,18 +26,15 @@ type IPFilter struct { httpOutbound []*httpcache.Cacher } -// NewFileFilter loads inbound/outbound/whitelist from local files. -func NewFileFilter(whitelist, inbound, outbound []string) *IPFilter { - return &IPFilter{ +func newFileSources(whitelist, inbound, outbound []string) *Sources { + return &Sources{ whitelistPaths: whitelist, inboundPaths: inbound, outboundPaths: outbound, } } -// NewGitFilter clones/pulls gitURL into repoDir and loads the given relative -// paths for each cohort on each update. -func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []string) *IPFilter { +func newGitSources(gitURL, repoDir string, whitelist, inboundRel, outboundRel []string) *Sources { abs := func(rel []string) []string { out := make([]string, len(rel)) for i, p := range rel { @@ -54,7 +42,7 @@ func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []s } return out } - return &IPFilter{ + return &Sources{ whitelistPaths: whitelist, inboundPaths: abs(inboundRel), outboundPaths: abs(outboundRel), @@ -62,169 +50,86 @@ func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []s } } -// NewHTTPFilter fetches inbound and outbound sources via HTTP; -// whitelist is always loaded from local files. -func NewHTTPFilter(whitelist []string, inbound, outbound []HTTPSource) *IPFilter { - f := &IPFilter{whitelistPaths: whitelist} +func newHTTPSources(whitelist []string, inbound, outbound []HTTPSource) *Sources { + s := &Sources{whitelistPaths: whitelist} for _, src := range inbound { - f.inboundPaths = append(f.inboundPaths, src.Path) - f.httpInbound = append(f.httpInbound, httpcache.New(src.URL, src.Path)) + s.inboundPaths = append(s.inboundPaths, src.Path) + s.httpInbound = append(s.httpInbound, httpcache.New(src.URL, src.Path)) } for _, src := range outbound { - f.outboundPaths = append(f.outboundPaths, src.Path) - f.httpOutbound = append(f.httpOutbound, httpcache.New(src.URL, src.Path)) + s.outboundPaths = append(s.outboundPaths, src.Path) + s.httpOutbound = append(s.httpOutbound, httpcache.New(src.URL, src.Path)) } - return f + return s } -func (f *IPFilter) Init(lightGC bool) error { +// Fetch pulls updates from the remote (git or HTTP). +// Returns whether any new data was received. +func (s *Sources) Fetch(lightGC bool) (bool, error) { switch { - case f.git != nil: - if _, err := f.git.Init(lightGC); err != nil { - return err - } - case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0: - for _, c := range f.httpInbound { - if _, err := c.Fetch(); err != nil { - return err - } - } - for _, c := range f.httpOutbound { - if _, err := c.Fetch(); err != nil { - return err - } - } - } - return f.reloadAll() -} - -func (f *IPFilter) Run(ctx context.Context, lightGC bool) { - ticker := time.NewTicker(47 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - updated, err := f.sync(lightGC) - if err != nil { - fmt.Fprintf(os.Stderr, "error: filter sync: %v\n", err) - } else if updated { - fmt.Fprintf(os.Stderr, "filter: reloaded — inbound=%d outbound=%d\n", - f.InboundSize(), f.OutboundSize()) - } - case <-ctx.Done(): - return - } - } -} - -func (f *IPFilter) sync(lightGC bool) (bool, error) { - switch { - case f.git != nil: - updated, err := f.git.Sync(lightGC) - if err != nil || !updated { - return updated, err - } - return true, f.reloadAll() - case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0: + case s.git != nil: + return s.git.Sync(lightGC) + case len(s.httpInbound) > 0 || len(s.httpOutbound) > 0: var anyUpdated bool - for _, c := range f.httpInbound { + for _, c := range s.httpInbound { updated, err := c.Fetch() if err != nil { return anyUpdated, err } anyUpdated = anyUpdated || updated } - for _, c := range f.httpOutbound { + for _, c := range s.httpOutbound { updated, err := c.Fetch() if err != nil { return anyUpdated, err } anyUpdated = anyUpdated || updated } - if anyUpdated { - return true, f.reloadAll() - } - return false, nil + return anyUpdated, nil default: return false, nil } } -// ContainsInbound reports whether ip is in the inbound blocklist and not whitelisted. -func (f *IPFilter) ContainsInbound(ip string) bool { - if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) { - return false - } - c := f.inbound.Load() - return c != nil && c.Contains(ip) -} - -// ContainsOutbound reports whether ip is in the outbound blocklist and not whitelisted. -func (f *IPFilter) ContainsOutbound(ip string) bool { - if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) { - return false - } - c := f.outbound.Load() - return c != nil && c.Contains(ip) -} - -func (f *IPFilter) InboundSize() int { - if c := f.inbound.Load(); c != nil { - return c.Size() - } - return 0 -} - -func (f *IPFilter) OutboundSize() int { - if c := f.outbound.Load(); c != nil { - return c.Size() - } - return 0 -} - -func (f *IPFilter) reloadAll() error { - if err := f.reloadWhitelist(); err != nil { +// Init ensures the remote is ready (clones if needed, fetches HTTP files). +// Always returns true so the caller knows to load data on startup. +func (s *Sources) Init(lightGC bool) error { + switch { + case s.git != nil: + _, err := s.git.Init(lightGC) return err + case len(s.httpInbound) > 0 || len(s.httpOutbound) > 0: + for _, c := range s.httpInbound { + if _, err := c.Fetch(); err != nil { + return err + } + } + for _, c := range s.httpOutbound { + if _, err := c.Fetch(); err != nil { + return err + } + } } - if err := f.reloadInbound(); err != nil { - return err - } - return f.reloadOutbound() -} - -func (f *IPFilter) reloadWhitelist() error { - if len(f.whitelistPaths) == 0 { - return nil - } - c, err := ipcohort.LoadFiles(f.whitelistPaths...) - if err != nil { - return err - } - f.whitelist.Store(c) return nil } -func (f *IPFilter) reloadInbound() error { - if len(f.inboundPaths) == 0 { - return nil +func (s *Sources) LoadWhitelist() (*ipcohort.Cohort, error) { + if len(s.whitelistPaths) == 0 { + return nil, nil } - c, err := ipcohort.LoadFiles(f.inboundPaths...) - if err != nil { - return err - } - f.inbound.Store(c) - return nil + return ipcohort.LoadFiles(s.whitelistPaths...) } -func (f *IPFilter) reloadOutbound() error { - if len(f.outboundPaths) == 0 { - return nil +func (s *Sources) LoadInbound() (*ipcohort.Cohort, error) { + if len(s.inboundPaths) == 0 { + return nil, nil } - c, err := ipcohort.LoadFiles(f.outboundPaths...) - if err != nil { - return err - } - f.outbound.Store(c) - return nil + return ipcohort.LoadFiles(s.inboundPaths...) +} + +func (s *Sources) LoadOutbound() (*ipcohort.Cohort, error) { + if len(s.outboundPaths) == 0 { + return nil, nil + } + return ipcohort.LoadFiles(s.outboundPaths...) } diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index 51305b9..53de620 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -1,9 +1,14 @@ package main import ( + "context" "fmt" "os" "strings" + "sync/atomic" + "time" + + "github.com/therootcompany/golib/net/ipcohort" ) // inbound blocklist - pre-separated by type for independent ETag caching @@ -34,19 +39,18 @@ func main() { gitURL = os.Args[3] } - var f *IPFilter + var src *Sources switch { case gitURL != "": - f = NewGitFilter(gitURL, dataPath, + src = newGitSources(gitURL, dataPath, nil, []string{"tables/inbound/single_ips.txt", "tables/inbound/networks.txt"}, []string{"tables/outbound/single_ips.txt", "tables/outbound/networks.txt"}, ) case strings.HasSuffix(dataPath, ".txt") || strings.HasSuffix(dataPath, ".csv"): - f = NewFileFilter(nil, []string{dataPath}, nil) + src = newFileSources(nil, []string{dataPath}, nil) default: - // dataPath is a cache directory; fetch the pre-split files via HTTP - f = NewHTTPFilter( + src = newHTTPSources( nil, []HTTPSource{ {inboundSingleURL, dataPath + "/inbound_single_ips.txt"}, @@ -59,15 +63,27 @@ func main() { ) } - if err := f.Init(false); err != nil { + var whitelist, inbound, outbound atomic.Pointer[ipcohort.Cohort] + + if err := src.Init(false); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + if err := reload(src, &whitelist, &inbound, &outbound); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", f.InboundSize(), f.OutboundSize()) + fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", + size(&inbound), size(&outbound)) - blockedInbound := f.ContainsInbound(ipStr) - blockedOutbound := f.ContainsOutbound(ipStr) + // Keep data fresh in the background if running as a daemon. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go run(ctx, src, &whitelist, &inbound, &outbound) + + blockedInbound := containsInbound(ipStr, &whitelist, &inbound) + blockedOutbound := containsOutbound(ipStr, &whitelist, &outbound) switch { case blockedInbound && blockedOutbound: @@ -83,3 +99,76 @@ func main() { fmt.Printf("%s is allowed\n", ipStr) } } + +func reload(src *Sources, + whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], +) error { + if wl, err := src.LoadWhitelist(); err != nil { + return err + } else if wl != nil { + whitelist.Store(wl) + } + if in, err := src.LoadInbound(); err != nil { + return err + } else if in != nil { + inbound.Store(in) + } + if out, err := src.LoadOutbound(); err != nil { + return err + } else if out != nil { + outbound.Store(out) + } + return nil +} + +func run(ctx context.Context, src *Sources, + whitelist, inbound, outbound *atomic.Pointer[ipcohort.Cohort], +) { + ticker := time.NewTicker(47 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + updated, err := src.Fetch(false) + if err != nil { + fmt.Fprintf(os.Stderr, "error: sync: %v\n", err) + continue + } + if !updated { + continue + } + if err := reload(src, whitelist, inbound, outbound); err != nil { + fmt.Fprintf(os.Stderr, "error: reload: %v\n", err) + continue + } + fmt.Fprintf(os.Stderr, "reloaded: inbound=%d outbound=%d\n", + size(inbound), size(outbound)) + case <-ctx.Done(): + return + } + } +} + +func containsInbound(ip string, whitelist, inbound *atomic.Pointer[ipcohort.Cohort]) bool { + if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { + return false + } + c := inbound.Load() + return c != nil && c.Contains(ip) +} + +func containsOutbound(ip string, whitelist, outbound *atomic.Pointer[ipcohort.Cohort]) bool { + if wl := whitelist.Load(); wl != nil && wl.Contains(ip) { + return false + } + c := outbound.Load() + return c != nil && c.Contains(ip) +} + +func size(ptr *atomic.Pointer[ipcohort.Cohort]) int { + if c := ptr.Load(); c != nil { + return c.Size() + } + return 0 +}