mirror of
https://github.com/therootcompany/golib.git
synced 2026-01-27 23:18:05 +00:00
feat: add net/ipcohort (for blacklisting, whitelisting, etc)
This commit is contained in:
parent
8d1354f0da
commit
3f19dd7768
53
net/ipcohort/README.md
Normal file
53
net/ipcohort/README.md
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
# [ipcohort](https://github.com/therootcompany/golib/tree/main/net/ipcohort)
|
||||||
|
|
||||||
|
A memory-efficient, fast IP cohort checker for blacklists, whitelists, and ad cohorts.
|
||||||
|
|
||||||
|
- 6 bytes per IP address (5 + 1 for alignment)
|
||||||
|
- binary search (not as fast as a trie, but memory is linear)
|
||||||
|
- atomic swaps for updates
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Check if an IP address belongs to a cohort (such as a blacklist):
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
ipStr := "92.255.85.72"
|
||||||
|
|
||||||
|
path := "/opt/github.com/bitwire-it/ipblocklist/inbound.txt"
|
||||||
|
unsorted := false
|
||||||
|
|
||||||
|
blacklist, err := ipcohort.LoadFile(path, unsorted)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load blacklist: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if blacklist.Contains(ipStr) {
|
||||||
|
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%s is allowed\n", ipStr)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Update the list periodically:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func backgroundUpdate(path string, c *ipcohort.Cohort) {
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
needsSort := false
|
||||||
|
nextCohort, err := ipcohort.LoadFile(path, needsSort)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("reload failed: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("reloaded %d blacklist entries", c.Size())
|
||||||
|
c.Swap(nextCohort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
119
net/ipcohort/cmd/check-ip-blacklist/main.go
Normal file
119
net/ipcohort/cmd/check-ip-blacklist/main.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/net/gitshallow"
|
||||||
|
"github.com/therootcompany/golib/net/ipcohort"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if len(os.Args) != 3 {
|
||||||
|
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
path := os.Args[1]
|
||||||
|
ipStr := os.Args[2]
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
|
||||||
|
|
||||||
|
gitURL := ""
|
||||||
|
r := NewReloader(gitURL, path)
|
||||||
|
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
||||||
|
if n, err := r.Init(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||||
|
} else if n > 0 {
|
||||||
|
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
||||||
|
if r.Blacklist.Contains(ipStr) {
|
||||||
|
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("%s is allowed\n", ipStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Reloader struct {
|
||||||
|
Blacklist *ipcohort.Cohort
|
||||||
|
gitRepo string
|
||||||
|
shallowRepo *gitshallow.ShallowRepo
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewReloader(gitURL, path string) *Reloader {
|
||||||
|
gitRepo := filepath.Dir(path)
|
||||||
|
gitDepth := 1
|
||||||
|
gitBranch := ""
|
||||||
|
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
|
||||||
|
|
||||||
|
return &Reloader{
|
||||||
|
Blacklist: nil,
|
||||||
|
gitRepo: gitRepo,
|
||||||
|
shallowRepo: shallowRepo,
|
||||||
|
path: path,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Reloader) Init() (int, error) {
|
||||||
|
blacklist, err := ipcohort.LoadFile(r.path, false)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
r.Blacklist = blacklist
|
||||||
|
|
||||||
|
gitDir := filepath.Join(r.gitRepo, ".git")
|
||||||
|
if _, err := os.Stat(gitDir); err != nil {
|
||||||
|
log.Fatalf("Failed to load blacklist: %v", err)
|
||||||
|
fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo)
|
||||||
|
return blacklist.Size(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.reload()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Reloader) Run(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(47 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if n, err := r.reload(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Reloader) reload() (int, error) {
|
||||||
|
laxGC := false
|
||||||
|
lazyPrune := false
|
||||||
|
updated, err := r.shallowRepo.Sync(laxGC, lazyPrune)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("git sync: %w", err)
|
||||||
|
}
|
||||||
|
if !updated {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
needsSort := false
|
||||||
|
nextCohort, err := ipcohort.LoadFile(r.path, needsSort)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("ip cohort: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.Blacklist.Swap(nextCohort)
|
||||||
|
return r.Blacklist.Size(), nil
|
||||||
|
}
|
||||||
172
net/ipcohort/ipcohort.go
Normal file
172
net/ipcohort/ipcohort.go
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
package ipcohort
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/csv"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
||||||
|
type IPv4Net struct {
|
||||||
|
networkBE uint32
|
||||||
|
prefix uint8
|
||||||
|
shift uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
|
||||||
|
return IPv4Net{
|
||||||
|
networkBE: ip4be,
|
||||||
|
prefix: prefix,
|
||||||
|
shift: 32 - prefix,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r IPv4Net) Contains(ip uint32) bool {
|
||||||
|
mask := uint32(0xFFFFFFFF << (r.shift))
|
||||||
|
return (ip & mask) == r.networkBE
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadFile(path string, unsorted bool) (*Cohort, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not load %q: %v", path, err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
return ParseCSV(f, unsorted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
|
||||||
|
r := csv.NewReader(f)
|
||||||
|
r.FieldsPerRecord = -1
|
||||||
|
|
||||||
|
return ReadAll(r, unsorted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||||
|
var ranges []IPv4Net
|
||||||
|
for {
|
||||||
|
record, err := r.Read()
|
||||||
|
if err == io.EOF {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("csv read error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(record) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
raw := strings.TrimSpace(record[0])
|
||||||
|
|
||||||
|
// Skip comments/empty
|
||||||
|
if raw == "" || strings.HasPrefix(raw, "#") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip IPv6
|
||||||
|
if strings.Contains(raw, ":") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ippre netip.Prefix
|
||||||
|
var ip netip.Addr
|
||||||
|
if strings.Contains(raw, "/") {
|
||||||
|
ippre, err = netip.ParsePrefix(raw)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("skipping invalid entry: %q", raw)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ip, err = netip.ParseAddr(raw)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("skipping invalid entry: %q", raw)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
ippre = netip.PrefixFrom(ip, 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
ip4 := ippre.Addr().As4()
|
||||||
|
prefix := uint8(ippre.Bits()) // 0-32
|
||||||
|
ranges = append(ranges, NewIPv4Net(
|
||||||
|
binary.BigEndian.Uint32(ip4[:]),
|
||||||
|
prefix,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
if unsorted {
|
||||||
|
// Sort by network address (required for binary search)
|
||||||
|
sort.Slice(ranges, func(i, j int) bool {
|
||||||
|
// Note: we could also sort by prefix (largest first)
|
||||||
|
return ranges[i].networkBE < ranges[j].networkBE
|
||||||
|
})
|
||||||
|
|
||||||
|
// Note: we could also merge ranges here
|
||||||
|
}
|
||||||
|
|
||||||
|
sizedList := make([]IPv4Net, len(ranges))
|
||||||
|
copy(sizedList, ranges)
|
||||||
|
|
||||||
|
ipList := &Cohort{}
|
||||||
|
ipList.Store(&innerCohort{ranges: sizedList})
|
||||||
|
return ipList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Cohort struct {
|
||||||
|
atomic.Pointer[innerCohort]
|
||||||
|
}
|
||||||
|
|
||||||
|
// for ergonomic - so we can access the slice without dereferencing
|
||||||
|
type innerCohort struct {
|
||||||
|
ranges []IPv4Net
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cohort) Swap(next *Cohort) {
|
||||||
|
c.Store(next.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cohort) Size() int {
|
||||||
|
return len(c.Load().ranges)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cohort) Contains(ipStr string) bool {
|
||||||
|
cohort := c.Load()
|
||||||
|
|
||||||
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip4 := ip.As4()
|
||||||
|
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||||
|
|
||||||
|
idx, found := slices.BinarySearchFunc(cohort.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||||
|
if r.networkBE < target {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if r.networkBE > target {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
if found {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the range immediately before the insertion point
|
||||||
|
if idx > 0 {
|
||||||
|
if cohort.ranges[idx-1].Contains(ipU32) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user