test: ipcohort + dataset; fix ParseIPv4 panic on IPv6

- ParseIPv4 now returns an error instead of panicking on IPv6 addrs
- Add ipcohort tests: ParseIPv4, Contains (host/CIDR/mixed/fail-closed/empty), Size, LoadFile, LoadFiles, IPv6 skip
- Add dataset tests: Init, Sync (updated/no-update), error paths, Close hook, Run tick, Group (single fetch drives all loaders)
This commit is contained in:
AJ ONeal 2026-04-20 09:36:13 -06:00
parent aeb94fc26b
commit 410b52f72c
No known key found for this signature in database
3 changed files with 468 additions and 1 deletions

284
net/dataset/dataset_test.go Normal file
View File

@ -0,0 +1,284 @@
package dataset_test
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/therootcompany/golib/net/dataset"
"github.com/therootcompany/golib/net/httpcache"
)
// countSyncer counts Fetch calls and optionally reports updated.
type countSyncer struct {
calls atomic.Int32
updated bool
err error
}
func (s *countSyncer) Fetch() (bool, error) {
s.calls.Add(1)
return s.updated, s.err
}
func TestDataset_Init(t *testing.T) {
syn := &countSyncer{updated: false}
calls := 0
ds := dataset.New(syn, func() (*string, error) {
calls++
v := "hello"
return &v, nil
})
if err := ds.Init(); err != nil {
t.Fatal(err)
}
if got := ds.Load(); got == nil || *got != "hello" {
t.Fatalf("Load() = %v, want \"hello\"", got)
}
if calls != 1 {
t.Errorf("loader called %d times, want 1", calls)
}
if syn.calls.Load() != 1 {
t.Errorf("Fetch called %d times, want 1", syn.calls.Load())
}
}
func TestDataset_LoadBeforeInit(t *testing.T) {
syn := httpcache.NopSyncer{}
ds := dataset.New(syn, func() (*string, error) {
v := "x"
return &v, nil
})
if ds.Load() != nil {
t.Error("Load() before Init should return nil")
}
}
func TestDataset_SyncNoUpdate(t *testing.T) {
syn := &countSyncer{updated: false}
calls := 0
ds := dataset.New(syn, func() (*string, error) {
calls++
v := "hello"
return &v, nil
})
if err := ds.Init(); err != nil {
t.Fatal(err)
}
calls = 0
updated, err := ds.Sync()
if err != nil {
t.Fatal(err)
}
if updated {
t.Error("Sync() reported updated=true but syncer returned false")
}
if calls != 0 {
t.Errorf("loader called %d times on no-update Sync, want 0", calls)
}
}
func TestDataset_SyncWithUpdate(t *testing.T) {
syn := &countSyncer{updated: true}
n := 0
ds := dataset.New(syn, func() (*string, error) {
n++
v := "v" + string(rune('0'+n))
return &v, nil
})
if err := ds.Init(); err != nil {
t.Fatal(err)
}
updated, err := ds.Sync()
if err != nil {
t.Fatal(err)
}
if !updated {
t.Error("Sync() reported updated=false but syncer returned true")
}
if got := ds.Load(); got == nil || *got != "v2" {
t.Errorf("Load() after Sync = %v, want \"v2\"", got)
}
}
func TestDataset_InitError(t *testing.T) {
syn := &countSyncer{err: errors.New("fetch failed")}
ds := dataset.New(syn, func() (*string, error) {
v := "x"
return &v, nil
})
if err := ds.Init(); err == nil {
t.Error("expected error from Init when syncer fails")
}
if ds.Load() != nil {
t.Error("Load() should be nil after failed Init")
}
}
func TestDataset_LoaderError(t *testing.T) {
syn := httpcache.NopSyncer{}
ds := dataset.New(syn, func() (*string, error) {
return nil, errors.New("load failed")
})
if err := ds.Init(); err == nil {
t.Error("expected error from Init when loader fails")
}
}
func TestDataset_Close(t *testing.T) {
syn := &countSyncer{updated: true}
var closed []string
n := 0
ds := dataset.New(syn, func() (*string, error) {
n++
v := "v" + string(rune('0'+n))
return &v, nil
})
ds.Close = func(s *string) { closed = append(closed, *s) }
if err := ds.Init(); err != nil {
t.Fatal(err)
}
// First swap: old is nil, Close should not be called.
if len(closed) != 0 {
t.Errorf("Close called %d times on Init, want 0", len(closed))
}
if _, err := ds.Sync(); err != nil {
t.Fatal(err)
}
if len(closed) != 1 || closed[0] != "v1" {
t.Errorf("Close got %v, want [\"v1\"]", closed)
}
}
func TestDataset_Run(t *testing.T) {
syn := &countSyncer{updated: true}
n := atomic.Int32{}
ds := dataset.New(syn, func() (*int32, error) {
v := n.Add(1)
return &v, nil
})
if err := ds.Init(); err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
ds.Run(ctx, 10*time.Millisecond)
close(done)
}()
time.Sleep(60 * time.Millisecond)
cancel()
<-done
if n.Load() < 2 {
t.Errorf("Run did not tick: loader called %d times", n.Load())
}
}
// --- Group tests ---
func TestGroup_Init(t *testing.T) {
syn := &countSyncer{}
g := dataset.NewGroup(syn)
callsA, callsB := 0, 0
dsA := dataset.Add(g, func() (*string, error) {
callsA++
v := "a"
return &v, nil
})
dsB := dataset.Add(g, func() (*int, error) {
callsB++
v := 42
return &v, nil
})
if err := g.Init(); err != nil {
t.Fatal(err)
}
if syn.calls.Load() != 1 {
t.Errorf("Fetch called %d times, want 1", syn.calls.Load())
}
if callsA != 1 || callsB != 1 {
t.Errorf("loaders called (%d,%d), want (1,1)", callsA, callsB)
}
if got := dsA.Load(); got == nil || *got != "a" {
t.Errorf("dsA.Load() = %v", got)
}
if got := dsB.Load(); got == nil || *got != 42 {
t.Errorf("dsB.Load() = %v", got)
}
}
func TestGroup_SyncNoUpdate(t *testing.T) {
syn := &countSyncer{updated: false}
g := dataset.NewGroup(syn)
calls := 0
dataset.Add(g, func() (*string, error) {
calls++
v := "x"
return &v, nil
})
if err := g.Init(); err != nil {
t.Fatal(err)
}
calls = 0
updated, err := g.Sync()
if err != nil {
t.Fatal(err)
}
if updated || calls != 0 {
t.Errorf("Sync() updated=%v calls=%d, want false/0", updated, calls)
}
}
func TestGroup_SyncWithUpdate(t *testing.T) {
syn := &countSyncer{updated: true}
g := dataset.NewGroup(syn)
n := 0
ds := dataset.Add(g, func() (*int, error) {
n++
return &n, nil
})
if err := g.Init(); err != nil {
t.Fatal(err)
}
if _, err := g.Sync(); err != nil {
t.Fatal(err)
}
if got := ds.Load(); got == nil || *got != 2 {
t.Errorf("ds.Load() = %v, want 2", got)
}
}
func TestGroup_FetchError(t *testing.T) {
syn := &countSyncer{err: errors.New("network down")}
g := dataset.NewGroup(syn)
dataset.Add(g, func() (*string, error) {
v := "x"
return &v, nil
})
if err := g.Init(); err == nil {
t.Error("expected error from Group.Init when syncer fails")
}
}
func TestGroup_LoaderError(t *testing.T) {
syn := httpcache.NopSyncer{}
g := dataset.NewGroup(syn)
dataset.Add(g, func() (*string, error) {
return nil, errors.New("parse error")
})
if err := g.Init(); err == nil {
t.Error("expected error from Group.Init when loader fails")
}
}

