refactor: decouple gitdataset/ipcohort for multi-file repos

gitshallow: fix double-fetch (pull already fetches), drop redundant -C flags
gitdataset: split into GitDataset[T] (file+atomic) and GitRepo (git+multi-dataset)
  - NewDataset for file-only use, AddDataset to register with a GitRepo
  - one clone/fetch per repo regardless of how many datasets it has
ipcohort: split Cohort into hosts (sorted /32, binary search) + nets (CIDRs, linear)
  - fixes false negatives when broad CIDRs (e.g. /8) precede specific entries
  - fixes Parse() sort-before-copy order bug
  - ReadAll always sorts; unsorted param removed (was dead code)
This commit is contained in:
AJ ONeal 2026-04-19 22:19:19 -06:00
parent a8e108a05b
commit 8731eaf10b
No known key found for this signature in database
6 changed files with 282 additions and 242 deletions

134
fs/dataset/dataset.go Normal file
View File

@ -0,0 +1,134 @@
package dataset
import (
"context"
"fmt"
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/therootcompany/golib/net/gitshallow"
)
// File holds an atomically-swappable pointer to a value loaded from a file.
// Reads are lock-free. Use NewFile for file-only use, or AddFile to attach
// to a GitRepo so the value refreshes whenever the repo is updated.
type File[T any] struct {
atomic.Pointer[T]
path string
loadFile func(string) (*T, error)
}
// NewFile creates a file-backed dataset with no git dependency.
// Call Reload to do the initial load and after any file change.
func NewFile[T any](path string, loadFile func(string) (*T, error)) *File[T] {
d := &File[T]{
path: path,
loadFile: loadFile,
}
d.Store(new(T))
return d
}
// Reload reads the file and atomically replaces the stored value.
func (d *File[T]) Reload() error {
v, err := d.loadFile(d.path)
if err != nil {
return err
}
d.Store(v)
return nil
}
func (d *File[T]) reloadFile() error {
return d.Reload()
}
// reloader is the internal interface GitRepo uses to trigger file reloads.
type reloader interface {
reloadFile() error
}
// GitRepo manages a shallow git clone and reloads all registered files
// whenever the repo is updated. Multiple files from the same repo share
// one clone and one pull, avoiding git file-lock conflicts.
type GitRepo struct {
path string
shallowRepo *gitshallow.ShallowRepo
files []reloader
}
// NewRepo creates a GitRepo backed by the given git URL, cloning into repoPath.
func NewRepo(gitURL, repoPath string) *GitRepo {
return &GitRepo{
path: repoPath,
shallowRepo: gitshallow.New(gitURL, repoPath, 1, ""),
}
}
// AddFile registers a file inside this repo and returns its handle.
// relPath is relative to the repo root. The file is reloaded automatically
// whenever the repo is synced via Init or Run.
func AddFile[T any](repo *GitRepo, relPath string, loadFile func(string) (*T, error)) *File[T] {
d := NewFile(filepath.Join(repo.path, relPath), loadFile)
repo.files = append(repo.files, d)
return d
}
// Init clones the repo if missing, syncs once, and loads all registered files.
// Always runs aggressive GC — acceptable as a one-time startup cost.
func (r *GitRepo) Init() error {
gitDir := filepath.Join(r.path, ".git")
if _, err := os.Stat(gitDir); err != nil {
if _, err := r.shallowRepo.Clone(); err != nil {
return err
}
}
_, err := r.sync(false, true)
return err
}
// Run periodically syncs the repo and reloads files. Blocks until ctx is done.
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
// Pass true to skip both when the periodic GC is too slow for your workload.
func (r *GitRepo) Run(ctx context.Context, lightGC bool) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if updated, err := r.sync(lightGC, false); err != nil {
fmt.Fprintf(os.Stderr, "error: git repo sync: %v\n", err)
} else if updated {
fmt.Fprintf(os.Stderr, "git repo: files reloaded\n")
}
case <-ctx.Done():
return
}
}
}
// Sync pulls the latest commits and reloads all files if HEAD changed.
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
func (r *GitRepo) Sync(lightGC bool) (bool, error) {
return r.sync(lightGC, false)
}
func (r *GitRepo) sync(lightGC, force bool) (bool, error) {
updated, err := r.shallowRepo.Sync(lightGC)
if err != nil {
return false, fmt.Errorf("git sync: %w", err)
}
if !updated && !force {
return false, nil
}
for _, f := range r.files {
if err := f.reloadFile(); err != nil {
fmt.Fprintf(os.Stderr, "error: reload file: %v\n", err)
}
}
return true, nil
}

