mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
refactor: decouple gitdataset/ipcohort for multi-file repos
gitshallow: fix double-fetch (pull already fetches), drop redundant -C flags gitdataset: split into GitDataset[T] (file+atomic) and GitRepo (git+multi-dataset) - NewDataset for file-only use, AddDataset to register with a GitRepo - one clone/fetch per repo regardless of how many datasets it has ipcohort: split Cohort into hosts (sorted /32, binary search) + nets (CIDRs, linear) - fixes false negatives when broad CIDRs (e.g. /8) precede specific entries - fixes Parse() sort-before-copy order bug - ReadAll always sorts; unsorted param removed (was dead code)
This commit is contained in:
parent
a8e108a05b
commit
8731eaf10b
134
fs/dataset/dataset.go
Normal file
134
fs/dataset/dataset.go
Normal file
@ -0,0 +1,134 @@
|
||||
package dataset
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitshallow"
|
||||
)
|
||||
|
||||
// File holds an atomically-swappable pointer to a value loaded from a file.
|
||||
// Reads are lock-free. Use NewFile for file-only use, or AddFile to attach
|
||||
// to a GitRepo so the value refreshes whenever the repo is updated.
|
||||
type File[T any] struct {
|
||||
atomic.Pointer[T]
|
||||
path string
|
||||
loadFile func(string) (*T, error)
|
||||
}
|
||||
|
||||
// NewFile creates a file-backed dataset with no git dependency.
|
||||
// Call Reload to do the initial load and after any file change.
|
||||
func NewFile[T any](path string, loadFile func(string) (*T, error)) *File[T] {
|
||||
d := &File[T]{
|
||||
path: path,
|
||||
loadFile: loadFile,
|
||||
}
|
||||
d.Store(new(T))
|
||||
return d
|
||||
}
|
||||
|
||||
// Reload reads the file and atomically replaces the stored value.
|
||||
func (d *File[T]) Reload() error {
|
||||
v, err := d.loadFile(d.path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Store(v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *File[T]) reloadFile() error {
|
||||
return d.Reload()
|
||||
}
|
||||
|
||||
// reloader is the internal interface GitRepo uses to trigger file reloads.
|
||||
type reloader interface {
|
||||
reloadFile() error
|
||||
}
|
||||
|
||||
// GitRepo manages a shallow git clone and reloads all registered files
|
||||
// whenever the repo is updated. Multiple files from the same repo share
|
||||
// one clone and one pull, avoiding git file-lock conflicts.
|
||||
type GitRepo struct {
|
||||
path string
|
||||
shallowRepo *gitshallow.ShallowRepo
|
||||
files []reloader
|
||||
}
|
||||
|
||||
// NewRepo creates a GitRepo backed by the given git URL, cloning into repoPath.
|
||||
func NewRepo(gitURL, repoPath string) *GitRepo {
|
||||
return &GitRepo{
|
||||
path: repoPath,
|
||||
shallowRepo: gitshallow.New(gitURL, repoPath, 1, ""),
|
||||
}
|
||||
}
|
||||
|
||||
// AddFile registers a file inside this repo and returns its handle.
|
||||
// relPath is relative to the repo root. The file is reloaded automatically
|
||||
// whenever the repo is synced via Init or Run.
|
||||
func AddFile[T any](repo *GitRepo, relPath string, loadFile func(string) (*T, error)) *File[T] {
|
||||
d := NewFile(filepath.Join(repo.path, relPath), loadFile)
|
||||
repo.files = append(repo.files, d)
|
||||
return d
|
||||
}
|
||||
|
||||
// Init clones the repo if missing, syncs once, and loads all registered files.
|
||||
// Always runs aggressive GC — acceptable as a one-time startup cost.
|
||||
func (r *GitRepo) Init() error {
|
||||
gitDir := filepath.Join(r.path, ".git")
|
||||
if _, err := os.Stat(gitDir); err != nil {
|
||||
if _, err := r.shallowRepo.Clone(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
_, err := r.sync(false, true)
|
||||
return err
|
||||
}
|
||||
|
||||
// Run periodically syncs the repo and reloads files. Blocks until ctx is done.
|
||||
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
|
||||
// Pass true to skip both when the periodic GC is too slow for your workload.
|
||||
func (r *GitRepo) Run(ctx context.Context, lightGC bool) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if updated, err := r.sync(lightGC, false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: git repo sync: %v\n", err)
|
||||
} else if updated {
|
||||
fmt.Fprintf(os.Stderr, "git repo: files reloaded\n")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sync pulls the latest commits and reloads all files if HEAD changed.
|
||||
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
|
||||
func (r *GitRepo) Sync(lightGC bool) (bool, error) {
|
||||
return r.sync(lightGC, false)
|
||||
}
|
||||
|
||||
func (r *GitRepo) sync(lightGC, force bool) (bool, error) {
|
||||
updated, err := r.shallowRepo.Sync(lightGC)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("git sync: %w", err)
|
||||
}
|
||||
if !updated && !force {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
for _, f := range r.files {
|
||||
if err := f.reloadFile(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: reload file: %v\n", err)
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
@ -1,92 +0,0 @@
|
||||
package gitdataset
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitshallow"
|
||||
)
|
||||
|
||||
// TODO maybe a GitRepo should contain GitDatasets such that loading
|
||||
// multiple datasets from the same GitRepo won't cause issues with file locking?
|
||||
|
||||
type GitDataset[T any] struct {
|
||||
LoadFile func(path string) (*T, error)
|
||||
atomic.Pointer[T]
|
||||
gitRepo string
|
||||
shallowRepo *gitshallow.ShallowRepo
|
||||
path string
|
||||
}
|
||||
|
||||
func New[T any](gitURL, path string, loadFile func(path string) (*T, error)) *GitDataset[T] {
|
||||
gitRepo := filepath.Dir(path)
|
||||
gitDepth := 1
|
||||
gitBranch := ""
|
||||
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
|
||||
|
||||
b := &GitDataset[T]{
|
||||
Pointer: atomic.Pointer[T]{},
|
||||
LoadFile: loadFile,
|
||||
gitRepo: gitRepo,
|
||||
shallowRepo: shallowRepo,
|
||||
path: path,
|
||||
}
|
||||
b.Store(new(T))
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *GitDataset[T]) Init(skipGC bool) (updated bool, err error) {
|
||||
gitDir := filepath.Join(b.gitRepo, ".git")
|
||||
if _, err := os.Stat(gitDir); err != nil {
|
||||
if _, err := b.shallowRepo.Clone(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
force := true
|
||||
return b.reload(skipGC, force)
|
||||
}
|
||||
|
||||
func (b *GitDataset[T]) Run(ctx context.Context) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if ok, err := b.reload(false, false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: git data: %v\n", err)
|
||||
} else if ok {
|
||||
fmt.Fprintf(os.Stderr, "git data: loaded repo\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "git data: already up-to-date\n")
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *GitDataset[T]) reload(skipGC, force bool) (updated bool, err error) {
|
||||
laxGC := skipGC
|
||||
lazyPrune := skipGC
|
||||
updated, err = b.shallowRepo.Sync(laxGC, lazyPrune)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("git sync: %w", err)
|
||||
}
|
||||
if !updated && !force {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
nextDataset, err := b.LoadFile(b.path)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
_ = b.Swap(nextDataset)
|
||||
return true, nil
|
||||
}
|
||||
@ -19,10 +19,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDepth = 1 // shallow by default
|
||||
defaultBranch = "" // empty = default branch + --single-branch
|
||||
laxGC = false // false = --aggressive
|
||||
lazyPrune = false // false = --prune=now
|
||||
defaultDepth = 1 // shallow by default
|
||||
defaultBranch = "" // empty = default branch + --single-branch
|
||||
)
|
||||
|
||||
func main() {
|
||||
@ -37,7 +35,7 @@ func main() {
|
||||
url := os.Args[1]
|
||||
path := os.Args[2]
|
||||
|
||||
// Expand ~ to home directory for Windows
|
||||
// Expand ~ to home directory
|
||||
if path[0] == '~' {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
@ -47,7 +45,6 @@ func main() {
|
||||
path = filepath.Join(home, path[1:])
|
||||
}
|
||||
|
||||
// Make path absolute
|
||||
absPath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err)
|
||||
@ -60,14 +57,14 @@ func main() {
|
||||
|
||||
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
|
||||
|
||||
updated, err := repo.Sync(laxGC, lazyPrune)
|
||||
updated, err := repo.Sync(false)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if updated {
|
||||
fmt.Println("Repository was updated (new commits fetched).")
|
||||
fmt.Println("Repository was updated (new commits pulled).")
|
||||
} else {
|
||||
fmt.Println("Repository is already up to date.")
|
||||
}
|
||||
|
||||
@ -14,11 +14,9 @@ type ShallowRepo struct {
|
||||
URL string
|
||||
Path string
|
||||
Depth int // 0 defaults to 1, -1 for all
|
||||
Branch string // Optional: specific branch to clone/fetch
|
||||
//WithBranches bool
|
||||
//WithTags bool
|
||||
Branch string // Optional: specific branch to clone/pull
|
||||
|
||||
mu sync.Mutex // Mutex for in-process locking
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// New creates a new ShallowRepo instance.
|
||||
@ -30,11 +28,11 @@ func New(url, path string, depth int, branch string) *ShallowRepo {
|
||||
URL: url,
|
||||
Path: path,
|
||||
Depth: depth,
|
||||
Branch: strings.TrimSpace(branch), // clean up accidental whitespace
|
||||
Branch: strings.TrimSpace(branch),
|
||||
}
|
||||
}
|
||||
|
||||
// Clone performs a shallow clone (default --depth 0 --single-branch, --no-tags, etc).
|
||||
// Clone performs a shallow clone (--depth N --single-branch --no-tags).
|
||||
func (r *ShallowRepo) Clone() (bool, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
@ -77,8 +75,7 @@ func (r *ShallowRepo) exists() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// runGit executes a git command.
|
||||
// For clone it runs in the parent directory; otherwise inside the repo.
|
||||
// runGit executes a git command in the repo directory (or parent for clone).
|
||||
func (r *ShallowRepo) runGit(args ...string) (string, error) {
|
||||
cmd := exec.Command("git", args...)
|
||||
|
||||
@ -96,51 +93,39 @@ func (r *ShallowRepo) runGit(args ...string) (string, error) {
|
||||
return strings.TrimSpace(string(output)), nil
|
||||
}
|
||||
|
||||
// Fetch performs a shallow fetch and updates the working branch.
|
||||
// Returns true if HEAD changed (i.e. meaningful update occurred).
|
||||
// Uses --depth on fetch; branch filtering only when Branch is set.
|
||||
func (r *ShallowRepo) Fetch() (updated bool, err error) {
|
||||
// Pull performs a shallow pull (--ff-only) and reports whether HEAD changed.
|
||||
func (r *ShallowRepo) Pull() (updated bool, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
return r.fetch()
|
||||
return r.pull()
|
||||
}
|
||||
|
||||
func (r *ShallowRepo) fetch() (updated bool, err error) {
|
||||
func (r *ShallowRepo) pull() (updated bool, err error) {
|
||||
if !r.exists() {
|
||||
return false, fmt.Errorf("repository does not exist at %s", r.Path)
|
||||
}
|
||||
|
||||
// Remember current HEAD
|
||||
oldHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
|
||||
oldHead, err := r.runGit("rev-parse", "HEAD")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Update local branch (git pull --ff-only is safer in shallow context)
|
||||
pullArgs := []string{"-C", r.Path, "pull", "--ff-only"}
|
||||
if r.Branch != "" {
|
||||
pullArgs = append(pullArgs, "origin", r.Branch)
|
||||
}
|
||||
_, err = r.runGit(pullArgs...)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// Fetch
|
||||
fetchArgs := []string{"-C", r.Path, "fetch", "--no-tags"}
|
||||
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
|
||||
if r.Depth == 0 {
|
||||
r.Depth = 1
|
||||
}
|
||||
if r.Depth >= 0 {
|
||||
fetchArgs = append(fetchArgs, "--depth", fmt.Sprintf("%d", r.Depth))
|
||||
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
|
||||
}
|
||||
_, err = r.runGit(fetchArgs...)
|
||||
if err != nil {
|
||||
if r.Branch != "" {
|
||||
pullArgs = append(pullArgs, "origin", r.Branch)
|
||||
}
|
||||
if _, err = r.runGit(pullArgs...); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
newHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
|
||||
newHead, err := r.runGit("rev-parse", "HEAD")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -148,24 +133,24 @@ func (r *ShallowRepo) fetch() (updated bool, err error) {
|
||||
return oldHead != newHead, nil
|
||||
}
|
||||
|
||||
// GC runs git gc, defaulting to pruning immediately and aggressively
|
||||
func (r *ShallowRepo) GC(lax, lazy bool) error {
|
||||
// GC runs git gc. aggressiveGC adds --aggressive; pruneNow adds --prune=now.
|
||||
func (r *ShallowRepo) GC(aggressiveGC, pruneNow bool) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
return r.gc(lax, lazy)
|
||||
return r.gc(aggressiveGC, pruneNow)
|
||||
}
|
||||
|
||||
func (r *ShallowRepo) gc(lax, lazy bool) error {
|
||||
func (r *ShallowRepo) gc(aggressiveGC, pruneNow bool) error {
|
||||
if !r.exists() {
|
||||
return fmt.Errorf("repository does not exist at %s", r.Path)
|
||||
}
|
||||
|
||||
args := []string{"-C", r.Path, "gc"}
|
||||
if !lax {
|
||||
args := []string{"gc"}
|
||||
if aggressiveGC {
|
||||
args = append(args, "--aggressive")
|
||||
}
|
||||
if !lazy {
|
||||
if pruneNow {
|
||||
args = append(args, "--prune=now")
|
||||
}
|
||||
|
||||
@ -173,27 +158,30 @@ func (r *ShallowRepo) gc(lax, lazy bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sync clones if missing, fetches, and runs GC.
|
||||
// Returns whether fetch caused an update.
|
||||
func (r *ShallowRepo) Sync(laxGC, lazyPrune bool) (updated bool, err error) {
|
||||
// Sync clones if missing, pulls, and runs GC.
|
||||
// lightGC=false (zero value) runs --aggressive GC with --prune=now to minimize disk use.
|
||||
// Pass true to skip both when speed matters more than footprint.
|
||||
func (r *ShallowRepo) Sync(lightGC bool) (updated bool, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if updated, err := r.clone(); err != nil {
|
||||
if cloned, err := r.clone(); err != nil {
|
||||
return false, err
|
||||
} else if updated {
|
||||
return updated, nil
|
||||
} else if cloned {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if updated, err := r.fetch(); err != nil {
|
||||
return updated, err
|
||||
} else if !updated {
|
||||
updated, err = r.pull()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !updated {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := r.gc(laxGC, lazyPrune); err != nil {
|
||||
return updated, fmt.Errorf("gc failed but fetch succeeded: %w", err)
|
||||
if err := r.gc(!lightGC, !lightGC); err != nil {
|
||||
return true, fmt.Errorf("gc failed but pull succeeded: %w", err)
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@ -3,14 +3,15 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitdataset"
|
||||
"github.com/therootcompany/golib/fs/dataset"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 3 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url]\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -21,28 +22,31 @@ func main() {
|
||||
gitURL = os.Args[3]
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
||||
var blacklist *dataset.File[ipcohort.Cohort]
|
||||
|
||||
var b *ipcohort.Cohort
|
||||
loadFile := func(path string) (*ipcohort.Cohort, error) {
|
||||
return ipcohort.LoadFile(path, false)
|
||||
}
|
||||
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
|
||||
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
||||
if updated, err := blacklist.Init(false); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
||||
if gitURL != "" {
|
||||
repoDir := filepath.Dir(dataPath)
|
||||
relPath := filepath.Base(dataPath)
|
||||
repo := dataset.NewRepo(gitURL, repoDir)
|
||||
blacklist = dataset.AddFile(repo, relPath, ipcohort.LoadFile)
|
||||
fmt.Fprintf(os.Stderr, "Syncing %q ...\n", repoDir)
|
||||
if err := repo.Init(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: git sync: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
b = blacklist.Load()
|
||||
if updated {
|
||||
n := b.Size()
|
||||
if n > 0 {
|
||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
||||
}
|
||||
blacklist = dataset.NewFile(dataPath, ipcohort.LoadFile)
|
||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
||||
if err := blacklist.Reload(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: load: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
||||
if blacklist.Load().Contains(ipStr) {
|
||||
c := blacklist.Load()
|
||||
fmt.Fprintf(os.Stderr, "Loaded %d entries\n", c.Size())
|
||||
|
||||
if c.Contains(ipStr) {
|
||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@ -9,11 +9,11 @@ import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
||||
// IPv4Net represents a subnet or single address (/32).
|
||||
// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8.
|
||||
type IPv4Net struct {
|
||||
networkBE uint32
|
||||
prefix uint8
|
||||
@ -29,32 +29,81 @@ func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
|
||||
}
|
||||
|
||||
func (r IPv4Net) Contains(ip uint32) bool {
|
||||
mask := uint32(0xFFFFFFFF << (r.shift))
|
||||
mask := uint32(0xFFFFFFFF << r.shift)
|
||||
return (ip & mask) == r.networkBE
|
||||
}
|
||||
|
||||
// Cohort is an immutable, read-only set of IPv4 addresses and subnets.
|
||||
// Contains is safe for concurrent use without locks.
|
||||
//
|
||||
// hosts holds sorted /32 addresses for O(log n) binary search.
|
||||
// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small.
|
||||
type Cohort struct {
|
||||
hosts []uint32
|
||||
nets []IPv4Net
|
||||
}
|
||||
|
||||
func New() *Cohort {
|
||||
cohort := &Cohort{ranges: []IPv4Net{}}
|
||||
return cohort
|
||||
return &Cohort{}
|
||||
}
|
||||
|
||||
// Size returns the total number of entries (hosts + nets).
|
||||
func (c *Cohort) Size() int {
|
||||
return len(c.hosts) + len(c.nets)
|
||||
}
|
||||
|
||||
// Contains reports whether ipStr falls within any host or subnet in the cohort.
|
||||
// Returns true on parse error (fail-closed).
|
||||
func (c *Cohort) Contains(ipStr string) bool {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
ip4 := ip.As4()
|
||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||
|
||||
_, found := slices.BinarySearch(c.hosts, ipU32)
|
||||
if found {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, net := range c.nets {
|
||||
if net.Contains(ipU32) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func Parse(prefixList []string) (*Cohort, error) {
|
||||
var ranges []IPv4Net
|
||||
var hosts []uint32
|
||||
var nets []IPv4Net
|
||||
|
||||
for _, raw := range prefixList {
|
||||
ipv4net, err := ParseIPv4(raw)
|
||||
if err != nil {
|
||||
log.Printf("skipping invalid entry: %q", raw)
|
||||
continue
|
||||
}
|
||||
ranges = append(ranges, ipv4net)
|
||||
if ipv4net.prefix == 32 {
|
||||
hosts = append(hosts, ipv4net.networkBE)
|
||||
} else {
|
||||
nets = append(nets, ipv4net)
|
||||
}
|
||||
}
|
||||
|
||||
sizedList := make([]IPv4Net, len(ranges))
|
||||
copy(sizedList, ranges)
|
||||
sortRanges(ranges)
|
||||
slices.Sort(hosts)
|
||||
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
||||
if a.networkBE < b.networkBE {
|
||||
return -1
|
||||
}
|
||||
if a.networkBE > b.networkBE {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
|
||||
cohort := &Cohort{ranges: sizedList}
|
||||
return cohort, nil
|
||||
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||
}
|
||||
|
||||
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
||||
@ -74,32 +123,34 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
||||
}
|
||||
|
||||
ip4 := ippre.Addr().As4()
|
||||
prefix := uint8(ippre.Bits()) // 0-32
|
||||
prefix := uint8(ippre.Bits()) // 0–32
|
||||
return NewIPv4Net(
|
||||
binary.BigEndian.Uint32(ip4[:]),
|
||||
prefix,
|
||||
), nil
|
||||
}
|
||||
|
||||
func LoadFile(path string, unsorted bool) (*Cohort, error) {
|
||||
func LoadFile(path string) (*Cohort, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not load %q: %v", path, err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
return ParseCSV(f, unsorted)
|
||||
return ParseCSV(f)
|
||||
}
|
||||
|
||||
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
|
||||
func ParseCSV(f io.Reader) (*Cohort, error) {
|
||||
r := csv.NewReader(f)
|
||||
r.FieldsPerRecord = -1
|
||||
|
||||
return ReadAll(r, unsorted)
|
||||
return ReadAll(r)
|
||||
}
|
||||
|
||||
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||
var ranges []IPv4Net
|
||||
func ReadAll(r *csv.Reader) (*Cohort, error) {
|
||||
var hosts []uint32
|
||||
var nets []IPv4Net
|
||||
|
||||
for {
|
||||
record, err := r.Read()
|
||||
if err == io.EOF {
|
||||
@ -115,7 +166,6 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||
|
||||
raw := strings.TrimSpace(record[0])
|
||||
|
||||
// Skip comments/empty
|
||||
if raw == "" || strings.HasPrefix(raw, "#") {
|
||||
continue
|
||||
}
|
||||
@ -130,65 +180,24 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
||||
log.Printf("skipping invalid entry: %q", raw)
|
||||
continue
|
||||
}
|
||||
ranges = append(ranges, ipv4net)
|
||||
|
||||
if ipv4net.prefix == 32 {
|
||||
hosts = append(hosts, ipv4net.networkBE)
|
||||
} else {
|
||||
nets = append(nets, ipv4net)
|
||||
}
|
||||
}
|
||||
|
||||
if unsorted {
|
||||
sortRanges(ranges)
|
||||
}
|
||||
|
||||
sizedList := make([]IPv4Net, len(ranges))
|
||||
copy(sizedList, ranges)
|
||||
|
||||
cohort := &Cohort{ranges: sizedList}
|
||||
return cohort, nil
|
||||
}
|
||||
|
||||
func sortRanges(ranges []IPv4Net) {
|
||||
// Sort by network address (required for binary search)
|
||||
sort.Slice(ranges, func(i, j int) bool {
|
||||
// Note: we could also sort by prefix (largest first)
|
||||
return ranges[i].networkBE < ranges[j].networkBE
|
||||
})
|
||||
|
||||
// Note: we could also merge ranges here
|
||||
}
|
||||
|
||||
type Cohort struct {
|
||||
ranges []IPv4Net
|
||||
}
|
||||
|
||||
func (c *Cohort) Size() int {
|
||||
return len(c.ranges)
|
||||
}
|
||||
|
||||
func (c *Cohort) Contains(ipStr string) bool {
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
ip4 := ip.As4()
|
||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||
|
||||
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
||||
if r.networkBE < target {
|
||||
slices.Sort(hosts)
|
||||
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
||||
if a.networkBE < b.networkBE {
|
||||
return -1
|
||||
}
|
||||
if r.networkBE > target {
|
||||
if a.networkBE > b.networkBE {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
})
|
||||
if found {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check the range immediately before the insertion point
|
||||
if idx > 0 {
|
||||
if c.ranges[idx-1].Contains(ipU32) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user