mirror of
https://github.com/therootcompany/golib.git
synced 2026-01-27 23:18:05 +00:00
feat: add net/ipcohort (for blacklisting, whitelisting, etc)
This commit is contained in:
parent
8d1354f0da
commit
3f19dd7768
53
net/ipcohort/README.md
Normal file
53
net/ipcohort/README.md
Normal file
@ -0,0 +1,53 @@
|
||||
# [ipcohort](https://github.com/therootcompany/golib/tree/main/net/ipcohort)
|
||||
|
||||
A memory-efficient, fast IP cohort checker for blacklists, whitelists, and ad cohorts.
|
||||
|
||||
- 6 bytes per IP address (5 + 1 for alignment)
|
||||
- binary search (not as fast as a trie, but memory is linear)
|
||||
- atomic swaps for updates
|
||||
|
||||
## Example
|
||||
|
||||
Check if an IP address belongs to a cohort (such as a blacklist):
|
||||
|
||||
```go
|
||||
func main() {
|
||||
ipStr := "92.255.85.72"
|
||||
|
||||
path := "/opt/github.com/bitwire-it/ipblocklist/inbound.txt"
|
||||
unsorted := false
|
||||
|
||||
blacklist, err := ipcohort.LoadFile(path, unsorted)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to load blacklist: %v", err)
|
||||
}
|
||||
|
||||
if blacklist.Contains(ipStr) {
|
||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("%s is allowed\n", ipStr)
|
||||
}
|
||||
```
|
||||
|
||||
Update the list periodically:
|
||||
|
||||
```go
|
||||
func backgroundUpdate(path string, c *ipcohort.Cohort) {
|
||||
ticker := time.NewTicker(1 * time.Hour)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
needsSort := false
|
||||
nextCohort, err := ipcohort.LoadFile(path, needsSort)
|
||||
if err != nil {
|
||||
log.Printf("reload failed: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Printf("reloaded %d blacklist entries", c.Size())
|
||||
c.Swap(nextCohort)
|
||||
}
|
||||
}
|
||||
```
|
||||
119
net/ipcohort/cmd/check-ip-blacklist/main.go
Normal file
119
net/ipcohort/cmd/check-ip-blacklist/main.go
Normal file
@ -0,0 +1,119 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitshallow"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) != 3 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
path := os.Args[1]
|
||||
ipStr := os.Args[2]
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
|
||||
|
||||
gitURL := ""
|
||||
r := NewReloader(gitURL, path)
|
||||
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
||||
if n, err := r.Init(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||
} else if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
||||
if r.Blacklist.Contains(ipStr) {
|
||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("%s is allowed\n", ipStr)
|
||||
}
|
||||
|
||||
type Reloader struct {
|
||||
Blacklist *ipcohort.Cohort
|
||||
gitRepo string
|
||||
shallowRepo *gitshallow.ShallowRepo
|
||||
path string
|
||||
}
|
||||
|
||||
func NewReloader(gitURL, path string) *Reloader {
|
||||
gitRepo := filepath.Dir(path)
|
||||
gitDepth := 1
|
||||
gitBranch := ""
|
||||
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
|
||||
|
||||
return &Reloader{
|
||||
Blacklist: nil,
|
||||
gitRepo: gitRepo,
|
||||
shallowRepo: shallowRepo,
|
||||
path: path,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Reloader) Init() (int, error) {
|
||||
blacklist, err := ipcohort.LoadFile(r.path, false)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.Blacklist = blacklist
|
||||
|
||||
gitDir := filepath.Join(r.gitRepo, ".git")
|
||||
if _, err := os.Stat(gitDir); err != nil {
|
||||
log.Fatalf("Failed to load blacklist: %v", err)
|
||||
fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo)
|
||||
return blacklist.Size(), nil
|
||||
}
|
||||
|
||||
return r.reload()
|
||||
}
|
||||
|
||||
func (r Reloader) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if n, err := r.reload(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r Reloader) reload() (int, error) {
|
||||
laxGC := false
|
||||
lazyPrune := false
|
||||
updated, err := r.shallowRepo.Sync(laxGC, lazyPrune)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("git sync: %w", err)
|
||||
}
|
||||
if !updated {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
needsSort := false
|
||||
nextCohort, err := ipcohort.LoadFile(r.path, needsSort)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("ip cohort: %w", err)
|
||||
}
|
||||
|
||||
r.Blacklist.Swap(nextCohort)
|
||||
return r.Blacklist.Size(), nil
|
||||
}
|
||||
172
net/ipcohort/ipcohort.go
Normal file
172
net/ipcohort/ipcohort.go
Normal file
@ -0,0 +1,172 @@
|
||||
package ipcohort
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
||||
type IPv4Net struct {
|
||||
networkBE uint32
|
||||
prefix uint8
|
||||
shift uint8
|
||||
}
|
||||
|
||||
func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
|
||||
return IPv4Net{
|
||||
networkBE: ip4be,
|
||||
prefix: prefix,
|
||||
shift: 32 - prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (r IPv4Net) Contains(ip uint32) bool {
|
||||
mask := uint32(0xFFFFFFFF << (r.shift))
|
||||
return (ip & mask) == r.networkBE
|
||||
}
|
||||
|
||||
func LoadFile(path string, unsorted bool) (*Cohort, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load %q: %v", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return ParseCSV(f, unsorted)
|
||||
}
|
||||
|
||||
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
|
||||
r := csv.NewReader(f)
|
||||
r.FieldsPerRecord = -1
|
||||
|
||||
return ReadAll(r, unsorted)
|
||||
}
|
||||
|
||||
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||
var ranges []IPv4Net
|
||||
for {
|
||||
record, err := r.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("csv read error: %w", err)
|
||||
}
|
||||
|
||||
if len(record) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
raw := strings.TrimSpace(record[0])
|
||||
|
||||
// Skip comments/empty
|
||||
if raw == "" || strings.HasPrefix(raw, "#") {
|
||||
continue
|
||||
}
|
||||
|
||||
// skip IPv6
|
||||
if strings.Contains(raw, ":") {
|
||||
continue
|
||||
}
|
||||
|
||||
var ippre netip.Prefix
|
||||
var ip netip.Addr
|
||||
if strings.Contains(raw, "/") {
|
||||
ippre, err = netip.ParsePrefix(raw)
|
||||
if err != nil {
|
||||
log.Printf("skipping invalid entry: %q", raw)
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
ip, err = netip.ParseAddr(raw)
|
||||
if err != nil {
|
||||
log.Printf("skipping invalid entry: %q", raw)
|
||||
continue
|
||||
}
|
||||
ippre = netip.PrefixFrom(ip, 32)
|
||||
}
|
||||
|
||||
ip4 := ippre.Addr().As4()
|
||||
prefix := uint8(ippre.Bits()) // 0-32
|
||||
ranges = append(ranges, NewIPv4Net(
|
||||
binary.BigEndian.Uint32(ip4[:]),
|
||||
prefix,
|
||||
))
|
||||
}
|
||||
|
||||
if unsorted {
|
||||
// Sort by network address (required for binary search)
|
||||
sort.Slice(ranges, func(i, j int) bool {
|
||||
// Note: we could also sort by prefix (largest first)
|
||||
return ranges[i].networkBE < ranges[j].networkBE
|
||||
})
|
||||
|
||||
// Note: we could also merge ranges here
|
||||
}
|
||||
|
||||
sizedList := make([]IPv4Net, len(ranges))
|
||||
copy(sizedList, ranges)
|
||||
|
||||
ipList := &Cohort{}
|
||||
ipList.Store(&innerCohort{ranges: sizedList})
|
||||
return ipList, nil
|
||||
}
|
||||
|
||||
type Cohort struct {
|
||||
atomic.Pointer[innerCohort]
|
||||
}
|
||||
|
||||
// for ergonomic - so we can access the slice without dereferencing
|
||||
type innerCohort struct {
|
||||
ranges []IPv4Net
|
||||
}
|
||||
|
||||
func (c *Cohort) Swap(next *Cohort) {
|
||||
c.Store(next.Load())
|
||||
}
|
||||
|
||||
func (c *Cohort) Size() int {
|
||||
return len(c.Load().ranges)
|
||||
}
|
||||
|
||||
func (c *Cohort) Contains(ipStr string) bool {
|
||||
cohort := c.Load()
|
||||
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
ip4 := ip.As4()
|
||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||
|
||||
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||
if r.networkBE < target {
|
||||
return -1
|
||||
}
|
||||
if r.networkBE > target {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
if found {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check the range immediately before the insertion point
|
||||
if idx > 0 {
|
||||
if cohort.ranges[idx-1].Contains(ipU32) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user