mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
fix: idiomatic Go cleanup across net packages
- gitshallow: replace in-place Depth mutation with effectiveDepth() method;
remove depth normalisation from New() since it was masking the bug
- ipcohort: extract sortNets() helper using cmp.Compare, eliminating 3 identical
sort closures; add ContainsAddr(netip.Addr) for pre-parsed callers; guard
Contains() against IPv6 panic (As4 panics on non-v4); add IPv6 test
- dataset: Add() now sets NopSyncer{} so callers cannot panic by accidentally
calling Init/Sync/Run on a Group-managed Dataset
This commit is contained in:
parent
410b52f72c
commit
896031b6a8
@ -124,11 +124,11 @@ func NewGroup(syncer httpcache.Syncer) *Group {
|
|||||||
return &Group{syncer: syncer}
|
return &Group{syncer: syncer}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add registers a new dataset in g and returns it. Call Init or Run on g —
|
// Add registers a new dataset in g and returns it. Fetch and reload are driven
|
||||||
// not on the returned dataset — to drive updates.
|
// by the Group — call Init/Run/Sync on g, not on the returned Dataset.
|
||||||
// load is a closure capturing whatever paths or config it needs.
|
// load is a closure capturing whatever paths or config it needs.
|
||||||
func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] {
|
func Add[T any](g *Group, load func() (*T, error)) *Dataset[T] {
|
||||||
d := &Dataset[T]{load: load}
|
d := &Dataset[T]{syncer: httpcache.NopSyncer{}, load: load}
|
||||||
g.members = append(g.members, d)
|
g.members = append(g.members, d)
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,9 +28,6 @@ type Repo struct {
|
|||||||
|
|
||||||
// New creates a new Repo instance.
|
// New creates a new Repo instance.
|
||||||
func New(url, path string, depth int, branch string) *Repo {
|
func New(url, path string, depth int, branch string) *Repo {
|
||||||
if depth == 0 {
|
|
||||||
depth = 1
|
|
||||||
}
|
|
||||||
return &Repo{
|
return &Repo{
|
||||||
URL: url,
|
URL: url,
|
||||||
Path: path,
|
Path: path,
|
||||||
@ -39,6 +36,15 @@ func New(url, path string, depth int, branch string) *Repo {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// effectiveDepth returns the depth to use for clone/pull.
|
||||||
|
// 0 means unset — defaults to 1. -1 means full history.
|
||||||
|
func (r *Repo) effectiveDepth() int {
|
||||||
|
if r.Depth == 0 {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return r.Depth
|
||||||
|
}
|
||||||
|
|
||||||
// Init clones the repo if missing, then syncs once.
|
// Init clones the repo if missing, then syncs once.
|
||||||
// Returns whether anything new was fetched.
|
// Returns whether anything new was fetched.
|
||||||
func (r *Repo) Init() (bool, error) {
|
func (r *Repo) Init() (bool, error) {
|
||||||
@ -70,11 +76,8 @@ func (r *Repo) clone() (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
args := []string{"clone", "--no-tags"}
|
args := []string{"clone", "--no-tags"}
|
||||||
if r.Depth == 0 {
|
if depth := r.effectiveDepth(); depth >= 0 {
|
||||||
r.Depth = 1
|
args = append(args, "--depth", fmt.Sprintf("%d", depth))
|
||||||
}
|
|
||||||
if r.Depth >= 0 {
|
|
||||||
args = append(args, "--depth", fmt.Sprintf("%d", r.Depth))
|
|
||||||
}
|
}
|
||||||
args = append(args, "--single-branch")
|
args = append(args, "--single-branch")
|
||||||
if r.Branch != "" {
|
if r.Branch != "" {
|
||||||
@ -125,11 +128,8 @@ func (r *Repo) pull() (updated bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
|
pullArgs := []string{"pull", "--ff-only", "--no-tags"}
|
||||||
if r.Depth == 0 {
|
if depth := r.effectiveDepth(); depth >= 0 {
|
||||||
r.Depth = 1
|
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", depth))
|
||||||
}
|
|
||||||
if r.Depth >= 0 {
|
|
||||||
pullArgs = append(pullArgs, "--depth", fmt.Sprintf("%d", r.Depth))
|
|
||||||
}
|
}
|
||||||
if r.Branch != "" {
|
if r.Branch != "" {
|
||||||
pullArgs = append(pullArgs, "origin", r.Branch)
|
pullArgs = append(pullArgs, "origin", r.Branch)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package ipcohort
|
package ipcohort
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/csv"
|
"encoding/csv"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -47,26 +48,41 @@ func New() *Cohort {
|
|||||||
return &Cohort{}
|
return &Cohort{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sortNets(nets []IPv4Net) {
|
||||||
|
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
||||||
|
return cmp.Compare(a.networkBE, b.networkBE)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Size returns the total number of entries (hosts + nets).
|
// Size returns the total number of entries (hosts + nets).
|
||||||
func (c *Cohort) Size() int {
|
func (c *Cohort) Size() int {
|
||||||
return len(c.hosts) + len(c.nets)
|
return len(c.hosts) + len(c.nets)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Contains reports whether ipStr falls within any host or subnet in the cohort.
|
// Contains reports whether ipStr falls within any host or subnet in the cohort.
|
||||||
// Returns true on parse error (fail-closed).
|
// Returns true on parse error (fail-closed): unparseable input is treated as
|
||||||
|
// blocked so that garbage strings never accidentally bypass a blocklist check.
|
||||||
|
// IPv6 addresses are not stored and always return false.
|
||||||
func (c *Cohort) Contains(ipStr string) bool {
|
func (c *Cohort) Contains(ipStr string) bool {
|
||||||
ip, err := netip.ParseAddr(ipStr)
|
ip, err := netip.ParseAddr(ipStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return true
|
return true // fail-closed
|
||||||
|
}
|
||||||
|
return c.ContainsAddr(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContainsAddr reports whether ip falls within any host or subnet in the cohort.
|
||||||
|
// IPv6 addresses always return false (cohort is IPv4-only).
|
||||||
|
func (c *Cohort) ContainsAddr(ip netip.Addr) bool {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
ip4 := ip.As4()
|
ip4 := ip.As4()
|
||||||
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
ipU32 := binary.BigEndian.Uint32(ip4[:])
|
||||||
|
|
||||||
_, found := slices.BinarySearch(c.hosts, ipU32)
|
if _, found := slices.BinarySearch(c.hosts, ipU32); found {
|
||||||
if found {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, net := range c.nets {
|
for _, net := range c.nets {
|
||||||
if net.Contains(ipU32) {
|
if net.Contains(ipU32) {
|
||||||
return true
|
return true
|
||||||
@ -93,15 +109,7 @@ func Parse(prefixList []string) (*Cohort, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(hosts)
|
slices.Sort(hosts)
|
||||||
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
sortNets(nets)
|
||||||
if a.networkBE < b.networkBE {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if a.networkBE > b.networkBE {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
|
|
||||||
return &Cohort{hosts: hosts, nets: nets}, nil
|
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||||
}
|
}
|
||||||
@ -160,15 +168,7 @@ func LoadFiles(paths ...string) (*Cohort, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(hosts)
|
slices.Sort(hosts)
|
||||||
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
sortNets(nets)
|
||||||
if a.networkBE < b.networkBE {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if a.networkBE > b.networkBE {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
|
|
||||||
return &Cohort{hosts: hosts, nets: nets}, nil
|
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||||
}
|
}
|
||||||
@ -222,15 +222,7 @@ func ReadAll(r *csv.Reader) (*Cohort, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
slices.Sort(hosts)
|
slices.Sort(hosts)
|
||||||
slices.SortFunc(nets, func(a, b IPv4Net) int {
|
sortNets(nets)
|
||||||
if a.networkBE < b.networkBE {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
if a.networkBE > b.networkBE {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
})
|
|
||||||
|
|
||||||
return &Cohort{hosts: hosts, nets: nets}, nil
|
return &Cohort{hosts: hosts, nets: nets}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -104,6 +104,17 @@ func TestContains_FailClosed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestContains_IPv6NeverBlocked(t *testing.T) {
|
||||||
|
c, _ := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"})
|
||||||
|
// IPv6 addresses are not stored; should return false, not panic.
|
||||||
|
if c.Contains("::1") {
|
||||||
|
t.Error("IPv6 address should not be in an IPv4-only cohort")
|
||||||
|
}
|
||||||
|
if c.Contains("2001:db8::1") {
|
||||||
|
t.Error("IPv6 address should not be in an IPv4-only cohort")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestContains_Empty(t *testing.T) {
|
func TestContains_Empty(t *testing.T) {
|
||||||
c, err := ipcohort.Parse(nil)
|
c, err := ipcohort.Parse(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user