wip: ipcohort: move atomics to gitdataset

This commit is contained in:
AJ ONeal 2026-01-22 02:20:49 -07:00
parent 9cd08ff2b8
commit 73c323b0f2
No known key found for this signature in database
2 changed files with 26 additions and 28 deletions

View File

@ -3,6 +3,9 @@ package main
import ( import (
"fmt" "fmt"
"os" "os"
"github.com/therootcompany/golib/net/gitdataset"
"github.com/therootcompany/golib/net/ipcohort"
) )
func main() { func main() {
@ -11,25 +14,35 @@ func main() {
os.Exit(1) os.Exit(1)
} }
path := os.Args[1] dataPath := os.Args[1]
ipStr := os.Args[2] ipStr := os.Args[2]
gitURL := "" gitURL := ""
if len(os.Args) >= 4 { if len(os.Args) >= 4 {
gitURL = os.Args[3] gitURL = os.Args[3]
} }
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path) fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
b := NewBlacklist(gitURL, path) var b *ipcohort.Cohort
loadFile := func(path string) (*ipcohort.Cohort, error) {
return ipcohort.LoadFile(path, false)
}
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
if n, err := b.Init(false); err != nil { if updated, err := blacklist.Init(false); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else if n > 0 { } else {
b = blacklist.Load()
if updated {
n := b.Size()
if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
} }
}
}
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
if b.Contains(ipStr) { if blacklist.Load().Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr) fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1) os.Exit(1)
} }

View File

@ -11,7 +11,6 @@ import (
"slices" "slices"
"sort" "sort"
"strings" "strings"
"sync/atomic"
) )
// Either a subnet or single address (subnet with /32 CIDR prefix) // Either a subnet or single address (subnet with /32 CIDR prefix)
@ -35,8 +34,7 @@ func (r IPv4Net) Contains(ip uint32) bool {
} }
func New() *Cohort { func New() *Cohort {
cohort := &Cohort{} cohort := &Cohort{ranges: []IPv4Net{}}
cohort.Store(&innerCohort{ranges: []IPv4Net{}})
return cohort return cohort
} }
@ -55,8 +53,7 @@ func Parse(prefixList []string) (*Cohort, error) {
copy(sizedList, ranges) copy(sizedList, ranges)
sortRanges(ranges) sortRanges(ranges)
cohort := &Cohort{} cohort := &Cohort{ranges: sizedList}
cohort.Store(&innerCohort{ranges: sizedList})
return cohort, nil return cohort, nil
} }
@ -143,8 +140,7 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
sizedList := make([]IPv4Net, len(ranges)) sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges) copy(sizedList, ranges)
cohort := &Cohort{} cohort := &Cohort{ranges: sizedList}
cohort.Store(&innerCohort{ranges: sizedList})
return cohort, nil return cohort, nil
} }
@ -159,25 +155,14 @@ func sortRanges(ranges []IPv4Net) {
} }
type Cohort struct { type Cohort struct {
atomic.Pointer[innerCohort]
}
// for ergonomic - so we can access the slice without dereferencing
type innerCohort struct {
ranges []IPv4Net ranges []IPv4Net
} }
func (c *Cohort) Swap(next *Cohort) {
c.Store(next.Load())
}
func (c *Cohort) Size() int { func (c *Cohort) Size() int {
return len(c.Load().ranges) return len(c.ranges)
} }
func (c *Cohort) Contains(ipStr string) bool { func (c *Cohort) Contains(ipStr string) bool {
cohort := c.Load()
ip, err := netip.ParseAddr(ipStr) ip, err := netip.ParseAddr(ipStr)
if err != nil { if err != nil {
return true return true
@ -185,7 +170,7 @@ func (c *Cohort) Contains(ipStr string) bool {
ip4 := ip.As4() ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:]) ipU32 := binary.BigEndian.Uint32(ip4[:])
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int { idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
if r.networkBE < target { if r.networkBE < target {
return -1 return -1
} }
@ -200,7 +185,7 @@ func (c *Cohort) Contains(ipStr string) bool {
// Check the range immediately before the insertion point // Check the range immediately before the insertion point
if idx > 0 { if idx > 0 {
if cohort.ranges[idx-1].Contains(ipU32) { if c.ranges[idx-1].Contains(ipU32) {
return true return true
} }
} }