View File

@ -1,92 +0,0 @@
package gitdataset
import (
"context"
"fmt"
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/therootcompany/golib/net/gitshallow"
)
// TODO maybe a GitRepo should contain GitDatasets such that loading
// multiple datasets from the same GitRepo won't cause issues with file locking?
type GitDataset[T any] struct {
LoadFile func(path string) (*T, error)
atomic.Pointer[T]
gitRepo string
shallowRepo *gitshallow.ShallowRepo
path string
}
func New[T any](gitURL, path string, loadFile func(path string) (*T, error)) *GitDataset[T] {
gitRepo := filepath.Dir(path)
gitDepth := 1
gitBranch := ""
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
b := &GitDataset[T]{
Pointer: atomic.Pointer[T]{},
LoadFile: loadFile,
gitRepo: gitRepo,
shallowRepo: shallowRepo,
path: path,
}
b.Store(new(T))
return b
}
func (b *GitDataset[T]) Init(skipGC bool) (updated bool, err error) {
gitDir := filepath.Join(b.gitRepo, ".git")
if _, err := os.Stat(gitDir); err != nil {
if _, err := b.shallowRepo.Clone(); err != nil {
return false, err
}
}
force := true
return b.reload(skipGC, force)
}
func (b *GitDataset[T]) Run(ctx context.Context) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if ok, err := b.reload(false, false); err != nil {
fmt.Fprintf(os.Stderr, "error: git data: %v\n", err)
} else if ok {
fmt.Fprintf(os.Stderr, "git data: loaded repo\n")
} else {
fmt.Fprintf(os.Stderr, "git data: already up-to-date\n")
}
case <-ctx.Done():
return
}
}
}
func (b *GitDataset[T]) reload(skipGC, force bool) (updated bool, err error) {
laxGC := skipGC
lazyPrune := skipGC
updated, err = b.shallowRepo.Sync(laxGC, lazyPrune)
if err != nil {
return false, fmt.Errorf("git sync: %w", err)
}
if !updated && !force {
return false, nil
}
nextDataset, err := b.LoadFile(b.path)
if err != nil {
return false, err
}
_ = b.Swap(nextDataset)
return true, nil
}

View File

