mirror of
https://github.com/therootcompany/golib.git
synced 2026-01-27 23:18:05 +00:00
wip: ipcohort: move atomics to gitdataset
This commit is contained in:
parent
9cd08ff2b8
commit
73c323b0f2
@ -3,6 +3,9 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/net/gitdataset"
|
||||||
|
"github.com/therootcompany/golib/net/ipcohort"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -11,25 +14,35 @@ func main() {
|
|||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
path := os.Args[1]
|
dataPath := os.Args[1]
|
||||||
ipStr := os.Args[2]
|
ipStr := os.Args[2]
|
||||||
gitURL := ""
|
gitURL := ""
|
||||||
if len(os.Args) >= 4 {
|
if len(os.Args) >= 4 {
|
||||||
gitURL = os.Args[3]
|
gitURL = os.Args[3]
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
|
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
||||||
|
|
||||||
b := NewBlacklist(gitURL, path)
|
var b *ipcohort.Cohort
|
||||||
|
loadFile := func(path string) (*ipcohort.Cohort, error) {
|
||||||
|
return ipcohort.LoadFile(path, false)
|
||||||
|
}
|
||||||
|
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
|
||||||
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
||||||
if n, err := b.Init(false); err != nil {
|
if updated, err := blacklist.Init(false); err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||||
} else if n > 0 {
|
} else {
|
||||||
|
b = blacklist.Load()
|
||||||
|
if updated {
|
||||||
|
n := b.Size()
|
||||||
|
if n > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
||||||
if b.Contains(ipStr) {
|
if blacklist.Load().Contains(ipStr) {
|
||||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,7 +11,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
||||||
@ -35,8 +34,7 @@ func (r IPv4Net) Contains(ip uint32) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func New() *Cohort {
|
func New() *Cohort {
|
||||||
cohort := &Cohort{}
|
cohort := &Cohort{ranges: []IPv4Net{}}
|
||||||
cohort.Store(&innerCohort{ranges: []IPv4Net{}})
|
|
||||||
return cohort
|
return cohort
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,8 +53,7 @@ func Parse(prefixList []string) (*Cohort, error) {
|
|||||||
copy(sizedList, ranges)
|
copy(sizedList, ranges)
|
||||||
sortRanges(ranges)
|
sortRanges(ranges)
|
||||||
|
|
||||||
cohort := &Cohort{}
|
cohort := &Cohort{ranges: sizedList}
|
||||||
cohort.Store(&innerCohort{ranges: sizedList})
|
|
||||||
return cohort, nil
|
return cohort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,8 +140,7 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
|||||||
sizedList := make([]IPv4Net, len(ranges))
|
sizedList := make([]IPv4Net, len(ranges))
|
||||||
copy(sizedList, ranges)
|
copy(sizedList, ranges)
|
||||||
|
|
||||||
cohort := &Cohort{}
|
cohort := &Cohort{ranges: sizedList}
|
||||||
cohort.Store(&innerCohort{ranges: sizedList})
|
|
||||||
return cohort, nil
|
return cohort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -159,25 +155,14 @@ func sortRanges(ranges []IPv4Net) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Cohort struct {
|
type Cohort struct {
|
||||||
atomic.Pointer[innerCohort]
|
|
||||||
}
|
|
||||||
|
|
||||||
// for ergonomic - so we can access the slice without dereferencing
|
|
||||||
type innerCohort struct {
|
|
||||||
ranges []IPv4Net
|
ranges []IPv4Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cohort) Swap(next *Cohort) {
|
|
||||||
c.Store(next.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cohort) Size() int {
|
func (c *Cohort) Size() int {
|
||||||
return len(c.Load().ranges)
|
return len(c.ranges)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Cohort) Contains(ipStr string) bool {
|
func (c *Cohort) Contains(ipStr string) bool {
|
||||||
cohort := c.Load()
|
|
||||||
|
|
||||||
ip, err := netip.ParseAddr(ipStr)
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true
|
||||||
@ -185,7 +170,7 @@ func (c *Cohort) Contains(ipStr string) bool {
|
|||||||
ip4 := ip.As4()
|
ip4 := ip.As4()
|
||||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||||
|
|
||||||
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||||
if r.networkBE < target {
|
if r.networkBE < target {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
@ -200,7 +185,7 @@ func (c *Cohort) Contains(ipStr string) bool {
|
|||||||
|
|
||||||
// Check the range immediately before the insertion point
|
// Check the range immediately before the insertion point
|
||||||
if idx > 0 {
|
if idx > 0 {
|
||||||
if cohort.ranges[idx-1].Contains(ipU32) {
|
if c.ranges[idx-1].Contains(ipU32) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user