mirror of
https://github.com/therootcompany/golib.git
synced 2026-01-27 15:08:05 +00:00
wip: ipcohort: move atomics to gitdataset
This commit is contained in:
parent
9cd08ff2b8
commit
73c323b0f2
@ -3,6 +3,9 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitdataset"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -11,25 +14,35 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
path := os.Args[1]
|
||||
dataPath := os.Args[1]
|
||||
ipStr := os.Args[2]
|
||||
gitURL := ""
|
||||
if len(os.Args) >= 4 {
|
||||
gitURL = os.Args[3]
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
|
||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
||||
|
||||
b := NewBlacklist(gitURL, path)
|
||||
var b *ipcohort.Cohort
|
||||
loadFile := func(path string) (*ipcohort.Cohort, error) {
|
||||
return ipcohort.LoadFile(path, false)
|
||||
}
|
||||
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
|
||||
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
||||
if n, err := b.Init(false); err != nil {
|
||||
if updated, err := blacklist.Init(false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||
} else if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||
} else {
|
||||
b = blacklist.Load()
|
||||
if updated {
|
||||
n := b.Size()
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
||||
if b.Contains(ipStr) {
|
||||
if blacklist.Load().Contains(ipStr) {
|
||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -11,7 +11,6 @@ import (
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
||||
@ -35,8 +34,7 @@ func (r IPv4Net) Contains(ip uint32) bool {
|
||||
}
|
||||
|
||||
func New() *Cohort {
|
||||
cohort := &Cohort{}
|
||||
cohort.Store(&innerCohort{ranges: []IPv4Net{}})
|
||||
cohort := &Cohort{ranges: []IPv4Net{}}
|
||||
return cohort
|
||||
}
|
||||
|
||||
@ -55,8 +53,7 @@ func Parse(prefixList []string) (*Cohort, error) {
|
||||
copy(sizedList, ranges)
|
||||
sortRanges(ranges)
|
||||
|
||||
cohort := &Cohort{}
|
||||
cohort.Store(&innerCohort{ranges: sizedList})
|
||||
cohort := &Cohort{ranges: sizedList}
|
||||
return cohort, nil
|
||||
}
|
||||
|
||||
@ -143,8 +140,7 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||
sizedList := make([]IPv4Net, len(ranges))
|
||||
copy(sizedList, ranges)
|
||||
|
||||
cohort := &Cohort{}
|
||||
cohort.Store(&innerCohort{ranges: sizedList})
|
||||
cohort := &Cohort{ranges: sizedList}
|
||||
return cohort, nil
|
||||
}
|
||||
|
||||
@ -159,25 +155,14 @@ func sortRanges(ranges []IPv4Net) {
|
||||
}
|
||||
|
||||
type Cohort struct {
|
||||
atomic.Pointer[innerCohort]
|
||||
}
|
||||
|
||||
// for ergonomic - so we can access the slice without dereferencing
|
||||
type innerCohort struct {
|
||||
ranges []IPv4Net
|
||||
}
|
||||
|
||||
func (c *Cohort) Swap(next *Cohort) {
|
||||
c.Store(next.Load())
|
||||
}
|
||||
|
||||
func (c *Cohort) Size() int {
|
||||
return len(c.Load().ranges)
|
||||
return len(c.ranges)
|
||||
}
|
||||
|
||||
func (c *Cohort) Contains(ipStr string) bool {
|
||||
cohort := c.Load()
|
||||
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return true
|
||||
@ -185,7 +170,7 @@ func (c *Cohort) Contains(ipStr string) bool {
|
||||
ip4 := ip.As4()
|
||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||
|
||||
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||
if r.networkBE < target {
|
||||
return -1
|
||||
}
|
||||
@ -200,7 +185,7 @@ func (c *Cohort) Contains(ipStr string) bool {
|
||||
|
||||
// Check the range immediately before the insertion point
|
||||
if idx > 0 {
|
||||
if cohort.ranges[idx-1].Contains(ipU32) {
|
||||
if c.ranges[idx-1].Contains(ipU32) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user