@ -19,10 +19,8 @@ import (
)
const (
defaultDepth = 1 // shallow by default
defaultBranch = "" // empty = default branch + --single-branch
laxGC = false // false = --aggressive
lazyPrune = false // false = --prune=now
defaultDepth = 1 // shallow by default
defaultBranch = "" // empty = default branch + --single-branch
)
func main() {
@ -37,7 +35,7 @@ func main() {
url := os.Args[1]
path := os.Args[2]
// Expand ~ to home directory for Windows
// Expand ~ to home directory
if path[0] == '~' {
home, err := os.UserHomeDir()
if err != nil {
@ -47,7 +45,6 @@ func main() {
path = filepath.Join(home, path[1:])
}
// Make path absolute
absPath, err := filepath.Abs(path)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err)
@ -60,14 +57,14 @@ func main() {
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
updated, err := repo.Sync(laxGC, lazyPrune)
updated, err := repo.Sync(false)
if err != nil {
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err)
os.Exit(1)
}
if updated {
fmt.Println("Repository was updated (new commits fetched).")
fmt.Println("Repository was updated (new commits pulled).")
} else {
fmt.Println("Repository is already up to date.")
}

View File

@ -14,11 +14,9 @@ type ShallowRepo struct {
URL string
Path string
Depth int // 0 defaults to 1, -1 for all
Branch string // Optional: specific branch to clone/fetch
//WithBranches bool
//WithTags bool
Branch string // Optional: specific branch to clone/pull
mu sync.Mutex // Mutex for in-process locking
mu sync.Mutex
}
// New creates a new ShallowRepo instance.
@ -30,11 +28,11 @@ func New(url, path string, depth int, branch string) *ShallowRepo {
URL: url,
Path: path,
Depth: depth,
Branch: strings.TrimSpace(branch), // clean up accidental whitespace
Branch: strings.TrimSpace(branch),
}
}
// Clone performs a shallow clone (default --depth 0 --single-branch, --no-tags, etc).
// Clone performs a shallow clone (--depth N --single-branch --no-tags).
func (r *ShallowRepo) Clone() (bool, error) {
r.mu.Lock()
defer r.mu.Unlock()
@ -77,8 +75,7 @@ func (r *ShallowRepo) exists() bool {
return err == nil
}
// runGit executes a git command.
// For clone it runs in the parent directory; otherwise inside the repo.
// runGit executes a git command in the repo directory (or parent for clone).
func (r *ShallowRepo) runGit(args ...string) (string, error) {
cmd := exec.Command("git", args...)
@ -96,51 +93,39 @@ func (r *ShallowRepo) runGit(args ...string) (string, error) {
return strings.TrimSpace(string(output)), nil
}
// Fetch performs a shallow fetch and updates the working branch.
// Returns true if HEAD changed (i.e. meaningful update occurred).
// Uses --depth on fetch; branch filtering only when Branch is set.
func (r *ShallowRepo) Fetch() (updated bool, err error) {
// Pull performs a shallow pull (--ff-only) and reports whether HEAD changed.
func (r *ShallowRepo) Pull() (updated bool, err error) {
r.mu.Lock()
defer r.mu.Unlock()
return r.fetch()
return r.pull()
}
func (r *ShallowRepo) fetch() (updated bool, err error) {
func (r *ShallowRepo) pull() (updated bool, err error) {
if !r.exists() {
return false, fmt.Errorf("repository does not exist at %s", r.Path)
}
// Remember current HEAD
oldHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
oldHead, err := r.runGit("rev-parse", "HEAD")
if err != nil {
return false, err
}
// Update local branch (git pull --ff-only is safer in shallow context)
pullArgs := []string{"-C", r.Path, "pull", "--ff-only"}
if r.Branch != "" {
pullArgs = append(pullArgs, "origin", r.Branch)
}
_, err = r.runGit(pullArgs...)
if err != nil {
return false, err
}
// Fetch
fetchArgs := []string{"-C", r.Path, "fetch", "--no-tags"}
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
if r.Depth == 0 {
r.Depth = 1
}
if r.Depth >= 0 {
fetchArgs = append(fetchArgs, "--depth", fmt.Sprintf("%d", r.Depth))
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
}
_, err = r.runGit(fetchArgs...)
if err != nil {
if r.Branch != "" {
pullArgs = append(pullArgs, "origin", r.Branch)
}
if _, err = r.runGit(pullArgs...); err != nil {
return false, err
}
newHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
newHead, err := r.runGit("rev-parse", "HEAD")
if err != nil {
return false, err
}
@ -148,24 +133,24 @@ func (r *ShallowRepo) fetch() (updated bool, err error) {
return oldHead != newHead, nil
}
// GC runs git gc, defaulting to pruning immediately and aggressively
func (r *ShallowRepo) GC(lax, lazy bool) error {
// GC runs git gc. aggressiveGC adds --aggressive; pruneNow adds --prune=now.
func (r *ShallowRepo) GC(aggressiveGC, pruneNow bool) error {
r.mu.Lock()
defer r.mu.Unlock()
return r.gc(lax, lazy)
return r.gc(aggressiveGC, pruneNow)
}
func (r *ShallowRepo) gc(lax, lazy bool) error {
func (r *ShallowRepo) gc(aggressiveGC, pruneNow bool) error {
if !r.exists() {
return fmt.Errorf("repository does not exist at %s", r.Path)
}
args := []string{"-C", r.Path, "gc"}
if !lax {
args := []string{"gc"}
if aggressiveGC {
args = append(args, "--aggressive")
}
if !lazy {
if pruneNow {
args = append(args, "--prune=now")
}
@ -173,27 +158,30 @@ func (r *ShallowRepo) gc(lax, lazy bool) error {
return err
}
// Sync clones if missing, fetches, and runs GC.
// Returns whether fetch caused an update.
func (r *ShallowRepo) Sync(laxGC, lazyPrune bool) (updated bool, err error) {
// Sync clones if missing, pulls, and runs GC.
// lightGC=false (zero value) runs --aggressive GC with --prune=now to minimize disk use.
// Pass true to skip both when speed matters more than footprint.
func (r *ShallowRepo) Sync(lightGC bool) (updated bool, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if updated, err := r.clone(); err != nil {
if cloned, err := r.clone(); err != nil {
return false, err
} else if updated {
return updated, nil
} else if cloned {
return true, nil
}
if updated, err := r.fetch(); err != nil {
return updated, err
} else if !updated {
updated, err = r.pull()
if err != nil {
return false, err
}
if !updated {
return false, nil
}
if err := r.gc(laxGC, lazyPrune); err != nil {
return updated, fmt.Errorf("gc failed but fetch succeeded: %w", err)
if err := r.gc(!lightGC, !lightGC); err != nil {
return true, fmt.Errorf("gc failed but pull succeeded: %w", err)
}
return updated, nil
return true, nil
}

View File

@ -3,14 +3,15 @@ package main
import (
"fmt"
"os"
"path/filepath"
"github.com/therootcompany/golib/net/gitdataset"
"github.com/therootcompany/golib/fs/dataset"
"github.com/therootcompany/golib/net/ipcohort"
)
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url]\n", os.Args[0])
os.Exit(1)
}
@ -21,28 +22,31 @@ func main() {
gitURL = os.Args[3]
}
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
var blacklist *dataset.File[ipcohort.Cohort]
var b *ipcohort.Cohort
loadFile := func(path string) (*ipcohort.Cohort, error) {
return ipcohort.LoadFile(path, false)
}
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
if updated, err := blacklist.Init(false); err != nil {
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
if gitURL != "" {
repoDir := filepath.Dir(dataPath)
relPath := filepath.Base(dataPath)
repo := dataset.NewRepo(gitURL, repoDir)
blacklist = dataset.AddFile(repo, relPath, ipcohort.LoadFile)
fmt.Fprintf(os.Stderr, "Syncing %q ...\n", repoDir)
if err := repo.Init(); err != nil {
fmt.Fprintf(os.Stderr, "error: git sync: %v\n", err)
os.Exit(1)
}
} else {
b = blacklist.Load()
if updated {
n := b.Size()
if n > 0 {
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
}
blacklist = dataset.NewFile(dataPath, ipcohort.LoadFile)
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
if err := blacklist.Reload(); err != nil {
fmt.Fprintf(os.Stderr, "error: load: %v\n", err)
os.Exit(1)
}
}
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
if blacklist.Load().Contains(ipStr) {
c := blacklist.Load()
fmt.Fprintf(os.Stderr, "Loaded %d entries\n", c.Size())
if c.Contains(ipStr) {
fmt.Printf("%s is BLOCKED\n", ipStr)
os.Exit(1)
}

View File

@ -9,11 +9,11 @@ import (
"net/netip"
"os"
"slices"
"sort"
"strings"
)
// Either a subnet or single address (subnet with /32 CIDR prefix)
// IPv4Net represents a subnet or single address (/32).
// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8.
type IPv4Net struct {
networkBE uint32
prefix uint8
@ -29,32 +29,81 @@ func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
}
func (r IPv4Net) Contains(ip uint32) bool {
mask := uint32(0xFFFFFFFF << (r.shift))
mask := uint32(0xFFFFFFFF << r.shift)
return (ip & mask) == r.networkBE
}
// Cohort is an immutable, read-only set of IPv4 addresses and subnets.
// Contains is safe for concurrent use without locks.
//
// hosts holds sorted /32 addresses for O(log n) binary search.
// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small.
type Cohort struct {
hosts []uint32
nets []IPv4Net
}
func New() *Cohort {
cohort := &Cohort{ranges: []IPv4Net{}}
return cohort
return &Cohort{}
}
// Size returns the total number of entries (hosts + nets).
func (c *Cohort) Size() int {
return len(c.hosts) + len(c.nets)
}
// Contains reports whether ipStr falls within any host or subnet in the cohort.
// Returns true on parse error (fail-closed).
func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
_, found := slices.BinarySearch(c.hosts, ipU32)
if found {
return true
}
for _, net := range c.nets {
if net.Contains(ipU32) {
return true
}
}
return false
}
func Parse(prefixList []string) (*Cohort, error) {
var ranges []IPv4Net
var hosts []uint32
var nets []IPv4Net
for _, raw := range prefixList {
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
ranges = append(ranges, ipv4net)
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
sortRanges(ranges)
slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int {
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
cohort := &Cohort{ranges: sizedList}
return cohort, nil
return &Cohort{hosts: hosts, nets: nets}, nil
}
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
@ -74,32 +123,34 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
}
ip4 := ippre.Addr().As4()
prefix := uint8(ippre.Bits()) // 0-32
prefix := uint8(ippre.Bits()) // 032
return NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
), nil
}
func LoadFile(path string, unsorted bool) (*Cohort, error) {
func LoadFile(path string) (*Cohort, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("could not load %q: %v", path, err)
}
defer f.Close()
return ParseCSV(f, unsorted)
return ParseCSV(f)
}
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
func ParseCSV(f io.Reader) (*Cohort, error) {
r := csv.NewReader(f)
r.FieldsPerRecord = -1
return ReadAll(r, unsorted)
return ReadAll(r)
}
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
var ranges []IPv4Net
func ReadAll(r *csv.Reader) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for {
record, err := r.Read()
if err == io.EOF {
@ -115,7 +166,6 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
raw := strings.TrimSpace(record[0])
// Skip comments/empty
if raw == "" || strings.HasPrefix(raw, "#") {
continue
}
@ -130,65 +180,24 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
log.Printf("skipping invalid entry: %q", raw)
continue
}
ranges = append(ranges, ipv4net)
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
if unsorted {
sortRanges(ranges)
}
sizedList := make([]IPv4Net, len(ranges))
copy(sizedList, ranges)
cohort := &Cohort{ranges: sizedList}
return cohort, nil
}
func sortRanges(ranges []IPv4Net) {
// Sort by network address (required for binary search)
sort.Slice(ranges, func(i, j int) bool {
// Note: we could also sort by prefix (largest first)
return ranges[i].networkBE < ranges[j].networkBE
})
// Note: we could also merge ranges here
}
type Cohort struct {
ranges []IPv4Net
}
func (c *Cohort) Size() int {
return len(c.ranges)
}
func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
if r.networkBE < target {
slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int {
if a.networkBE < b.networkBE {
return -1
}
if r.networkBE > target {
if a.networkBE > b.networkBE {
return 1
}
return 0
})
if found {
return true
}
// Check the range immediately before the insertion point
if idx > 0 {
if c.ranges[idx-1].Contains(ipU32) {
return true
}
}
return false
return &Cohort{hosts: hosts, nets: nets}, nil
}