fix: idiomatic Go cleanup across net packages

- gitshallow: replace in-place Depth mutation with effectiveDepth() method;
  remove depth normalisation from New() since it was masking the bug
- ipcohort: extract sortNets() helper using cmp.Compare, eliminating 3 identical
  sort closures; add ContainsAddr(netip.Addr) for pre-parsed callers; guard
  Contains() against IPv6 panic (As4 panics on non-v4); add IPv6 test
- dataset: Add() now sets NopSyncer{} so callers cannot panic by accidentally
  calling Init/Sync/Run on a Group-managed Dataset
This commit is contained in:
AJ ONeal 2026-04-20 09:47:50 -06:00
parent 410b52f72c
commit 896031b6a8
No known key found for this signature in database
4 changed files with 51 additions and 48 deletions

View File

@ -124,11 +124,11 @@ func NewGroup(syncer httpcache.Syncer) *Group {
return &Group{syncer: syncer}
}
// Add registers a new dataset in g and returns it. Call Init or Run on g —
// not on the returned dataset — to drive updates.
// Add registers a new dataset in g and returns it. Fetch and reload are driven
// by the Group — call Init/Run/Sync on g, not on the returned Dataset.
// load is a closure capturing whatever paths or config it needs.
func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] {
d := &Dataset[T]{load: load}
d := &Dataset[T]{syncer: httpcache.NopSyncer{}, load: load}
g.members = append(g.members, d)
return d
}

View File

@ -28,9 +28,6 @@ type Repo struct {
// New creates a new Repo instance.
func New(url, path string, depth int, branch string) *Repo {
if depth == 0 {
depth = 1
}
return &Repo{
URL: url,
Path: path,
@ -39,6 +36,15 @@ func New(url, path string, depth int, branch string) *Repo {
}
}
// effectiveDepth returns the depth to use for clone/pull.
// 0 means unset — defaults to 1. -1 means full history.
func (r *Repo) effectiveDepth() int {
if r.Depth == 0 {
return 1
}
return r.Depth
}
// Init clones the repo if missing, then syncs once.
// Returns whether anything new was fetched.
func (r *Repo) Init() (bool, error) {
@ -70,11 +76,8 @@ func (r *Repo) clone() (bool, error) {
}
args := []string{"clone", "--no-tags"}
if r.Depth == 0 {
r.Depth = 1
}
if r.Depth >= 0 {
args = append(args, "--depth", fmt.Sprintf("%d", r.Depth))
if depth := r.effectiveDepth(); depth >= 0 {
args = append(args, "--depth", fmt.Sprintf("%d", depth))
}
args = append(args, "--single-branch")
if r.Branch != "" {
@ -125,11 +128,8 @@ func (r *Repo) pull() (updated bool, err error) {
}
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
if r.Depth == 0 {
r.Depth = 1
}
if r.Depth >= 0 {
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
if depth := r.effectiveDepth(); depth >= 0 {
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", depth))
}
if r.Branch != "" {
pullArgs = append(pullArgs, "origin", r.Branch)

View File

@ -1,6 +1,7 @@
package ipcohort
import (
"cmp"
"encoding/binary"
"encoding/csv"
"fmt"
@ -47,26 +48,41 @@ func New() *Cohort {
return &Cohort{}
}
func sortNets(nets []IPv4Net) {
slices.SortFunc(nets, func(a, b IPv4Net) int {
return cmp.Compare(a.networkBE, b.networkBE)
})
}
// Size returns the total number of entries (hosts + nets).
func (c *Cohort) Size() int {
return len(c.hosts) + len(c.nets)
}
// Contains reports whether ipStr falls within any host or subnet in the cohort.
// Returns true on parse error (fail-closed).
// Returns true on parse error (fail-closed): unparseable input is treated as
// blocked so that garbage strings never accidentally bypass a blocklist check.
// IPv6 addresses are not stored and always return false.
func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true
return true // fail-closed
}
return c.ContainsAddr(ip)
}
// ContainsAddr reports whether ip falls within any host or subnet in the cohort.
// IPv6 addresses always return false (cohort is IPv4-only).
func (c *Cohort) ContainsAddr(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
_, found := slices.BinarySearch(c.hosts, ipU32)
if found {
if _, found := slices.BinarySearch(c.hosts, ipU32); found {
return true
}
for _, net := range c.nets {
if net.Contains(ipU32) {
return true
@ -93,15 +109,7 @@ func Parse(prefixList []string) (*Cohort, error) {
}
slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int {
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
@ -160,15 +168,7 @@ func LoadFiles(paths ...string) (*Cohort, error) {
}
slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int {
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
@ -222,15 +222,7 @@ func ReadAll(r *csv.Reader) (*Cohort, error) {
}
slices.Sort(hosts)
slices.SortFunc(nets, func(a, b IPv4Net) int {
if a.networkBE < b.networkBE {
return -1
}
if a.networkBE > b.networkBE {
return 1
}
return 0
})
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}

View File

@ -104,6 +104,17 @@ func TestContains_FailClosed(t *testing.T) {
}
}
func TestContains_IPv6NeverBlocked(t *testing.T) {
c, _ := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"})
// IPv6 addresses are not stored; should return false, not panic.
if c.Contains("::1") {
t.Error("IPv6 address should not be in an IPv4-only cohort")
}
if c.Contains("2001:db8::1") {
t.Error("IPv6 address should not be in an IPv4-only cohort")
}
}
func TestContains_Empty(t *testing.T) {
c, err := ipcohort.Parse(nil)
if err != nil {