From 73c323b0f2665d9d9f3190e59774809adc085fa1 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 22 Jan 2026 02:20:49 -0700 Subject: [PATCH] wip: ipcohort: move atomics to gitdataset --- net/ipcohort/cmd/check-ip-blacklist/main.go | 27 +++++++++++++++------ net/ipcohort/ipcohort.go | 27 +++++---------------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index c646db3..eea9f6d 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -3,6 +3,9 @@ package main import ( "fmt" "os" + + "github.com/therootcompany/golib/net/gitdataset" + "github.com/therootcompany/golib/net/ipcohort" ) func main() { @@ -11,25 +14,35 @@ func main() { os.Exit(1) } - path := os.Args[1] + dataPath := os.Args[1] ipStr := os.Args[2] gitURL := "" if len(os.Args) >= 4 { gitURL = os.Args[3] } - fmt.Fprintf(os.Stderr, "Loading %q ...\n", path) + fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath) - b := NewBlacklist(gitURL, path) + var b *ipcohort.Cohort + loadFile := func(path string) (*ipcohort.Cohort, error) { + return ipcohort.LoadFile(path, false) + } + blacklist := gitdataset.New(gitURL, dataPath, loadFile) fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") - if n, err := b.Init(false); err != nil { + if updated, err := blacklist.Init(false); err != nil { fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) - } else if n > 0 { - fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } else { + b = blacklist.Load() + if updated { + n := b.Size() + if n > 0 { + fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } + } } fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") - if b.Contains(ipStr) { + if blacklist.Load().Contains(ipStr) { fmt.Printf("%s is BLOCKED\n", ipStr) os.Exit(1) } diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go index f48b1e9..fb860a9 100644 --- a/net/ipcohort/ipcohort.go +++ b/net/ipcohort/ipcohort.go @@ -11,7 +11,6 @@ import ( "slices" "sort" "strings" - "sync/atomic" ) // Either a subnet or single address (subnet with /32 CIDR prefix) @@ -35,8 +34,7 @@ func (r IPv4Net) Contains(ip uint32) bool { } func New() *Cohort { - cohort := &Cohort{} - cohort.Store(&innerCohort{ranges: []IPv4Net{}}) + cohort := &Cohort{ranges: []IPv4Net{}} return cohort } @@ -55,8 +53,7 @@ func Parse(prefixList []string) (*Cohort, error) { copy(sizedList, ranges) sortRanges(ranges) - cohort := &Cohort{} - cohort.Store(&innerCohort{ranges: sizedList}) + cohort := &Cohort{ranges: sizedList} return cohort, nil } @@ -143,8 +140,7 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { sizedList := make([]IPv4Net, len(ranges)) copy(sizedList, ranges) - cohort := &Cohort{} - cohort.Store(&innerCohort{ranges: sizedList}) + cohort := &Cohort{ranges: sizedList} return cohort, nil } @@ -159,25 +155,14 @@ func sortRanges(ranges []IPv4Net) { } type Cohort struct { - atomic.Pointer[innerCohort] -} - -// for ergonomic - so we can access the slice without dereferencing -type innerCohort struct { ranges []IPv4Net } -func (c *Cohort) Swap(next *Cohort) { - c.Store(next.Load()) -} - func (c *Cohort) Size() int { - return len(c.Load().ranges) + return len(c.ranges) } func (c *Cohort) Contains(ipStr string) bool { - cohort := c.Load() - ip, err := netip.ParseAddr(ipStr) if err != nil { return true @@ -185,7 +170,7 @@ func (c *Cohort) Contains(ipStr string) bool { ip4 := ip.As4() ipU32 := binary.BigEndian.Uint32(ip4[:]) - idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int { + idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int { if r.networkBE < target { return -1 } @@ -200,7 +185,7 @@ func (c *Cohort) Contains(ipStr string) bool { // Check the range immediately before the insertion point if idx > 0 { - if cohort.ranges[idx-1].Contains(ipU32) { + if c.ranges[idx-1].Contains(ipU32) { return true } }