diff --git a/fs/dataset/dataset.go b/fs/dataset/dataset.go new file mode 100644 index 0000000..dc94c36 --- /dev/null +++ b/fs/dataset/dataset.go @@ -0,0 +1,134 @@ +package dataset + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sync/atomic" + "time" + + "github.com/therootcompany/golib/net/gitshallow" +) + +// File holds an atomically-swappable pointer to a value loaded from a file. +// Reads are lock-free. Use NewFile for file-only use, or AddFile to attach +// to a GitRepo so the value refreshes whenever the repo is updated. +type File[T any] struct { + atomic.Pointer[T] + path string + loadFile func(string) (*T, error) +} + +// NewFile creates a file-backed dataset with no git dependency. +// Call Reload to do the initial load and after any file change. +func NewFile[T any](path string, loadFile func(string) (*T, error)) *File[T] { + d := &File[T]{ + path: path, + loadFile: loadFile, + } + d.Store(new(T)) + return d +} + +// Reload reads the file and atomically replaces the stored value. +func (d *File[T]) Reload() error { + v, err := d.loadFile(d.path) + if err != nil { + return err + } + d.Store(v) + return nil +} + +func (d *File[T]) reloadFile() error { + return d.Reload() +} + +// reloader is the internal interface GitRepo uses to trigger file reloads. +type reloader interface { + reloadFile() error +} + +// GitRepo manages a shallow git clone and reloads all registered files +// whenever the repo is updated. Multiple files from the same repo share +// one clone and one pull, avoiding git file-lock conflicts. +type GitRepo struct { + path string + shallowRepo *gitshallow.ShallowRepo + files []reloader +} + +// NewRepo creates a GitRepo backed by the given git URL, cloning into repoPath. +func NewRepo(gitURL, repoPath string) *GitRepo { + return &GitRepo{ + path: repoPath, + shallowRepo: gitshallow.New(gitURL, repoPath, 1, ""), + } +} + +// AddFile registers a file inside this repo and returns its handle. +// relPath is relative to the repo root. The file is reloaded automatically +// whenever the repo is synced via Init or Run. +func AddFile[T any](repo *GitRepo, relPath string, loadFile func(string) (*T, error)) *File[T] { + d := NewFile(filepath.Join(repo.path, relPath), loadFile) + repo.files = append(repo.files, d) + return d +} + +// Init clones the repo if missing, syncs once, and loads all registered files. +// Always runs aggressive GC — acceptable as a one-time startup cost. +func (r *GitRepo) Init() error { + gitDir := filepath.Join(r.path, ".git") + if _, err := os.Stat(gitDir); err != nil { + if _, err := r.shallowRepo.Clone(); err != nil { + return err + } + } + _, err := r.sync(false, true) + return err +} + +// Run periodically syncs the repo and reloads files. Blocks until ctx is done. +// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal. +// Pass true to skip both when the periodic GC is too slow for your workload. +func (r *GitRepo) Run(ctx context.Context, lightGC bool) { + ticker := time.NewTicker(47 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + if updated, err := r.sync(lightGC, false); err != nil { + fmt.Fprintf(os.Stderr, "error: git repo sync: %v\n", err) + } else if updated { + fmt.Fprintf(os.Stderr, "git repo: files reloaded\n") + } + case <-ctx.Done(): + return + } + } +} + +// Sync pulls the latest commits and reloads all files if HEAD changed. +// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal. +func (r *GitRepo) Sync(lightGC bool) (bool, error) { + return r.sync(lightGC, false) +} + +func (r *GitRepo) sync(lightGC, force bool) (bool, error) { + updated, err := r.shallowRepo.Sync(lightGC) + if err != nil { + return false, fmt.Errorf("git sync: %w", err) + } + if !updated && !force { + return false, nil + } + + for _, f := range r.files { + if err := f.reloadFile(); err != nil { + fmt.Fprintf(os.Stderr, "error: reload file: %v\n", err) + } + } + return true, nil +} diff --git a/net/gitdataset/gitdataset.go b/net/gitdataset/gitdataset.go deleted file mode 100644 index fad3cae..0000000 --- a/net/gitdataset/gitdataset.go +++ /dev/null @@ -1,92 +0,0 @@ -package gitdataset - -import ( - "context" - "fmt" - "os" - "path/filepath" - "sync/atomic" - "time" - - "github.com/therootcompany/golib/net/gitshallow" -) - -// TODO maybe a GitRepo should contain GitDatasets such that loading -// multiple datasets from the same GitRepo won't cause issues with file locking? - -type GitDataset[T any] struct { - LoadFile func(path string) (*T, error) - atomic.Pointer[T] - gitRepo string - shallowRepo *gitshallow.ShallowRepo - path string -} - -func New[T any](gitURL, path string, loadFile func(path string) (*T, error)) *GitDataset[T] { - gitRepo := filepath.Dir(path) - gitDepth := 1 - gitBranch := "" - shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch) - - b := &GitDataset[T]{ - Pointer: atomic.Pointer[T]{}, - LoadFile: loadFile, - gitRepo: gitRepo, - shallowRepo: shallowRepo, - path: path, - } - b.Store(new(T)) - return b -} - -func (b *GitDataset[T]) Init(skipGC bool) (updated bool, err error) { - gitDir := filepath.Join(b.gitRepo, ".git") - if _, err := os.Stat(gitDir); err != nil { - if _, err := b.shallowRepo.Clone(); err != nil { - return false, err - } - } - - force := true - return b.reload(skipGC, force) -} - -func (b *GitDataset[T]) Run(ctx context.Context) { - ticker := time.NewTicker(47 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if ok, err := b.reload(false, false); err != nil { - fmt.Fprintf(os.Stderr, "error: git data: %v\n", err) - } else if ok { - fmt.Fprintf(os.Stderr, "git data: loaded repo\n") - } else { - fmt.Fprintf(os.Stderr, "git data: already up-to-date\n") - } - case <-ctx.Done(): - return - } - } -} - -func (b *GitDataset[T]) reload(skipGC, force bool) (updated bool, err error) { - laxGC := skipGC - lazyPrune := skipGC - updated, err = b.shallowRepo.Sync(laxGC, lazyPrune) - if err != nil { - return false, fmt.Errorf("git sync: %w", err) - } - if !updated && !force { - return false, nil - } - - nextDataset, err := b.LoadFile(b.path) - if err != nil { - return false, err - } - - _ = b.Swap(nextDataset) - return true, nil -} diff --git a/net/gitshallow/cmd/git-shallow-sync/main.go b/net/gitshallow/cmd/git-shallow-sync/main.go index c7093ed..f7389ef 100644 --- a/net/gitshallow/cmd/git-shallow-sync/main.go +++ b/net/gitshallow/cmd/git-shallow-sync/main.go @@ -19,10 +19,8 @@ import ( ) const ( - defaultDepth = 1 // shallow by default - defaultBranch = "" // empty = default branch + --single-branch - laxGC = false // false = --aggressive - lazyPrune = false // false = --prune=now + defaultDepth = 1 // shallow by default + defaultBranch = "" // empty = default branch + --single-branch ) func main() { @@ -37,7 +35,7 @@ func main() { url := os.Args[1] path := os.Args[2] - // Expand ~ to home directory for Windows + // Expand ~ to home directory if path[0] == '~' { home, err := os.UserHomeDir() if err != nil { @@ -47,7 +45,6 @@ func main() { path = filepath.Join(home, path[1:]) } - // Make path absolute absPath, err := filepath.Abs(path) if err != nil { fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err) @@ -60,14 +57,14 @@ func main() { repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch) - updated, err := repo.Sync(laxGC, lazyPrune) + updated, err := repo.Sync(false) if err != nil { fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err) os.Exit(1) } if updated { - fmt.Println("Repository was updated (new commits fetched).") + fmt.Println("Repository was updated (new commits pulled).") } else { fmt.Println("Repository is already up to date.") } diff --git a/net/gitshallow/gitshallow.go b/net/gitshallow/gitshallow.go index 8d05b07..c263ae1 100644 --- a/net/gitshallow/gitshallow.go +++ b/net/gitshallow/gitshallow.go @@ -14,11 +14,9 @@ type ShallowRepo struct { URL string Path string Depth int // 0 defaults to 1, -1 for all - Branch string // Optional: specific branch to clone/fetch - //WithBranches bool - //WithTags bool + Branch string // Optional: specific branch to clone/pull - mu sync.Mutex // Mutex for in-process locking + mu sync.Mutex } // New creates a new ShallowRepo instance. @@ -30,11 +28,11 @@ func New(url, path string, depth int, branch string) *ShallowRepo { URL: url, Path: path, Depth: depth, - Branch: strings.TrimSpace(branch), // clean up accidental whitespace + Branch: strings.TrimSpace(branch), } } -// Clone performs a shallow clone (default --depth 0 --single-branch, --no-tags, etc). +// Clone performs a shallow clone (--depth N --single-branch --no-tags). func (r *ShallowRepo) Clone() (bool, error) { r.mu.Lock() defer r.mu.Unlock() @@ -77,8 +75,7 @@ func (r *ShallowRepo) exists() bool { return err == nil } -// runGit executes a git command. -// For clone it runs in the parent directory; otherwise inside the repo. +// runGit executes a git command in the repo directory (or parent for clone). func (r *ShallowRepo) runGit(args ...string) (string, error) { cmd := exec.Command("git", args...) @@ -96,51 +93,39 @@ func (r *ShallowRepo) runGit(args ...string) (string, error) { return strings.TrimSpace(string(output)), nil } -// Fetch performs a shallow fetch and updates the working branch. -// Returns true if HEAD changed (i.e. meaningful update occurred). -// Uses --depth on fetch; branch filtering only when Branch is set. -func (r *ShallowRepo) Fetch() (updated bool, err error) { +// Pull performs a shallow pull (--ff-only) and reports whether HEAD changed. +func (r *ShallowRepo) Pull() (updated bool, err error) { r.mu.Lock() defer r.mu.Unlock() - return r.fetch() + return r.pull() } -func (r *ShallowRepo) fetch() (updated bool, err error) { +func (r *ShallowRepo) pull() (updated bool, err error) { if !r.exists() { return false, fmt.Errorf("repository does not exist at %s", r.Path) } - // Remember current HEAD - oldHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD") + oldHead, err := r.runGit("rev-parse", "HEAD") if err != nil { return false, err } - // Update local branch (git pull --ff-only is safer in shallow context) - pullArgs := []string{"-C", r.Path, "pull", "--ff-only"} - if r.Branch != "" { - pullArgs = append(pullArgs, "origin", r.Branch) - } - _, err = r.runGit(pullArgs...) - if err != nil { - return false, err - } - - // Fetch - fetchArgs := []string{"-C", r.Path, "fetch", "--no-tags"} + pullArgs := []string{"pull", "--ff-only", "--no-tags"} if r.Depth == 0 { r.Depth = 1 } if r.Depth >= 0 { - fetchArgs = append(fetchArgs, "--depth", fmt.Sprintf("%d", r.Depth)) + pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth)) } - _, err = r.runGit(fetchArgs...) - if err != nil { + if r.Branch != "" { + pullArgs = append(pullArgs, "origin", r.Branch) + } + if _, err = r.runGit(pullArgs...); err != nil { return false, err } - newHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD") + newHead, err := r.runGit("rev-parse", "HEAD") if err != nil { return false, err } @@ -148,24 +133,24 @@ func (r *ShallowRepo) fetch() (updated bool, err error) { return oldHead != newHead, nil } -// GC runs git gc, defaulting to pruning immediately and aggressively -func (r *ShallowRepo) GC(lax, lazy bool) error { +// GC runs git gc. aggressiveGC adds --aggressive; pruneNow adds --prune=now. +func (r *ShallowRepo) GC(aggressiveGC, pruneNow bool) error { r.mu.Lock() defer r.mu.Unlock() - return r.gc(lax, lazy) + return r.gc(aggressiveGC, pruneNow) } -func (r *ShallowRepo) gc(lax, lazy bool) error { +func (r *ShallowRepo) gc(aggressiveGC, pruneNow bool) error { if !r.exists() { return fmt.Errorf("repository does not exist at %s", r.Path) } - args := []string{"-C", r.Path, "gc"} - if !lax { + args := []string{"gc"} + if aggressiveGC { args = append(args, "--aggressive") } - if !lazy { + if pruneNow { args = append(args, "--prune=now") } @@ -173,27 +158,30 @@ func (r *ShallowRepo) gc(lax, lazy bool) error { return err } -// Sync clones if missing, fetches, and runs GC. -// Returns whether fetch caused an update. -func (r *ShallowRepo) Sync(laxGC, lazyPrune bool) (updated bool, err error) { +// Sync clones if missing, pulls, and runs GC. +// lightGC=false (zero value) runs --aggressive GC with --prune=now to minimize disk use. +// Pass true to skip both when speed matters more than footprint. +func (r *ShallowRepo) Sync(lightGC bool) (updated bool, err error) { r.mu.Lock() defer r.mu.Unlock() - if updated, err := r.clone(); err != nil { + if cloned, err := r.clone(); err != nil { return false, err - } else if updated { - return updated, nil + } else if cloned { + return true, nil } - if updated, err := r.fetch(); err != nil { - return updated, err - } else if !updated { + updated, err = r.pull() + if err != nil { + return false, err + } + if !updated { return false, nil } - if err := r.gc(laxGC, lazyPrune); err != nil { - return updated, fmt.Errorf("gc failed but fetch succeeded: %w", err) + if err := r.gc(!lightGC, !lightGC); err != nil { + return true, fmt.Errorf("gc failed but pull succeeded: %w", err) } - return updated, nil + return true, nil } diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index eea9f6d..68b38c2 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -3,14 +3,15 @@ package main import ( "fmt" "os" + "path/filepath" - "github.com/therootcompany/golib/net/gitdataset" + "github.com/therootcompany/golib/fs/dataset" "github.com/therootcompany/golib/net/ipcohort" ) func main() { if len(os.Args) < 3 { - fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Usage: %s [git-url]\n", os.Args[0]) os.Exit(1) } @@ -21,28 +22,31 @@ func main() { gitURL = os.Args[3] } - fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath) + var blacklist *dataset.File[ipcohort.Cohort] - var b *ipcohort.Cohort - loadFile := func(path string) (*ipcohort.Cohort, error) { - return ipcohort.LoadFile(path, false) - } - blacklist := gitdataset.New(gitURL, dataPath, loadFile) - fmt.Fprintf(os.Stderr, "Syncing git repo ...\n") - if updated, err := blacklist.Init(false); err != nil { - fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err) + if gitURL != "" { + repoDir := filepath.Dir(dataPath) + relPath := filepath.Base(dataPath) + repo := dataset.NewRepo(gitURL, repoDir) + blacklist = dataset.AddFile(repo, relPath, ipcohort.LoadFile) + fmt.Fprintf(os.Stderr, "Syncing %q ...\n", repoDir) + if err := repo.Init(); err != nil { + fmt.Fprintf(os.Stderr, "error: git sync: %v\n", err) + os.Exit(1) + } } else { - b = blacklist.Load() - if updated { - n := b.Size() - if n > 0 { - fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n) - } + blacklist = dataset.NewFile(dataPath, ipcohort.LoadFile) + fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath) + if err := blacklist.Reload(); err != nil { + fmt.Fprintf(os.Stderr, "error: load: %v\n", err) + os.Exit(1) } } - fmt.Fprintf(os.Stderr, "Checking blacklist ...\n") - if blacklist.Load().Contains(ipStr) { + c := blacklist.Load() + fmt.Fprintf(os.Stderr, "Loaded %d entries\n", c.Size()) + + if c.Contains(ipStr) { fmt.Printf("%s is BLOCKED\n", ipStr) os.Exit(1) } diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go index fb860a9..61ff414 100644 --- a/net/ipcohort/ipcohort.go +++ b/net/ipcohort/ipcohort.go @@ -9,11 +9,11 @@ import ( "net/netip" "os" "slices" - "sort" "strings" ) -// Either a subnet or single address (subnet with /32 CIDR prefix) +// IPv4Net represents a subnet or single address (/32). +// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8. type IPv4Net struct { networkBE uint32 prefix uint8 @@ -29,32 +29,81 @@ func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net { } func (r IPv4Net) Contains(ip uint32) bool { - mask := uint32(0xFFFFFFFF << (r.shift)) + mask := uint32(0xFFFFFFFF << r.shift) return (ip & mask) == r.networkBE } +// Cohort is an immutable, read-only set of IPv4 addresses and subnets. +// Contains is safe for concurrent use without locks. +// +// hosts holds sorted /32 addresses for O(log n) binary search. +// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small. +type Cohort struct { + hosts []uint32 + nets []IPv4Net +} + func New() *Cohort { - cohort := &Cohort{ranges: []IPv4Net{}} - return cohort + return &Cohort{} +} + +// Size returns the total number of entries (hosts + nets). +func (c *Cohort) Size() int { + return len(c.hosts) + len(c.nets) +} + +// Contains reports whether ipStr falls within any host or subnet in the cohort. +// Returns true on parse error (fail-closed). +func (c *Cohort) Contains(ipStr string) bool { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return true + } + ip4 := ip.As4() + ipU32 := binary.BigEndian.Uint32(ip4[:]) + + _, found := slices.BinarySearch(c.hosts, ipU32) + if found { + return true + } + + for _, net := range c.nets { + if net.Contains(ipU32) { + return true + } + } + return false } func Parse(prefixList []string) (*Cohort, error) { - var ranges []IPv4Net + var hosts []uint32 + var nets []IPv4Net + for _, raw := range prefixList { ipv4net, err := ParseIPv4(raw) if err != nil { log.Printf("skipping invalid entry: %q", raw) continue } - ranges = append(ranges, ipv4net) + if ipv4net.prefix == 32 { + hosts = append(hosts, ipv4net.networkBE) + } else { + nets = append(nets, ipv4net) + } } - sizedList := make([]IPv4Net, len(ranges)) - copy(sizedList, ranges) - sortRanges(ranges) + slices.Sort(hosts) + slices.SortFunc(nets, func(a, b IPv4Net) int { + if a.networkBE < b.networkBE { + return -1 + } + if a.networkBE > b.networkBE { + return 1 + } + return 0 + }) - cohort := &Cohort{ranges: sizedList} - return cohort, nil + return &Cohort{hosts: hosts, nets: nets}, nil } func ParseIPv4(raw string) (ipv4net IPv4Net, err error) { @@ -74,32 +123,34 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) { } ip4 := ippre.Addr().As4() - prefix := uint8(ippre.Bits()) // 0-32 + prefix := uint8(ippre.Bits()) // 0–32 return NewIPv4Net( binary.BigEndian.Uint32(ip4[:]), prefix, ), nil } -func LoadFile(path string, unsorted bool) (*Cohort, error) { +func LoadFile(path string) (*Cohort, error) { f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("could not load %q: %v", path, err) } defer f.Close() - return ParseCSV(f, unsorted) + return ParseCSV(f) } -func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) { +func ParseCSV(f io.Reader) (*Cohort, error) { r := csv.NewReader(f) r.FieldsPerRecord = -1 - return ReadAll(r, unsorted) + return ReadAll(r) } -func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { - var ranges []IPv4Net +func ReadAll(r *csv.Reader) (*Cohort, error) { + var hosts []uint32 + var nets []IPv4Net + for { record, err := r.Read() if err == io.EOF { @@ -115,7 +166,6 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { raw := strings.TrimSpace(record[0]) - // Skip comments/empty if raw == "" || strings.HasPrefix(raw, "#") { continue } @@ -130,65 +180,24 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { log.Printf("skipping invalid entry: %q", raw) continue } - ranges = append(ranges, ipv4net) + + if ipv4net.prefix == 32 { + hosts = append(hosts, ipv4net.networkBE) + } else { + nets = append(nets, ipv4net) + } } - if unsorted { - sortRanges(ranges) - } - - sizedList := make([]IPv4Net, len(ranges)) - copy(sizedList, ranges) - - cohort := &Cohort{ranges: sizedList} - return cohort, nil -} - -func sortRanges(ranges []IPv4Net) { - // Sort by network address (required for binary search) - sort.Slice(ranges, func(i, j int) bool { - // Note: we could also sort by prefix (largest first) - return ranges[i].networkBE < ranges[j].networkBE - }) - - // Note: we could also merge ranges here -} - -type Cohort struct { - ranges []IPv4Net -} - -func (c *Cohort) Size() int { - return len(c.ranges) -} - -func (c *Cohort) Contains(ipStr string) bool { - ip, err := netip.ParseAddr(ipStr) - if err != nil { - return true - } - ip4 := ip.As4() - ipU32 := binary.BigEndian.Uint32(ip4[:]) - - idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int { - if r.networkBE < target { + slices.Sort(hosts) + slices.SortFunc(nets, func(a, b IPv4Net) int { + if a.networkBE < b.networkBE { return -1 } - if r.networkBE > target { + if a.networkBE > b.networkBE { return 1 } return 0 }) - if found { - return true - } - // Check the range immediately before the insertion point - if idx > 0 { - if c.ranges[idx-1].Contains(ipU32) { - return true - } - } - - return false + return &Cohort{hosts: hosts, nets: nets}, nil }