From 896031b6a81a9468b98b1a3f1d2598c4873a3328 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 09:47:50 -0600 Subject: [PATCH] fix: idiomatic Go cleanup across net packages - gitshallow: replace in-place Depth mutation with effectiveDepth() method; remove depth normalisation from New() since it was masking the bug - ipcohort: extract sortNets() helper using cmp.Compare, eliminating 3 identical sort closures; add ContainsAddr(netip.Addr) for pre-parsed callers; guard Contains() against IPv6 panic (As4 panics on non-v4); add IPv6 test - dataset: Add() now sets NopSyncer{} so callers cannot panic by accidentally calling Init/Sync/Run on a Group-managed Dataset --- net/dataset/dataset.go | 6 ++-- net/gitshallow/gitshallow.go | 26 ++++++++-------- net/ipcohort/ipcohort.go | 56 +++++++++++++++-------------------- net/ipcohort/ipcohort_test.go | 11 +++++++ 4 files changed, 51 insertions(+), 48 deletions(-) diff --git a/net/dataset/dataset.go b/net/dataset/dataset.go index 26d481e..c14d10a 100644 --- a/net/dataset/dataset.go +++ b/net/dataset/dataset.go @@ -124,11 +124,11 @@ func NewGroup(syncer httpcache.Syncer) *Group { return &Group{syncer: syncer} } -// Add registers a new dataset in g and returns it. Call Init or Run on g — -// not on the returned dataset — to drive updates. +// Add registers a new dataset in g and returns it. Fetch and reload are driven +// by the Group — call Init/Run/Sync on g, not on the returned Dataset. // load is a closure capturing whatever paths or config it needs. func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] { - d := &Dataset[T]{load: load} + d := &Dataset[T]{syncer: httpcache.NopSyncer{}, load: load} g.members = append(g.members, d) return d } diff --git a/net/gitshallow/gitshallow.go b/net/gitshallow/gitshallow.go index 655ab45..3b2a6ed 100644 --- a/net/gitshallow/gitshallow.go +++ b/net/gitshallow/gitshallow.go @@ -28,9 +28,6 @@ type Repo struct { // New creates a new Repo instance. func New(url, path string, depth int, branch string) *Repo { - if depth == 0 { - depth = 1 - } return &Repo{ URL: url, Path: path, @@ -39,6 +36,15 @@ func New(url, path string, depth int, branch string) *Repo { } } +// effectiveDepth returns the depth to use for clone/pull. +// 0 means unset — defaults to 1. -1 means full history. +func (r *Repo) effectiveDepth() int { + if r.Depth == 0 { + return 1 + } + return r.Depth +} + // Init clones the repo if missing, then syncs once. // Returns whether anything new was fetched. func (r *Repo) Init() (bool, error) { @@ -70,11 +76,8 @@ func (r *Repo) clone() (bool, error) { } args := []string{"clone", "--no-tags"} - if r.Depth == 0 { - r.Depth = 1 - } - if r.Depth >= 0 { - args = append(args, "--depth", fmt.Sprintf("%d", r.Depth)) + if depth := r.effectiveDepth(); depth >= 0 { + args = append(args, "--depth", fmt.Sprintf("%d", depth)) } args = append(args, "--single-branch") if r.Branch != "" { @@ -125,11 +128,8 @@ func (r *Repo) pull() (updated bool, err error) { } pullArgs := []string{"pull", "--ff-only", "--no-tags"} - if r.Depth == 0 { - r.Depth = 1 - } - if r.Depth >= 0 { - pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth)) + if depth := r.effectiveDepth(); depth >= 0 { + pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", depth)) } if r.Branch != "" { pullArgs = append(pullArgs, "origin", r.Branch) diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go index 9861e68..4cef8cb 100644 --- a/net/ipcohort/ipcohort.go +++ b/net/ipcohort/ipcohort.go @@ -1,6 +1,7 @@ package ipcohort import ( + "cmp" "encoding/binary" "encoding/csv" "fmt" @@ -47,26 +48,41 @@ func New() *Cohort { return &Cohort{} } +func sortNets(nets []IPv4Net) { + slices.SortFunc(nets, func(a, b IPv4Net) int { + return cmp.Compare(a.networkBE, b.networkBE) + }) +} + // Size returns the total number of entries (hosts + nets). func (c *Cohort) Size() int { return len(c.hosts) + len(c.nets) } // Contains reports whether ipStr falls within any host or subnet in the cohort. -// Returns true on parse error (fail-closed). +// Returns true on parse error (fail-closed): unparseable input is treated as +// blocked so that garbage strings never accidentally bypass a blocklist check. +// IPv6 addresses are not stored and always return false. func (c *Cohort) Contains(ipStr string) bool { ip, err := netip.ParseAddr(ipStr) if err != nil { - return true + return true // fail-closed + } + return c.ContainsAddr(ip) +} + +// ContainsAddr reports whether ip falls within any host or subnet in the cohort. +// IPv6 addresses always return false (cohort is IPv4-only). +func (c *Cohort) ContainsAddr(ip netip.Addr) bool { + if !ip.Is4() { + return false } ip4 := ip.As4() ipU32 := binary.BigEndian.Uint32(ip4[:]) - _, found := slices.BinarySearch(c.hosts, ipU32) - if found { + if _, found := slices.BinarySearch(c.hosts, ipU32); found { return true } - for _, net := range c.nets { if net.Contains(ipU32) { return true @@ -93,15 +109,7 @@ func Parse(prefixList []string) (*Cohort, error) { } slices.Sort(hosts) - slices.SortFunc(nets, func(a, b IPv4Net) int { - if a.networkBE < b.networkBE { - return -1 - } - if a.networkBE > b.networkBE { - return 1 - } - return 0 - }) + sortNets(nets) return &Cohort{hosts: hosts, nets: nets}, nil } @@ -160,15 +168,7 @@ func LoadFiles(paths ...string) (*Cohort, error) { } slices.Sort(hosts) - slices.SortFunc(nets, func(a, b IPv4Net) int { - if a.networkBE < b.networkBE { - return -1 - } - if a.networkBE > b.networkBE { - return 1 - } - return 0 - }) + sortNets(nets) return &Cohort{hosts: hosts, nets: nets}, nil } @@ -222,15 +222,7 @@ func ReadAll(r *csv.Reader) (*Cohort, error) { } slices.Sort(hosts) - slices.SortFunc(nets, func(a, b IPv4Net) int { - if a.networkBE < b.networkBE { - return -1 - } - if a.networkBE > b.networkBE { - return 1 - } - return 0 - }) + sortNets(nets) return &Cohort{hosts: hosts, nets: nets}, nil } diff --git a/net/ipcohort/ipcohort_test.go b/net/ipcohort/ipcohort_test.go index 68b215e..eb37e35 100644 --- a/net/ipcohort/ipcohort_test.go +++ b/net/ipcohort/ipcohort_test.go @@ -104,6 +104,17 @@ func TestContains_FailClosed(t *testing.T) { } } +func TestContains_IPv6NeverBlocked(t *testing.T) { + c, _ := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"}) + // IPv6 addresses are not stored; should return false, not panic. + if c.Contains("::1") { + t.Error("IPv6 address should not be in an IPv4-only cohort") + } + if c.Contains("2001:db8::1") { + t.Error("IPv6 address should not be in an IPv4-only cohort") + } +} + func TestContains_Empty(t *testing.T) { c, err := ipcohort.Parse(nil) if err != nil {