f: ipcohort / blacklist

This commit is contained in:
AJ ONeal 2026-01-22 00:21:58 -07:00
parent 3f19dd7768
commit 1947b91c1d
No known key found for this signature in database
3 changed files with 164 additions and 121 deletions

View File

@ -0,0 +1,87 @@
package main
import (
"context"
"fmt"
"log"
"os"
"path/filepath"
"time"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/ipcohort"
)
type Blacklist struct {
*ipcohort.Cohort
gitRepo string
shallowRepo *gitshallow.ShallowRepo
path string
}
func NewBlacklist(gitURL, path string) *Blacklist {
gitRepo := filepath.Dir(path)
gitDepth := 1
gitBranch := ""
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
return &Blacklist{
Cohort: ipcohort.New(),
gitRepo: gitRepo,
shallowRepo: shallowRepo,
path: path,
}
}
func (b *Blacklist) Init(skipGC bool) (int, error) {
gitDir := filepath.Join(b.gitRepo, ".git")
if _, err := os.Stat(gitDir); err != nil {
if _, err := b.shallowRepo.Clone(); err != nil {
log.Fatalf("Failed to load blacklist: %v", err)
fmt.Printf("%q is not a git repo, skipping sync\n", b.gitRepo)
return b.Size(), nil
}
}
force := true
return b.reload(skipGC, force)
}
func (r Blacklist) Run(ctx context.Context) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if n, err := r.reload(false, false); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
case <-ctx.Done():
return
}
}
}
func (b Blacklist) reload(skipGC, force bool) (int, error) {
laxGC := skipGC
lazyPrune := skipGC
updated, err := b.shallowRepo.Sync(laxGC, lazyPrune)
if err != nil {
return 0, fmt.Errorf("git sync: %w", err)
}
if !updated && !force {
return 0, nil
}
needsSort := false
nextCohort, err := ipcohort.LoadFile(b.path, needsSort)
if err != nil {
return 0, fmt.Errorf("ip cohort: %w", err)
}
b.Swap(nextCohort)
return b.Size(), nil
}

View File

@ -1,119 +1,38 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"log"
"os" "os"
"path/filepath"
"time"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/ipcohort"
) )
func main() { func main() {
if len(os.Args) != 3 { if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
os.Exit(1) os.Exit(1)
} }
path := os.Args[1] path := os.Args[1]
ipStr := os.Args[2] ipStr := os.Args[2]
gitURL := ""
if len(os.Args) >= 4 {
gitURL = os.Args[3]
}
fmt.Fprintf(os.Stderr, "Loading %q ...\n", path) fmt.Fprintf(os.Stderr, "Loading %q ...\n", path)
gitURL := "" b := NewBlacklist(gitURL, path)
r := NewReloader(gitURL, path)
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
if n, err := r.Init(); err != nil { if n, err := b.Init(false); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else if n > 0 { } else if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
} }
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
if r.Blacklist.Contains(ipStr) { if b.Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr) fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1) os.Exit(1)
} }
fmt.Printf("%s is allowed\n", ipStr) fmt.Printf("%s is allowed\n", ipStr)
} }
type Reloader struct {
Blacklist *ipcohort.Cohort
gitRepo string
shallowRepo *gitshallow.ShallowRepo
path string
}
func NewReloader(gitURL, path string) *Reloader {
gitRepo := filepath.Dir(path)
gitDepth := 1
gitBranch := ""
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
return &Reloader{
Blacklist: nil,
gitRepo: gitRepo,
shallowRepo: shallowRepo,
path: path,
}
}
func (r *Reloader) Init() (int, error) {
blacklist, err := ipcohort.LoadFile(r.path, false)
if err != nil {
return 0, err
}
r.Blacklist = blacklist
gitDir := filepath.Join(r.gitRepo, ".git")
if _, err := os.Stat(gitDir); err != nil {
log.Fatalf("Failed to load blacklist: %v", err)
fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo)
return blacklist.Size(), nil
}
return r.reload()
}
func (r Reloader) Run(ctx context.Context) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if n, err := r.reload(); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
} else {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
case <-ctx.Done():
return
}
}
}
func (r Reloader) reload() (int, error) {
laxGC := false
lazyPrune := false
updated, err := r.shallowRepo.Sync(laxGC, lazyPrune)
if err != nil {
return 0, fmt.Errorf("git sync: %w", err)
}
if !updated {
return 0, nil
}
needsSort := false
nextCohort, err := ipcohort.LoadFile(r.path, needsSort)
if err != nil {
return 0, fmt.Errorf("ip cohort: %w", err)
}
r.Blacklist.Swap(nextCohort)
return r.Blacklist.Size(), nil
}

View File

@ -34,6 +34,56 @@ func (r IPv4Net) Contains(ip uint32) bool {
return (ip & mask) == r.networkBE return (ip & mask) == r.networkBE
} }
func New() *Cohort {
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: []IPv4Net{}})
return cohort
}
func Parse(prefixList []string) (*Cohort, error) {
var ranges []IPv4Net
for _, raw := range prefixList {
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
ranges = append(ranges, ipv4net)
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
sortRanges(ranges)
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: sizedList})
return cohort, nil
}
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
var ippre netip.Prefix
var ip netip.Addr
if strings.Contains(raw, "/") {
ippre, err = netip.ParsePrefix(raw)
if err != nil {
return ipv4net, err
}
} else {
ip, err = netip.ParseAddr(raw)
if err != nil {
return ipv4net, err
}
ippre = netip.PrefixFrom(ip, 32)
}
ip4 := ippre.Addr().As4()
prefix := uint8(ippre.Bits()) // 0-32
return NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
), nil
}
func LoadFile(path string, unsorted bool) (*Cohort, error) { func LoadFile(path string, unsorted bool) (*Cohort, error) {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
@ -78,32 +128,27 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
continue continue
} }
var ippre netip.Prefix ipv4net, err := ParseIPv4(raw)
var ip netip.Addr
if strings.Contains(raw, "/") {
ippre, err = netip.ParsePrefix(raw)
if err != nil { if err != nil {
log.Printf("skipping invalid entry: %q", raw) log.Printf("skipping invalid entry: %q", raw)
continue continue
} }
} else { ranges = append(ranges, ipv4net)
ip, err = netip.ParseAddr(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
ippre = netip.PrefixFrom(ip, 32)
}
ip4 := ippre.Addr().As4()
prefix := uint8(ippre.Bits()) // 0-32
ranges = append(ranges, NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
))
} }
if unsorted { if unsorted {
sortRanges(ranges)
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
cohort := &Cohort{}
cohort.Store(&innerCohort{ranges: sizedList})
return cohort, nil
}
func sortRanges(ranges []IPv4Net) {
// Sort by network address (required for binary search) // Sort by network address (required for binary search)
sort.Slice(ranges, func(i, j int) bool { sort.Slice(ranges, func(i, j int) bool {
// Note: we could also sort by prefix (largest first) // Note: we could also sort by prefix (largest first)
@ -111,14 +156,6 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
}) })
// Note: we could also merge ranges here // Note: we could also merge ranges here
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
ipList := &Cohort{}
ipList.Store(&innerCohort{ranges: sizedList})
return ipList, nil
} }
type Cohort struct { type Cohort struct {