View File

@ -122,7 +122,11 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
ippre = netip.PrefixFrom(ip, 32)
}
ip4 := ippre.Addr().As4()
addr := ippre.Addr()
if !addr.Is4() {
return ipv4net, fmt.Errorf("IPv6 not supported: %s", raw)
}
ip4 := addr.As4()
prefix := uint8(ippre.Bits()) // 032
return NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),

View File

@ -0,0 +1,179 @@
package ipcohort_test
import (
"os"
"path/filepath"
"testing"
"github.com/therootcompany/golib/net/ipcohort"
)
func TestParseIPv4(t *testing.T) {
tests := []struct {
raw string
wantErr bool
}{
{"1.2.3.4", false},
{"1.2.3.4/32", false},
{"10.0.0.0/8", false},
{"192.168.0.0/16", false},
{"0.0.0.0/0", false},
{"", true},
{"not-an-ip", true},
{"1.2.3.4/33", true},
{"::1", true}, // IPv6 not supported
}
for _, tt := range tests {
_, err := ipcohort.ParseIPv4(tt.raw)
if (err != nil) != tt.wantErr {
t.Errorf("ParseIPv4(%q): got err=%v, wantErr=%v", tt.raw, err, tt.wantErr)
}
}
}
func TestContains_SingleHosts(t *testing.T) {
c, err := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.1"})
if err != nil {
t.Fatal(err)
}
hits := []string{"1.2.3.4", "5.6.7.8", "10.0.0.1"}
misses := []string{"1.2.3.5", "5.6.7.7", "10.0.0.2", "0.0.0.0"}
for _, ip := range hits {
if !c.Contains(ip) {
t.Errorf("expected %s to be in cohort", ip)
}
}
for _, ip := range misses {
if c.Contains(ip) {
t.Errorf("expected %s NOT to be in cohort", ip)
}
}
}
func TestContains_CIDRRanges(t *testing.T) {
c, err := ipcohort.Parse([]string{"10.0.0.0/8", "192.168.1.0/24"})
if err != nil {
t.Fatal(err)
}
hits := []string{
"10.0.0.0", "10.0.0.1", "10.255.255.255",
"192.168.1.0", "192.168.1.1", "192.168.1.254", "192.168.1.255",
}
misses := []string{
"9.255.255.255", "11.0.0.0",
"192.168.0.255", "192.168.2.0",
}
for _, ip := range hits {
if !c.Contains(ip) {
t.Errorf("expected %s to be in cohort (CIDR)", ip)
}
}
for _, ip := range misses {
if c.Contains(ip) {
t.Errorf("expected %s NOT to be in cohort (CIDR)", ip)
}
}
}
func TestContains_Mixed(t *testing.T) {
c, err := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"})
if err != nil {
t.Fatal(err)
}
if !c.Contains("1.2.3.4") {
t.Error("host miss")
}
if !c.Contains("10.5.5.5") {
t.Error("CIDR miss")
}
if c.Contains("1.2.3.5") {
t.Error("false positive for host-adjacent")
}
}
func TestContains_FailClosed(t *testing.T) {
c, _ := ipcohort.Parse([]string{"1.2.3.4"})
// Unparseable input should return true (fail-closed).
if !c.Contains("not-an-ip") {
t.Error("expected fail-closed true for unparseable IP")
}
}
func TestContains_Empty(t *testing.T) {
c, err := ipcohort.Parse(nil)
if err != nil {
t.Fatal(err)
}
if c.Contains("1.2.3.4") {
t.Error("empty cohort should not contain anything")
}
}
func TestSize(t *testing.T) {
c, _ := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.0/8"})
if got, want := c.Size(), 3; got != want {
t.Errorf("Size() = %d, want %d", got, want)
}
}
func TestLoadFile(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "list.txt")
content := "# comment\n1.2.3.4\n10.0.0.0/8\n\n5.6.7.8\n"
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
t.Fatal(err)
}
c, err := ipcohort.LoadFile(path)
if err != nil {
t.Fatal(err)
}
if c.Size() != 3 {
t.Errorf("Size() = %d, want 3", c.Size())
}
if !c.Contains("1.2.3.4") {
t.Error("missing 1.2.3.4")
}
if !c.Contains("10.5.5.5") {
t.Error("missing CIDR member 10.5.5.5")
}
}
func TestLoadFiles_Merge(t *testing.T) {
dir := t.TempDir()
f1 := filepath.Join(dir, "singles.txt")
f2 := filepath.Join(dir, "networks.txt")
os.WriteFile(f1, []byte("1.2.3.4\n5.6.7.8\n"), 0o644)
os.WriteFile(f2, []byte("192.168.0.0/24\n"), 0o644)
c, err := ipcohort.LoadFiles(f1, f2)
if err != nil {
t.Fatal(err)
}
if c.Size() != 3 {
t.Errorf("Size() = %d, want 3", c.Size())
}
if !c.Contains("192.168.0.100") {
t.Error("missing merged CIDR member")
}
}
func TestCSVSkipsIPv6(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "mixed.txt")
os.WriteFile(path, []byte("1.2.3.4\n::1\n2001:db8::/32\n5.6.7.8\n"), 0o644)
c, err := ipcohort.LoadFile(path)
if err != nil {
t.Fatal(err)
}
if c.Size() != 2 {
t.Errorf("Size() = %d, want 2 (IPv6 should be skipped)", c.Size())
}
}