From 3f19dd77688068eb112e31f1a3d31d115f66af5c Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Wed, 21 Jan 2026 20:38:14 -0700 Subject: [PATCH] feat: add net/ipcohort (for blacklisting, whitelisting, etc) --- net/ipcohort/README.md | 53 ++++++ net/ipcohort/cmd/check-ip-blacklist/main.go | 119 ++++++++++++++ net/ipcohort/ipcohort.go | 172 ++++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 net/ipcohort/README.md create mode 100644 net/ipcohort/cmd/check-ip-blacklist/main.go create mode 100644 net/ipcohort/ipcohort.go diff --git a/net/ipcohort/README.md b/net/ipcohort/README.md new file mode 100644 index 0000000..8db600f --- /dev/null +++ b/net/ipcohort/README.md @@ -0,0 +1,53 @@ +# [ipcohort](https://github.com/therootcompany/golib/tree/main/net/ipcohort) + +A memory-efficient, fast IP cohort checker for blacklists, whitelists, and ad cohorts. + +- 6 bytes per IP address (5 + 1 for alignment) +- binary search (not as fast as a trie, but memory is linear) +- atomic swaps for updates + +## Example + +Check if an IP address belongs to a cohort (such as a blacklist): + +```go +func main() { + ipStr := "92.255.85.72" + + path := "/opt/github.com/bitwire-it/ipblocklist/inbound.txt" + unsorted := false + + blacklist, err := ipcohort.LoadFile(path, unsorted) + if err != nil { + log.Fatalf("Failed to load blacklist: %v", err) + } + + if blacklist.Contains(ipStr) { + fmt.Printf("%s is BLOCKED\n", ipStr) + os.Exit(1) + } + + fmt.Printf("%s is allowed\n", ipStr) +} +``` + +Update the list periodically: + +```go +func backgroundUpdate(path string, c *ipcohort.Cohort) { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for range ticker.C { + needsSort := false + nextCohort, err := ipcohort.LoadFile(path, needsSort) + if err != nil { + log.Printf("reload failed: %v", err) + continue + } + + log.Printf("reloaded %d blacklist entries", c.Size()) + c.Swap(nextCohort) + } +} +``` diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go new file mode 100644 index 0000000..ce0ef59 --- /dev/null +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -0,0 +1,119 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "github.com/therootcompany/golib/net/gitshallow" + "github.com/therootcompany/golib/net/ipcohort" +) + +func main() { + if len(os.Args) != 3 { + fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) + os.Exit(1) + } + + path := os.Args[1] + ipStr := os.Args[2] + + fmt.Fprintf(os.Stderr, "Loading %q ...\n", path) + + gitURL := "" + r := NewReloader(gitURL, path) + fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") + if n, err := r.Init(); err != nil { + fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) + } else if n > 0 { + fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } + + fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") + if r.Blacklist.Contains(ipStr) { + fmt.Printf("%s is BLOCKED\n", ipStr) + os.Exit(1) + } + + fmt.Printf("%s is allowed\n", ipStr) +} + +type Reloader struct { + Blacklist *ipcohort.Cohort + gitRepo string + shallowRepo *gitshallow.ShallowRepo + path string +} + +func NewReloader(gitURL, path string) *Reloader { + gitRepo := filepath.Dir(path) + gitDepth := 1 + gitBranch := "" + shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch) + + return &Reloader{ + Blacklist: nil, + gitRepo: gitRepo, + shallowRepo: shallowRepo, + path: path, + } +} + +func (r *Reloader) Init() (int, error) { + blacklist, err := ipcohort.LoadFile(r.path, false) + if err != nil { + return 0, err + } + r.Blacklist = blacklist + + gitDir := filepath.Join(r.gitRepo, ".git") + if _, err := os.Stat(gitDir); err != nil { + log.Fatalf("Failed to load blacklist: %v", err) + fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo) + return blacklist.Size(), nil + } + + return r.reload() +} + +func (r Reloader) Run(ctx context.Context) { + ticker := time.NewTicker(47 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if n, err := r.reload(); err != nil { + fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } + case <-ctx.Done(): + return + } + } +} + +func (r Reloader) reload() (int, error) { + laxGC := false + lazyPrune := false + updated, err := r.shallowRepo.Sync(laxGC, lazyPrune) + if err != nil { + return 0, fmt.Errorf("git sync: %w", err) + } + if !updated { + return 0, nil + } + + needsSort := false + nextCohort, err := ipcohort.LoadFile(r.path, needsSort) + if err != nil { + return 0, fmt.Errorf("ip cohort: %w", err) + } + + r.Blacklist.Swap(nextCohort) + return r.Blacklist.Size(), nil +} diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go new file mode 100644 index 0000000..c89db37 --- /dev/null +++ b/net/ipcohort/ipcohort.go @@ -0,0 +1,172 @@ +package ipcohort + +import ( + "encoding/binary" + "encoding/csv" + "fmt" + "io" + "log" + "net/netip" + "os" + "slices" + "sort" + "strings" + "sync/atomic" +) + +// Either a subnet or single address (subnet with /32 CIDR prefix) +type IPv4Net struct { + networkBE uint32 + prefix uint8 + shift uint8 +} + +func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net { + return IPv4Net{ + networkBE: ip4be, + prefix: prefix, + shift: 32 - prefix, + } +} + +func (r IPv4Net) Contains(ip uint32) bool { + mask := uint32(0xFFFFFFFF << (r.shift)) + return (ip & mask) == r.networkBE +} + +func LoadFile(path string, unsorted bool) (*Cohort, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("could not load %q: %v", path, err) + } + defer f.Close() + + return ParseCSV(f, unsorted) +} + +func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) { + r := csv.NewReader(f) + r.FieldsPerRecord = -1 + + return ReadAll(r, unsorted) +} + +func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { + var ranges []IPv4Net + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, fmt.Errorf("csv read error: %w", err) + } + + if len(record) == 0 { + continue + } + + raw := strings.TrimSpace(record[0]) + + // Skip comments/empty + if raw == "" || strings.HasPrefix(raw, "#") { + continue + } + + // skip IPv6 + if strings.Contains(raw, ":") { + continue + } + + var ippre netip.Prefix + var ip netip.Addr + if strings.Contains(raw, "/") { + ippre, err = netip.ParsePrefix(raw) + if err != nil { + log.Printf("skipping invalid entry: %q", raw) + continue + } + } else { + ip, err = netip.ParseAddr(raw) + if err != nil { + log.Printf("skipping invalid entry: %q", raw) + continue + } + ippre = netip.PrefixFrom(ip, 32) + } + + ip4 := ippre.Addr().As4() + prefix := uint8(ippre.Bits()) // 0-32 + ranges = append(ranges, NewIPv4Net( + binary.BigEndian.Uint32(ip4[:]), + prefix, + )) + } + + if unsorted { + // Sort by network address (required for binary search) + sort.Slice(ranges, func(i, j int) bool { + // Note: we could also sort by prefix (largest first) + return ranges[i].networkBE < ranges[j].networkBE + }) + + // Note: we could also merge ranges here + } + + sizedList := make([]IPv4Net, len(ranges)) + copy(sizedList, ranges) + + ipList := &Cohort{} + ipList.Store(&innerCohort{ranges: sizedList}) + return ipList, nil +} + +type Cohort struct { + atomic.Pointer[innerCohort] +} + +// for ergonomic - so we can access the slice without dereferencing +type innerCohort struct { + ranges []IPv4Net +} + +func (c *Cohort) Swap(next *Cohort) { + c.Store(next.Load()) +} + +func (c *Cohort) Size() int { + return len(c.Load().ranges) +} + +func (c *Cohort) Contains(ipStr string) bool { + cohort := c.Load() + + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return true + } + ip4 := ip.As4() + ipU32 := binary.BigEndian.Uint32(ip4[:]) + + idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int { + if r.networkBE < target { + return -1 + } + if r.networkBE > target { + return 1 + } + return 0 + }) + if found { + return true + } + + // Check the range immediately before the insertion point + if idx > 0 { + if cohort.ranges[idx-1].Contains(ipU32) { + return true + } + } + + return false +}