From 1947b91c1d75707aa829e0a8c469850dbb707c26 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Thu, 22 Jan 2026 00:21:58 -0700 Subject: [PATCH] f: ipcohort / blacklist --- .../cmd/check-ip-blacklist/blacklist.go | 87 +++++++++++++++ net/ipcohort/cmd/check-ip-blacklist/main.go | 97 ++--------------- net/ipcohort/ipcohort.go | 101 ++++++++++++------ 3 files changed, 164 insertions(+), 121 deletions(-) create mode 100644 net/ipcohort/cmd/check-ip-blacklist/blacklist.go diff --git a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go new file mode 100644 index 0000000..4b4b6d4 --- /dev/null +++ b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go @@ -0,0 +1,87 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "github.com/therootcompany/golib/net/gitshallow" + "github.com/therootcompany/golib/net/ipcohort" +) + +type Blacklist struct { + *ipcohort.Cohort + gitRepo string + shallowRepo *gitshallow.ShallowRepo + path string +} + +func NewBlacklist(gitURL, path string) *Blacklist { + gitRepo := filepath.Dir(path) + gitDepth := 1 + gitBranch := "" + shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch) + + return &Blacklist{ + Cohort: ipcohort.New(), + gitRepo: gitRepo, + shallowRepo: shallowRepo, + path: path, + } +} + +func (b *Blacklist) Init(skipGC bool) (int, error) { + gitDir := filepath.Join(b.gitRepo, ".git") + if _, err := os.Stat(gitDir); err != nil { + if _, err := b.shallowRepo.Clone(); err != nil { + log.Fatalf("Failed to load blacklist: %v", err) + fmt.Printf("%q is not a git repo, skipping sync\n", b.gitRepo) + return b.Size(), nil + } + } + + force := true + return b.reload(skipGC, force) +} + +func (r Blacklist) Run(ctx context.Context) { + ticker := time.NewTicker(47 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if n, err := r.reload(false, false); err != nil { + fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) + } else { + fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) + } + case <-ctx.Done(): + return + } + } +} + +func (b Blacklist) reload(skipGC, force bool) (int, error) { + laxGC := skipGC + lazyPrune := skipGC + updated, err := b.shallowRepo.Sync(laxGC, lazyPrune) + if err != nil { + return 0, fmt.Errorf("git sync: %w", err) + } + if !updated && !force { + return 0, nil + } + + needsSort := false + nextCohort, err := ipcohort.LoadFile(b.path, needsSort) + if err != nil { + return 0, fmt.Errorf("ip cohort: %w", err) + } + + b.Swap(nextCohort) + return b.Size(), nil +} diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index ce0ef59..c646db3 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -1,119 +1,38 @@ package main import ( - "context" "fmt" - "log" "os" - "path/filepath" - "time" - - "github.com/therootcompany/golib/net/gitshallow" - "github.com/therootcompany/golib/net/ipcohort" ) func main() { - if len(os.Args) != 3 { + if len(os.Args) < 3 { fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) os.Exit(1) } path := os.Args[1] ipStr := os.Args[2] + gitURL := "" + if len(os.Args) >= 4 { + gitURL = os.Args[3] + } fmt.Fprintf(os.Stderr, "Loading %q ...\n", path) - gitURL := "" - r := NewReloader(gitURL, path) + b := NewBlacklist(gitURL, path) fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") - if n, err := r.Init(); err != nil { + if n, err := b.Init(false); err != nil { fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) } else if n > 0 { fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) } fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") - if r.Blacklist.Contains(ipStr) { + if b.Contains(ipStr) { fmt.Printf("%s is BLOCKED\n", ipStr) os.Exit(1) } fmt.Printf("%s is allowed\n", ipStr) } - -type Reloader struct { - Blacklist *ipcohort.Cohort - gitRepo string - shallowRepo *gitshallow.ShallowRepo - path string -} - -func NewReloader(gitURL, path string) *Reloader { - gitRepo := filepath.Dir(path) - gitDepth := 1 - gitBranch := "" - shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch) - - return &Reloader{ - Blacklist: nil, - gitRepo: gitRepo, - shallowRepo: shallowRepo, - path: path, - } -} - -func (r *Reloader) Init() (int, error) { - blacklist, err := ipcohort.LoadFile(r.path, false) - if err != nil { - return 0, err - } - r.Blacklist = blacklist - - gitDir := filepath.Join(r.gitRepo, ".git") - if _, err := os.Stat(gitDir); err != nil { - log.Fatalf("Failed to load blacklist: %v", err) - fmt.Printf("%q is not a git repo, skipping sync\n", r.gitRepo) - return blacklist.Size(), nil - } - - return r.reload() -} - -func (r Reloader) Run(ctx context.Context) { - ticker := time.NewTicker(47 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if n, err := r.reload(); err != nil { - fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) - } else { - fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) - } - case <-ctx.Done(): - return - } - } -} - -func (r Reloader) reload() (int, error) { - laxGC := false - lazyPrune := false - updated, err := r.shallowRepo.Sync(laxGC, lazyPrune) - if err != nil { - return 0, fmt.Errorf("git sync: %w", err) - } - if !updated { - return 0, nil - } - - needsSort := false - nextCohort, err := ipcohort.LoadFile(r.path, needsSort) - if err != nil { - return 0, fmt.Errorf("ip cohort: %w", err) - } - - r.Blacklist.Swap(nextCohort) - return r.Blacklist.Size(), nil -} diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go index c89db37..f48b1e9 100644 --- a/net/ipcohort/ipcohort.go +++ b/net/ipcohort/ipcohort.go @@ -34,6 +34,56 @@ func (r IPv4Net) Contains(ip uint32) bool { return (ip & mask) == r.networkBE } +func New() *Cohort { + cohort := &Cohort{} + cohort.Store(&innerCohort{ranges: []IPv4Net{}}) + return cohort +} + +func Parse(prefixList []string) (*Cohort, error) { + var ranges []IPv4Net + for _, raw := range prefixList { + ipv4net, err := ParseIPv4(raw) + if err != nil { + log.Printf("skipping invalid entry: %q", raw) + continue + } + ranges = append(ranges, ipv4net) + } + + sizedList := make([]IPv4Net, len(ranges)) + copy(sizedList, ranges) + sortRanges(ranges) + + cohort := &Cohort{} + cohort.Store(&innerCohort{ranges: sizedList}) + return cohort, nil +} + +func ParseIPv4(raw string) (ipv4net IPv4Net, err error) { + var ippre netip.Prefix + var ip netip.Addr + if strings.Contains(raw, "/") { + ippre, err = netip.ParsePrefix(raw) + if err != nil { + return ipv4net, err + } + } else { + ip, err = netip.ParseAddr(raw) + if err != nil { + return ipv4net, err + } + ippre = netip.PrefixFrom(ip, 32) + } + + ip4 := ippre.Addr().As4() + prefix := uint8(ippre.Bits()) // 0-32 + return NewIPv4Net( + binary.BigEndian.Uint32(ip4[:]), + prefix, + ), nil +} + func LoadFile(path string, unsorted bool) (*Cohort, error) { f, err := os.Open(path) if err != nil { @@ -78,47 +128,34 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { continue } - var ippre netip.Prefix - var ip netip.Addr - if strings.Contains(raw, "/") { - ippre, err = netip.ParsePrefix(raw) - if err != nil { - log.Printf("skipping invalid entry: %q", raw) - continue - } - } else { - ip, err = netip.ParseAddr(raw) - if err != nil { - log.Printf("skipping invalid entry: %q", raw) - continue - } - ippre = netip.PrefixFrom(ip, 32) + ipv4net, err := ParseIPv4(raw) + if err != nil { + log.Printf("skipping invalid entry: %q", raw) + continue } - - ip4 := ippre.Addr().As4() - prefix := uint8(ippre.Bits()) // 0-32 - ranges = append(ranges, NewIPv4Net( - binary.BigEndian.Uint32(ip4[:]), - prefix, - )) + ranges = append(ranges, ipv4net) } if unsorted { - // Sort by network address (required for binary search) - sort.Slice(ranges, func(i, j int) bool { - // Note: we could also sort by prefix (largest first) - return ranges[i].networkBE < ranges[j].networkBE - }) - - // Note: we could also merge ranges here + sortRanges(ranges) } sizedList := make([]IPv4Net, len(ranges)) copy(sizedList, ranges) - ipList := &Cohort{} - ipList.Store(&innerCohort{ranges: sizedList}) - return ipList, nil + cohort := &Cohort{} + cohort.Store(&innerCohort{ranges: sizedList}) + return cohort, nil +} + +func sortRanges(ranges []IPv4Net) { + // Sort by network address (required for binary search) + sort.Slice(ranges, func(i, j int) bool { + // Note: we could also sort by prefix (largest first) + return ranges[i].networkBE < ranges[j].networkBE + }) + + // Note: we could also merge ranges here } type Cohort struct {