fix: idiomatic Go cleanup across net packages

- gitshallow: replace in-place Depth mutation with effectiveDepth() method;
  remove depth normalisation from New() since it was masking the bug
- ipcohort: extract sortNets() helper using cmp.Compare, eliminating 3 identical
  sort closures; add ContainsAddr(netip.Addr) for pre-parsed callers; guard
  Contains() against IPv6 panic (As4 panics on non-v4); add IPv6 test
- dataset: Add() now sets NopSyncer{} so callers cannot panic by accidentally
  calling Init/Sync/Run on a Group-managed Dataset
This commit is contained in:
AJ ONeal 2026-04-20 09:47:50 -06:00
parent 410b52f72c
commit 896031b6a8
No known key found for this signature in database
4 changed files with 51 additions and 48 deletions

View File

@ -124,11 +124,11 @@ func NewGroup(syncer httpcache.Syncer) *Group {
return &Group{syncer: syncer} return &Group{syncer: syncer}
} }
// Add registers a new dataset in g and returns it. Call Init or Run on g — // Add registers a new dataset in g and returns it. Fetch and reload are driven
// not on the returned dataset — to drive updates. // by the Group — call Init/Run/Sync on g, not on the returned Dataset.
// load is a closure capturing whatever paths or config it needs. // load is a closure capturing whatever paths or config it needs.
func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] { func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] {
d := &Dataset[T]{load: load} d := &Dataset[T]{syncer: httpcache.NopSyncer{}, load: load}
g.members = append(g.members, d) g.members = append(g.members, d)
return d return d
} }

View File

@ -28,9 +28,6 @@ type Repo struct {
// New creates a new Repo instance. // New creates a new Repo instance.
func New(url, path string, depth int, branch string) *Repo { func New(url, path string, depth int, branch string) *Repo {
if depth == 0 {
depth = 1
}
return &Repo{ return &Repo{
URL: url, URL: url,
Path: path, Path: path,
@ -39,6 +36,15 @@ func New(url, path string, depth int, branch string) *Repo {
} }
} }
// effectiveDepth returns the depth to use for clone/pull.
// 0 means unset — defaults to 1. -1 means full history.
func (r *Repo) effectiveDepth() int {
if r.Depth == 0 {
return 1
}
return r.Depth
}
// Init clones the repo if missing, then syncs once. // Init clones the repo if missing, then syncs once.
// Returns whether anything new was fetched. // Returns whether anything new was fetched.
func (r *Repo) Init() (bool, error) { func (r *Repo) Init() (bool, error) {
@ -70,11 +76,8 @@ func (r *Repo) clone() (bool, error) {
} }
args := []string{"clone", "--no-tags"} args := []string{"clone", "--no-tags"}
if r.Depth == 0 { if depth := r.effectiveDepth(); depth >= 0 {
r.Depth = 1 args = append(args, "--depth", fmt.Sprintf("%d", depth))
}
if r.Depth >= 0 {
args = append(args, "--depth", fmt.Sprintf("%d", r.Depth))
} }
args = append(args, "--single-branch") args = append(args, "--single-branch")
if r.Branch != "" { if r.Branch != "" {
@ -125,11 +128,8 @@ func (r *Repo) pull() (updated bool, err error) {
} }
pullArgs := []string{"pull", "--ff-only", "--no-tags"} pullArgs := []string{"pull", "--ff-only", "--no-tags"}
if r.Depth == 0 { if depth := r.effectiveDepth(); depth >= 0 {
r.Depth = 1 pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", depth))
}
if r.Depth >= 0 {
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
} }
if r.Branch != "" { if r.Branch != "" {
pullArgs = append(pullArgs, "origin", r.Branch) pullArgs = append(pullArgs, "origin", r.Branch)

View File

@ -1,6 +1,7 @@
package ipcohort package ipcohort
import ( import (
"cmp"
"encoding/binary" "encoding/binary"
"encoding/csv" "encoding/csv"
"fmt" "fmt"
@ -47,26 +48,41 @@ func New() *Cohort {
return &Cohort{} return &Cohort{}
} }
func sortNets(nets []IPv4Net) {
slices.SortFunc(nets, func(a, b IPv4Net) int {
return cmp.Compare(a.networkBE, b.networkBE)
})
}
// Size returns the total number of entries (hosts + nets). // Size returns the total number of entries (hosts + nets).
func (c *Cohort) Size() int { func (c *Cohort) Size() int {
return len(c.hosts) + len(c.nets) return len(c.hosts) + len(c.nets)
} }
// Contains reports whether ipStr falls within any host or subnet in the cohort. // Contains reports whether ipStr falls within any host or subnet in the cohort.
// Returns true on parse error (fail-closed). // Returns true on parse error (fail-closed): unparseable input is treated as
// blocked so that garbage strings never accidentally bypass a blocklist check.
// IPv6 addresses are not stored and always return false.
func (c *Cohort) Contains(ipStr string) bool { func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr) ip, err := netip.ParseAddr(ipStr)
if err != nil { if err != nil {
return true return true // fail-closed
}
return c.ContainsAddr(ip)
}
// ContainsAddr reports whether ip falls within any host or subnet in the cohort.
// IPv6 addresses always return false (cohort is IPv4-only).
func (c *Cohort) ContainsAddr(ip netip.Addr) bool {
if !ip.Is4() {
return false
} }
ip4 := ip.As4() ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:]) ipU32 := binary.BigEndian.Uint32(ip4[:])
_, found := slices.BinarySearch(c.hosts, ipU32) if _, found := slices.BinarySearch(c.hosts, ipU32); found {
if found {
return true return true
} }
for _, net := range c.nets { for _, net := range c.nets {
if net.Contains(ipU32) { if net.Contains(ipU32) {
return true return true
@ -93,15 +109,7 @@ func Parse(prefixList []string) (*Cohort, error) {
} }
slices.Sort(hosts) slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int { sortNets(nets)
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
return &Cohort{hosts: hosts, nets: nets}, nil return &Cohort{hosts: hosts, nets: nets}, nil
} }
@ -160,15 +168,7 @@ func LoadFiles(paths ...string) (*Cohort, error) {
} }
slices.Sort(hosts) slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int { sortNets(nets)
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
return &Cohort{hosts: hosts, nets: nets}, nil return &Cohort{hosts: hosts, nets: nets}, nil
} }
@ -222,15 +222,7 @@ func ReadAll(r *csv.Reader) (*Cohort, error) {
} }
slices.Sort(hosts) slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int { sortNets(nets)
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
return &Cohort{hosts: hosts, nets: nets}, nil return &Cohort{hosts: hosts, nets: nets}, nil
} }

View File

@ -104,6 +104,17 @@ func TestContains_FailClosed(t *testing.T) {
} }
} }
func TestContains_IPv6NeverBlocked(t *testing.T) {
c, _ := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"})
// IPv6 addresses are not stored; should return false, not panic.
if c.Contains("::1") {
t.Error("IPv6 address should not be in an IPv4-only cohort")
}
if c.Contains("2001:db8::1") {
t.Error("IPv6 address should not be in an IPv4-only cohort")
}
}
func TestContains_Empty(t *testing.T) { func TestContains_Empty(t *testing.T) {
c, err := ipcohort.Parse(nil) c, err := ipcohort.Parse(nil)
if err != nil { if err != nil {