feat: add net/ipcohort (for blacklisting, whitelisting, etc)

This commit is contained in:
AJ ONeal 2026-01-21 20:38:14 -07:00
parent 8d1354f0da
commit 3f19dd7768
No known key found for this signature in database
3 changed files with 344 additions and 0 deletions

53
net/ipcohort/README.md Normal file
View File

@ -0,0 +1,53 @@
# [ipcohort](https://github.com/therootcompany/golib/tree/main/net/ipcohort)
A memory-efficient, fast IP cohort checker for blacklists, whitelists, and ad cohorts.
- 6 bytes per IP address (5 + 1 for alignment)
- binary search (not as fast as a trie, but memory is linear)
- atomic swaps for updates
## Example
Check if an IP address belongs to a cohort (such as a blacklist):
```go
func main() {
ipStr := "92.255.85.72"
path := "/opt/github.com/bitwire-it/ipblocklist/inbound.txt"
unsorted := false
blacklist, err := ipcohort.LoadFile(path, unsorted)
if err != nil {
log.Fatalf("Failed to load blacklist: %v", err)
}
if blacklist.Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1)
}
fmt.Printf("%s is allowed\n", ipStr)
}
```
Update the list periodically:
```go
func backgroundUpdate(path string, c *ipcohort.Cohort) {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
needsSort := false
nextCohort, err := ipcohort.LoadFile(path, needsSort)
if err != nil {
log.Printf("reload failed: %v", err)
continue
}
log.Printf("reloaded %d blacklist entries", c.Size())
c.Swap(nextCohort)
}
}
```

View File

@ -0,0 +1,119 @@
package main
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/ipcohort"
)
func main() {
if len(os.Args) != 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
os.Exit(1)
}
path := os.Args[1]
ipStr := os.Args[2]
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
gitURL := ""
r := NewReloader(gitURL, path)
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
if n, err := r.Init(); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
if r.Blacklist.Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1)
}
fmt.Printf("%s is allowed\n", ipStr)
}
type Reloader struct {
Blacklist *ipcohort.Cohort
gitRepo string
shallowRepo *gitshallow.ShallowRepo
path string
}
func NewReloader(gitURL, path string) *Reloader {
gitRepo := filepath.Dir(path)
gitDepth := 1
gitBranch := ""
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
return &Reloader{
Blacklist: nil,
gitRepo: gitRepo,
shallowRepo: shallowRepo,
path: path,
}
}
func (r *Reloader) Init() (int, error) {
blacklist, err := ipcohort.LoadFile(r.path, false)
if err != nil {
return 0, err
}
r.Blacklist = blacklist
gitDir := filepath.Join(r.gitRepo, ".git")
if _, err := os.Stat(gitDir); err != nil {
log.Fatalf("Failed to load blacklist: %v", err)
fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo)
return blacklist.Size(), nil
}
return r.reload()
}
func (r Reloader) Run(ctx context.Context) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if n, err := r.reload(); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
case <-ctx.Done():
return
}
}
}
func (r Reloader) reload() (int, error) {
laxGC := false
lazyPrune := false
updated, err := r.shallowRepo.Sync(laxGC, lazyPrune)
if err != nil {
return 0, fmt.Errorf("git sync: %w", err)
}
if !updated {
return 0, nil
}
needsSort := false
nextCohort, err := ipcohort.LoadFile(r.path, needsSort)
if err != nil {
return 0, fmt.Errorf("ip cohort: %w", err)
}
r.Blacklist.Swap(nextCohort)
return r.Blacklist.Size(), nil
}

172
net/ipcohort/ipcohort.go Normal file
View File

@ -0,0 +1,172 @@
package ipcohort
import (
"encoding/binary"
"encoding/csv"
"fmt"
"io"
"log"
"net/netip"
"os"
"slices"
"sort"
"strings"
"sync/atomic"
)
// Either a subnet or single address (subnet with /32 CIDR prefix)
type IPv4Net struct {
networkBE uint32
prefix uint8
shift uint8
}
func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
return IPv4Net{
networkBE: ip4be,
prefix: prefix,
shift: 32 - prefix,
}
}
func (r IPv4Net) Contains(ip uint32) bool {
mask := uint32(0xFFFFFFFF << (r.shift))
return (ip & mask) == r.networkBE
}
func LoadFile(path string, unsorted bool) (*Cohort, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("could not load %q: %v", path, err)
}
defer f.Close()
return ParseCSV(f, unsorted)
}
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
r := csv.NewReader(f)
r.FieldsPerRecord = -1
return ReadAll(r, unsorted)
}
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
var ranges []IPv4Net
for {
record, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("csv read error: %w", err)
}
if len(record) == 0 {
continue
}
raw := strings.TrimSpace(record[0])
// Skip comments/empty
if raw == "" || strings.HasPrefix(raw, "#") {
continue
}
// skip IPv6
if strings.Contains(raw, ":") {
continue
}
var ippre netip.Prefix
var ip netip.Addr
if strings.Contains(raw, "/") {
ippre, err = netip.ParsePrefix(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
} else {
ip, err = netip.ParseAddr(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
ippre = netip.PrefixFrom(ip, 32)
}
ip4 := ippre.Addr().As4()
prefix := uint8(ippre.Bits()) // 0-32
ranges = append(ranges, NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
))
}
if unsorted {
// Sort by network address (required for binary search)
sort.Slice(ranges, func(i, j int) bool {
// Note: we could also sort by prefix (largest first)
return ranges[i].networkBE < ranges[j].networkBE
})
// Note: we could also merge ranges here
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
ipList := &Cohort{}
ipList.Store(&innerCohort{ranges: sizedList})
return ipList, nil
}
type Cohort struct {
atomic.Pointer[innerCohort]
}
// for ergonomic - so we can access the slice without dereferencing
type innerCohort struct {
ranges []IPv4Net
}
func (c *Cohort) Swap(next *Cohort) {
c.Store(next.Load())
}
func (c *Cohort) Size() int {
return len(c.Load().ranges)
}
func (c *Cohort) Contains(ipStr string) bool {
cohort := c.Load()
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
if r.networkBE < target {
return -1
}
if r.networkBE > target {
return 1
}
return 0
})
if found {
return true
}
// Check the range immediately before the insertion point
if idx > 0 {
if cohort.ranges[idx-1].Contains(ipU32) {
return true
}
}
return false
}