wip: ipcohort: move atomics to gitdataset

This commit is contained in:
AJ ONeal 2026-01-22 02:20:49 -07:00
parent 9cd08ff2b8
commit 73c323b0f2
No known key found for this signature in database
2 changed files with 26 additions and 28 deletions

View File

@ -3,6 +3,9 @@ package main
import (
"fmt"
"os"
"github.com/therootcompany/golib/net/gitdataset"
"github.com/therootcompany/golib/net/ipcohort"
)
func main() {
@ -11,25 +14,35 @@ func main() {
os.Exit(1)
}
path := os.Args[1]
dataPath := os.Args[1]
ipStr := os.Args[2]
gitURL := ""
if len(os.Args) >= 4 {
gitURL = os.Args[3]
}
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
b := NewBlacklist(gitURL, path)
var b *ipcohort.Cohort
loadFile := func(path string) (*ipcohort.Cohort, error) {
return ipcohort.LoadFile(path, false)
}
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
if n, err := b.Init(false); err != nil {
if updated, err := blacklist.Init(false); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
} else {
b = blacklist.Load()
if updated {
n := b.Size()
if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
}
}
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
if b.Contains(ipStr) {
if blacklist.Load().Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1)
}

View File

@ -11,7 +11,6 @@ import (
"slices"
"sort"
"strings"
"sync/atomic"
)
// Either a subnet or single address (subnet with /32 CIDR prefix)
@ -35,8 +34,7 @@ func (r IPv4Net) Contains(ip uint32) bool {
}
func New() *Cohort {
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: []IPv4Net{}})
cohort := &Cohort{ranges: []IPv4Net{}}
return cohort
}
@ -55,8 +53,7 @@ func Parse(prefixList []string) (*Cohort, error) {
copy(sizedList, ranges)
sortRanges(ranges)
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: sizedList})
cohort := &Cohort{ranges: sizedList}
return cohort, nil
}
@ -143,8 +140,7 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: sizedList})
cohort := &Cohort{ranges: sizedList}
return cohort, nil
}
@ -159,25 +155,14 @@ func sortRanges(ranges []IPv4Net) {
}
type Cohort struct {
atomic.Pointer[innerCohort]
}
// for ergonomic - so we can access the slice without dereferencing
type innerCohort struct {
ranges []IPv4Net
}
func (c *Cohort) Swap(next *Cohort) {
c.Store(next.Load())
}
func (c *Cohort) Size() int {
return len(c.Load().ranges)
return len(c.ranges)
}
func (c *Cohort) Contains(ipStr string) bool {
cohort := c.Load()
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true
@ -185,7 +170,7 @@ func (c *Cohort) Contains(ipStr string) bool {
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
if r.networkBE < target {
return -1
}
@ -200,7 +185,7 @@ func (c *Cohort) Contains(ipStr string) bool {
// Check the range immediately before the insertion point
if idx > 0 {
if cohort.ranges[idx-1].Contains(ipU32) {
if c.ranges[idx-1].Contains(ipU32) {
return true
}
}