mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
refactor: decouple gitdataset/ipcohort for multi-file repos
gitshallow: fix double-fetch (pull already fetches), drop redundant -C flags gitdataset: split into GitDataset[T] (file+atomic) and GitRepo (git+multi-dataset) - NewDataset for file-only use, AddDataset to register with a GitRepo - one clone/fetch per repo regardless of how many datasets it has ipcohort: split Cohort into hosts (sorted /32, binary search) + nets (CIDRs, linear) - fixes false negatives when broad CIDRs (e.g. /8) precede specific entries - fixes Parse() sort-before-copy order bug - ReadAll always sorts; unsorted param removed (was dead code)
This commit is contained in:
parent
a8e108a05b
commit
8731eaf10b
134
fs/dataset/dataset.go
Normal file
134
fs/dataset/dataset.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package dataset
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/therootcompany/golib/net/gitshallow"
|
||||||
|
)
|
||||||
|
|
||||||
|
// File holds an atomically-swappable pointer to a value loaded from a file.
|
||||||
|
// Reads are lock-free. Use NewFile for file-only use, or AddFile to attach
|
||||||
|
// to a GitRepo so the value refreshes whenever the repo is updated.
|
||||||
|
type File[T any] struct {
|
||||||
|
atomic.Pointer[T]
|
||||||
|
path string
|
||||||
|
loadFile func(string) (*T, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFile creates a file-backed dataset with no git dependency.
|
||||||
|
// Call Reload to do the initial load and after any file change.
|
||||||
|
func NewFile[T any](path string, loadFile func(string) (*T, error)) *File[T] {
|
||||||
|
d := &File[T]{
|
||||||
|
path: path,
|
||||||
|
loadFile: loadFile,
|
||||||
|
}
|
||||||
|
d.Store(new(T))
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload reads the file and atomically replaces the stored value.
|
||||||
|
func (d *File[T]) Reload() error {
|
||||||
|
v, err := d.loadFile(d.path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
d.Store(v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *File[T]) reloadFile() error {
|
||||||
|
return d.Reload()
|
||||||
|
}
|
||||||
|
|
||||||
|
// reloader is the internal interface GitRepo uses to trigger file reloads.
|
||||||
|
type reloader interface {
|
||||||
|
reloadFile() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GitRepo manages a shallow git clone and reloads all registered files
|
||||||
|
// whenever the repo is updated. Multiple files from the same repo share
|
||||||
|
// one clone and one pull, avoiding git file-lock conflicts.
|
||||||
|
type GitRepo struct {
|
||||||
|
path string
|
||||||
|
shallowRepo *gitshallow.ShallowRepo
|
||||||
|
files []reloader
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRepo creates a GitRepo backed by the given git URL, cloning into repoPath.
|
||||||
|
func NewRepo(gitURL, repoPath string) *GitRepo {
|
||||||
|
return &GitRepo{
|
||||||
|
path: repoPath,
|
||||||
|
shallowRepo: gitshallow.New(gitURL, repoPath, 1, ""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFile registers a file inside this repo and returns its handle.
|
||||||
|
// relPath is relative to the repo root. The file is reloaded automatically
|
||||||
|
// whenever the repo is synced via Init or Run.
|
||||||
|
func AddFile[T any](repo *GitRepo, relPath string, loadFile func(string) (*T, error)) *File[T] {
|
||||||
|
d := NewFile(filepath.Join(repo.path, relPath), loadFile)
|
||||||
|
repo.files = append(repo.files, d)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// Init clones the repo if missing, syncs once, and loads all registered files.
|
||||||
|
// Always runs aggressive GC — acceptable as a one-time startup cost.
|
||||||
|
func (r *GitRepo) Init() error {
|
||||||
|
gitDir := filepath.Join(r.path, ".git")
|
||||||
|
if _, err := os.Stat(gitDir); err != nil {
|
||||||
|
if _, err := r.shallowRepo.Clone(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, err := r.sync(false, true)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run periodically syncs the repo and reloads files. Blocks until ctx is done.
|
||||||
|
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
|
||||||
|
// Pass true to skip both when the periodic GC is too slow for your workload.
|
||||||
|
func (r *GitRepo) Run(ctx context.Context, lightGC bool) {
|
||||||
|
ticker := time.NewTicker(47 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
if updated, err := r.sync(lightGC, false); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: git repo sync: %v\n", err)
|
||||||
|
} else if updated {
|
||||||
|
fmt.Fprintf(os.Stderr, "git repo: files reloaded\n")
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sync pulls the latest commits and reloads all files if HEAD changed.
|
||||||
|
// lightGC=false (zero value) runs aggressive GC with immediate pruning to keep footprint minimal.
|
||||||
|
func (r *GitRepo) Sync(lightGC bool) (bool, error) {
|
||||||
|
return r.sync(lightGC, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *GitRepo) sync(lightGC, force bool) (bool, error) {
|
||||||
|
updated, err := r.shallowRepo.Sync(lightGC)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("git sync: %w", err)
|
||||||
|
}
|
||||||
|
if !updated && !force {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, f := range r.files {
|
||||||
|
if err := f.reloadFile(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: reload file: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@ -1,92 +0,0 @@
|
|||||||
package gitdataset
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/therootcompany/golib/net/gitshallow"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO maybe a GitRepo should contain GitDatasets such that loading
|
|
||||||
// multiple datasets from the same GitRepo won't cause issues with file locking?
|
|
||||||
|
|
||||||
type GitDataset[T any] struct {
|
|
||||||
LoadFile func(path string) (*T, error)
|
|
||||||
atomic.Pointer[T]
|
|
||||||
gitRepo string
|
|
||||||
shallowRepo *gitshallow.ShallowRepo
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
func New[T any](gitURL, path string, loadFile func(path string) (*T, error)) *GitDataset[T] {
|
|
||||||
gitRepo := filepath.Dir(path)
|
|
||||||
gitDepth := 1
|
|
||||||
gitBranch := ""
|
|
||||||
shallowRepo := gitshallow.New(gitURL, gitRepo, gitDepth, gitBranch)
|
|
||||||
|
|
||||||
b := &GitDataset[T]{
|
|
||||||
Pointer: atomic.Pointer[T]{},
|
|
||||||
LoadFile: loadFile,
|
|
||||||
gitRepo: gitRepo,
|
|
||||||
shallowRepo: shallowRepo,
|
|
||||||
path: path,
|
|
||||||
}
|
|
||||||
b.Store(new(T))
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *GitDataset[T]) Init(skipGC bool) (updated bool, err error) {
|
|
||||||
gitDir := filepath.Join(b.gitRepo, ".git")
|
|
||||||
if _, err := os.Stat(gitDir); err != nil {
|
|
||||||
if _, err := b.shallowRepo.Clone(); err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
force := true
|
|
||||||
return b.reload(skipGC, force)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *GitDataset[T]) Run(ctx context.Context) {
|
|
||||||
ticker := time.NewTicker(47 * time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
if ok, err := b.reload(false, false); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: git data: %v\n", err)
|
|
||||||
} else if ok {
|
|
||||||
fmt.Fprintf(os.Stderr, "git data: loaded repo\n")
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(os.Stderr, "git data: already up-to-date\n")
|
|
||||||
}
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *GitDataset[T]) reload(skipGC, force bool) (updated bool, err error) {
|
|
||||||
laxGC := skipGC
|
|
||||||
lazyPrune := skipGC
|
|
||||||
updated, err = b.shallowRepo.Sync(laxGC, lazyPrune)
|
|
||||||
if err != nil {
|
|
||||||
return false, fmt.Errorf("git sync: %w", err)
|
|
||||||
}
|
|
||||||
if !updated && !force {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
nextDataset, err := b.LoadFile(b.path)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = b.Swap(nextDataset)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
@ -21,8 +21,6 @@ import (
|
|||||||
const (
|
const (
|
||||||
defaultDepth = 1 // shallow by default
|
defaultDepth = 1 // shallow by default
|
||||||
defaultBranch = "" // empty = default branch + --single-branch
|
defaultBranch = "" // empty = default branch + --single-branch
|
||||||
laxGC = false // false = --aggressive
|
|
||||||
lazyPrune = false // false = --prune=now
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@ -37,7 +35,7 @@ func main() {
|
|||||||
url := os.Args[1]
|
url := os.Args[1]
|
||||||
path := os.Args[2]
|
path := os.Args[2]
|
||||||
|
|
||||||
// Expand ~ to home directory for Windows
|
// Expand ~ to home directory
|
||||||
if path[0] == '~' {
|
if path[0] == '~' {
|
||||||
home, err := os.UserHomeDir()
|
home, err := os.UserHomeDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -47,7 +45,6 @@ func main() {
|
|||||||
path = filepath.Join(home, path[1:])
|
path = filepath.Join(home, path[1:])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make path absolute
|
|
||||||
absPath, err := filepath.Abs(path)
|
absPath, err := filepath.Abs(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err)
|
||||||
@ -60,14 +57,14 @@ func main() {
|
|||||||
|
|
||||||
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
|
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
|
||||||
|
|
||||||
updated, err := repo.Sync(laxGC, lazyPrune)
|
updated, err := repo.Sync(false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err)
|
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
if updated {
|
if updated {
|
||||||
fmt.Println("Repository was updated (new commits fetched).")
|
fmt.Println("Repository was updated (new commits pulled).")
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("Repository is already up to date.")
|
fmt.Println("Repository is already up to date.")
|
||||||
}
|
}
|
||||||
|
|||||||
@ -14,11 +14,9 @@ type ShallowRepo struct {
|
|||||||
URL string
|
URL string
|
||||||
Path string
|
Path string
|
||||||
Depth int // 0 defaults to 1, -1 for all
|
Depth int // 0 defaults to 1, -1 for all
|
||||||
Branch string // Optional: specific branch to clone/fetch
|
Branch string // Optional: specific branch to clone/pull
|
||||||
//WithBranches bool
|
|
||||||
//WithTags bool
|
|
||||||
|
|
||||||
mu sync.Mutex // Mutex for in-process locking
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new ShallowRepo instance.
|
// New creates a new ShallowRepo instance.
|
||||||
@ -30,11 +28,11 @@ func New(url, path string, depth int, branch string) *ShallowRepo {
|
|||||||
URL: url,
|
URL: url,
|
||||||
Path: path,
|
Path: path,
|
||||||
Depth: depth,
|
Depth: depth,
|
||||||
Branch: strings.TrimSpace(branch), // clean up accidental whitespace
|
Branch: strings.TrimSpace(branch),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clone performs a shallow clone (default --depth 0 --single-branch, --no-tags, etc).
|
// Clone performs a shallow clone (--depth N --single-branch --no-tags).
|
||||||
func (r *ShallowRepo) Clone() (bool, error) {
|
func (r *ShallowRepo) Clone() (bool, error) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
@ -77,8 +75,7 @@ func (r *ShallowRepo) exists() bool {
|
|||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// runGit executes a git command.
|
// runGit executes a git command in the repo directory (or parent for clone).
|
||||||
// For clone it runs in the parent directory; otherwise inside the repo.
|
|
||||||
func (r *ShallowRepo) runGit(args ...string) (string, error) {
|
func (r *ShallowRepo) runGit(args ...string) (string, error) {
|
||||||
cmd := exec.Command("git", args...)
|
cmd := exec.Command("git", args...)
|
||||||
|
|
||||||
@ -96,51 +93,39 @@ func (r *ShallowRepo) runGit(args ...string) (string, error) {
|
|||||||
return strings.TrimSpace(string(output)), nil
|
return strings.TrimSpace(string(output)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch performs a shallow fetch and updates the working branch.
|
// Pull performs a shallow pull (--ff-only) and reports whether HEAD changed.
|
||||||
// Returns true if HEAD changed (i.e. meaningful update occurred).
|
func (r *ShallowRepo) Pull() (updated bool, err error) {
|
||||||
// Uses --depth on fetch; branch filtering only when Branch is set.
|
|
||||||
func (r *ShallowRepo) Fetch() (updated bool, err error) {
|
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
return r.fetch()
|
return r.pull()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ShallowRepo) fetch() (updated bool, err error) {
|
func (r *ShallowRepo) pull() (updated bool, err error) {
|
||||||
if !r.exists() {
|
if !r.exists() {
|
||||||
return false, fmt.Errorf("repository does not exist at %s", r.Path)
|
return false, fmt.Errorf("repository does not exist at %s", r.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remember current HEAD
|
oldHead, err := r.runGit("rev-parse", "HEAD")
|
||||||
oldHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update local branch (git pull --ff-only is safer in shallow context)
|
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
|
||||||
pullArgs := []string{"-C", r.Path, "pull", "--ff-only"}
|
|
||||||
if r.Branch != "" {
|
|
||||||
pullArgs = append(pullArgs, "origin", r.Branch)
|
|
||||||
}
|
|
||||||
_, err = r.runGit(pullArgs...)
|
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fetch
|
|
||||||
fetchArgs := []string{"-C", r.Path, "fetch", "--no-tags"}
|
|
||||||
if r.Depth == 0 {
|
if r.Depth == 0 {
|
||||||
r.Depth = 1
|
r.Depth = 1
|
||||||
}
|
}
|
||||||
if r.Depth >= 0 {
|
if r.Depth >= 0 {
|
||||||
fetchArgs = append(fetchArgs, "--depth", fmt.Sprintf("%d", r.Depth))
|
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
|
||||||
}
|
}
|
||||||
_, err = r.runGit(fetchArgs...)
|
if r.Branch != "" {
|
||||||
if err != nil {
|
pullArgs = append(pullArgs, "origin", r.Branch)
|
||||||
|
}
|
||||||
|
if _, err = r.runGit(pullArgs...); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
newHead, err := r.runGit("-C", r.Path, "rev-parse", "HEAD")
|
newHead, err := r.runGit("rev-parse", "HEAD")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -148,24 +133,24 @@ func (r *ShallowRepo) fetch() (updated bool, err error) {
|
|||||||
return oldHead != newHead, nil
|
return oldHead != newHead, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GC runs git gc, defaulting to pruning immediately and aggressively
|
// GC runs git gc. aggressiveGC adds --aggressive; pruneNow adds --prune=now.
|
||||||
func (r *ShallowRepo) GC(lax, lazy bool) error {
|
func (r *ShallowRepo) GC(aggressiveGC, pruneNow bool) error {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
return r.gc(lax, lazy)
|
return r.gc(aggressiveGC, pruneNow)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *ShallowRepo) gc(lax, lazy bool) error {
|
func (r *ShallowRepo) gc(aggressiveGC, pruneNow bool) error {
|
||||||
if !r.exists() {
|
if !r.exists() {
|
||||||
return fmt.Errorf("repository does not exist at %s", r.Path)
|
return fmt.Errorf("repository does not exist at %s", r.Path)
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"-C", r.Path, "gc"}
|
args := []string{"gc"}
|
||||||
if !lax {
|
if aggressiveGC {
|
||||||
args = append(args, "--aggressive")
|
args = append(args, "--aggressive")
|
||||||
}
|
}
|
||||||
if !lazy {
|
if pruneNow {
|
||||||
args = append(args, "--prune=now")
|
args = append(args, "--prune=now")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,27 +158,30 @@ func (r *ShallowRepo) gc(lax, lazy bool) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sync clones if missing, fetches, and runs GC.
|
// Sync clones if missing, pulls, and runs GC.
|
||||||
// Returns whether fetch caused an update.
|
// lightGC=false (zero value) runs --aggressive GC with --prune=now to minimize disk use.
|
||||||
func (r *ShallowRepo) Sync(laxGC, lazyPrune bool) (updated bool, err error) {
|
// Pass true to skip both when speed matters more than footprint.
|
||||||
|
func (r *ShallowRepo) Sync(lightGC bool) (updated bool, err error) {
|
||||||
r.mu.Lock()
|
r.mu.Lock()
|
||||||
defer r.mu.Unlock()
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
if updated, err := r.clone(); err != nil {
|
if cloned, err := r.clone(); err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
} else if updated {
|
} else if cloned {
|
||||||
return updated, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if updated, err := r.fetch(); err != nil {
|
updated, err = r.pull()
|
||||||
return updated, err
|
if err != nil {
|
||||||
} else if !updated {
|
return false, err
|
||||||
|
}
|
||||||
|
if !updated {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.gc(laxGC, lazyPrune); err != nil {
|
if err := r.gc(!lightGC, !lightGC); err != nil {
|
||||||
return updated, fmt.Errorf("gc failed but fetch succeeded: %w", err)
|
return true, fmt.Errorf("gc failed but pull succeeded: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return updated, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,14 +3,15 @@ package main
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
"github.com/therootcompany/golib/net/gitdataset"
|
"github.com/therootcompany/golib/fs/dataset"
|
||||||
"github.com/therootcompany/golib/net/ipcohort"
|
"github.com/therootcompany/golib/net/ipcohort"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) < 3 {
|
if len(os.Args) < 3 {
|
||||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address>\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url]\n", os.Args[0])
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -21,28 +22,31 @@ func main() {
|
|||||||
gitURL = os.Args[3]
|
gitURL = os.Args[3]
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
var blacklist *dataset.File[ipcohort.Cohort]
|
||||||
|
|
||||||
var b *ipcohort.Cohort
|
if gitURL != "" {
|
||||||
loadFile := func(path string) (*ipcohort.Cohort, error) {
|
repoDir := filepath.Dir(dataPath)
|
||||||
return ipcohort.LoadFile(path, false)
|
relPath := filepath.Base(dataPath)
|
||||||
|
repo := dataset.NewRepo(gitURL, repoDir)
|
||||||
|
blacklist = dataset.AddFile(repo, relPath, ipcohort.LoadFile)
|
||||||
|
fmt.Fprintf(os.Stderr, "Syncing %q ...\n", repoDir)
|
||||||
|
if err := repo.Init(); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "error: git sync: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
blacklist := gitdataset.New(gitURL, dataPath, loadFile)
|
|
||||||
fmt.Fprintf(os.Stderr, "Syncing git repo ...\n")
|
|
||||||
if updated, err := blacklist.Init(false); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: ip cohort: %v\n", err)
|
|
||||||
} else {
|
} else {
|
||||||
b = blacklist.Load()
|
blacklist = dataset.NewFile(dataPath, ipcohort.LoadFile)
|
||||||
if updated {
|
fmt.Fprintf(os.Stderr, "Loading %q ...\n", dataPath)
|
||||||
n := b.Size()
|
if err := blacklist.Reload(); err != nil {
|
||||||
if n > 0 {
|
fmt.Fprintf(os.Stderr, "error: load: %v\n", err)
|
||||||
fmt.Fprintf(os.Stderr, "ip cohort: loaded %d blacklist entries\n", n)
|
os.Exit(1)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(os.Stderr, "Checking blacklist ...\n")
|
c := blacklist.Load()
|
||||||
if blacklist.Load().Contains(ipStr) {
|
fmt.Fprintf(os.Stderr, "Loaded %d entries\n", c.Size())
|
||||||
|
|
||||||
|
if c.Contains(ipStr) {
|
||||||
fmt.Printf("%s is BLOCKED\n", ipStr)
|
fmt.Printf("%s is BLOCKED\n", ipStr)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -9,11 +9,11 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Either a subnet or single address (subnet with /32 CIDR prefix)
|
// IPv4Net represents a subnet or single address (/32).
|
||||||
|
// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8.
|
||||||
type IPv4Net struct {
|
type IPv4Net struct {
|
||||||
networkBE uint32
|
networkBE uint32
|
||||||
prefix uint8
|
prefix uint8
|
||||||
@ -29,32 +29,81 @@ func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r IPv4Net) Contains(ip uint32) bool {
|
func (r IPv4Net) Contains(ip uint32) bool {
|
||||||
mask := uint32(0xFFFFFFFF << (r.shift))
|
mask := uint32(0xFFFFFFFF << r.shift)
|
||||||
return (ip & mask) == r.networkBE
|
return (ip & mask) == r.networkBE
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Cohort is an immutable, read-only set of IPv4 addresses and subnets.
|
||||||
|
// Contains is safe for concurrent use without locks.
|
||||||
|
//
|
||||||
|
// hosts holds sorted /32 addresses for O(log n) binary search.
|
||||||
|
// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small.
|
||||||
|
type Cohort struct {
|
||||||
|
hosts []uint32
|
||||||
|
nets []IPv4Net
|
||||||
|
}
|
||||||
|
|
||||||
func New() *Cohort {
|
func New() *Cohort {
|
||||||
cohort := &Cohort{ranges: []IPv4Net{}}
|
return &Cohort{}
|
||||||
return cohort
|
}
|
||||||
|
|
||||||
|
// Size returns the total number of entries (hosts + nets).
|
||||||
|
func (c *Cohort) Size() int {
|
||||||
|
return len(c.hosts) + len(c.nets)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contains reports whether ipStr falls within any host or subnet in the cohort.
|
||||||
|
// Returns true on parse error (fail-closed).
|
||||||
|
func (c *Cohort) Contains(ipStr string) bool {
|
||||||
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
ip4 := ip.As4()
|
||||||
|
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||||
|
|
||||||
|
_, found := slices.BinarySearch(c.hosts, ipU32)
|
||||||
|
if found {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, net := range c.nets {
|
||||||
|
if net.Contains(ipU32) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse(prefixList []string) (*Cohort, error) {
|
func Parse(prefixList []string) (*Cohort, error) {
|
||||||
var ranges []IPv4Net
|
var hosts []uint32
|
||||||
|
var nets []IPv4Net
|
||||||
|
|
||||||
for _, raw := range prefixList {
|
for _, raw := range prefixList {
|
||||||
ipv4net, err := ParseIPv4(raw)
|
ipv4net, err := ParseIPv4(raw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("skipping invalid entry: %q", raw)
|
log.Printf("skipping invalid entry: %q", raw)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ranges = append(ranges, ipv4net)
|
if ipv4net.prefix == 32 {
|
||||||
|
hosts = append(hosts, ipv4net.networkBE)
|
||||||
|
} else {
|
||||||
|
nets = append(nets, ipv4net)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sizedList := make([]IPv4Net, len(ranges))
|
slices.Sort(hosts)
|
||||||
copy(sizedList, ranges)
|
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
||||||
sortRanges(ranges)
|
if a.networkBE < b.networkBE {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
if a.networkBE > b.networkBE {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
})
|
||||||
|
|
||||||
cohort := &Cohort{ranges: sizedList}
|
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||||
return cohort, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
||||||
@ -74,32 +123,34 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip4 := ippre.Addr().As4()
|
ip4 := ippre.Addr().As4()
|
||||||
prefix := uint8(ippre.Bits()) // 0-32
|
prefix := uint8(ippre.Bits()) // 0–32
|
||||||
return NewIPv4Net(
|
return NewIPv4Net(
|
||||||
binary.BigEndian.Uint32(ip4[:]),
|
binary.BigEndian.Uint32(ip4[:]),
|
||||||
prefix,
|
prefix,
|
||||||
), nil
|
), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadFile(path string, unsorted bool) (*Cohort, error) {
|
func LoadFile(path string) (*Cohort, error) {
|
||||||
f, err := os.Open(path)
|
f, err := os.Open(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not load %q: %v", path, err)
|
return nil, fmt.Errorf("could not load %q: %v", path, err)
|
||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
return ParseCSV(f, unsorted)
|
return ParseCSV(f)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) {
|
func ParseCSV(f io.Reader) (*Cohort, error) {
|
||||||
r := csv.NewReader(f)
|
r := csv.NewReader(f)
|
||||||
r.FieldsPerRecord = -1
|
r.FieldsPerRecord = -1
|
||||||
|
|
||||||
return ReadAll(r, unsorted)
|
return ReadAll(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
func ReadAll(r *csv.Reader) (*Cohort, error) {
|
||||||
var ranges []IPv4Net
|
var hosts []uint32
|
||||||
|
var nets []IPv4Net
|
||||||
|
|
||||||
for {
|
for {
|
||||||
record, err := r.Read()
|
record, err := r.Read()
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
@ -115,7 +166,6 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
|||||||
|
|
||||||
raw := strings.TrimSpace(record[0])
|
raw := strings.TrimSpace(record[0])
|
||||||
|
|
||||||
// Skip comments/empty
|
|
||||||
if raw == "" || strings.HasPrefix(raw, "#") {
|
if raw == "" || strings.HasPrefix(raw, "#") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -130,65 +180,24 @@ func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) {
|
|||||||
log.Printf("skipping invalid entry: %q", raw)
|
log.Printf("skipping invalid entry: %q", raw)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ranges = append(ranges, ipv4net)
|
|
||||||
|
if ipv4net.prefix == 32 {
|
||||||
|
hosts = append(hosts, ipv4net.networkBE)
|
||||||
|
} else {
|
||||||
|
nets = append(nets, ipv4net)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if unsorted {
|
slices.Sort(hosts)
|
||||||
sortRanges(ranges)
|
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
||||||
}
|
if a.networkBE < b.networkBE {
|
||||||
|
|
||||||
sizedList := make([]IPv4Net, len(ranges))
|
|
||||||
copy(sizedList, ranges)
|
|
||||||
|
|
||||||
cohort := &Cohort{ranges: sizedList}
|
|
||||||
return cohort, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func sortRanges(ranges []IPv4Net) {
|
|
||||||
// Sort by network address (required for binary search)
|
|
||||||
sort.Slice(ranges, func(i, j int) bool {
|
|
||||||
// Note: we could also sort by prefix (largest first)
|
|
||||||
return ranges[i].networkBE < ranges[j].networkBE
|
|
||||||
})
|
|
||||||
|
|
||||||
// Note: we could also merge ranges here
|
|
||||||
}
|
|
||||||
|
|
||||||
type Cohort struct {
|
|
||||||
ranges []IPv4Net
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cohort) Size() int {
|
|
||||||
return len(c.ranges)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Cohort) Contains(ipStr string) bool {
|
|
||||||
ip, err := netip.ParseAddr(ipStr)
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
ip4 := ip.As4()
|
|
||||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
|
||||||
|
|
||||||
idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int {
|
|
||||||
if r.networkBE < target {
|
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
if r.networkBE > target {
|
if a.networkBE > b.networkBE {
|
||||||
return 1
|
return 1
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
})
|
})
|
||||||
if found {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check the range immediately before the insertion point
|
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||||
if idx > 0 {
|
|
||||||
if c.ranges[idx-1].Contains(ipU32) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user