From 410b52f72c7253367663970fc236e7e5f9ee4f91 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 09:36:13 -0600 Subject: [PATCH] test: ipcohort + dataset; fix ParseIPv4 panic on IPv6 - ParseIPv4 now returns an error instead of panicking on IPv6 addrs - Add ipcohort tests: ParseIPv4, Contains (host/CIDR/mixed/fail-closed/empty), Size, LoadFile, LoadFiles, IPv6 skip - Add dataset tests: Init, Sync (updated/no-update), error paths, Close hook, Run tick, Group (single fetch drives all loaders) --- net/dataset/dataset_test.go | 284 ++++++++++++++++++++++++++++++++++ net/ipcohort/ipcohort.go | 6 +- net/ipcohort/ipcohort_test.go | 179 +++++++++++++++++++++ 3 files changed, 468 insertions(+), 1 deletion(-) create mode 100644 net/dataset/dataset_test.go create mode 100644 net/ipcohort/ipcohort_test.go diff --git a/net/dataset/dataset_test.go b/net/dataset/dataset_test.go new file mode 100644 index 0000000..d06af39 --- /dev/null +++ b/net/dataset/dataset_test.go @@ -0,0 +1,284 @@ +package dataset_test + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/therootcompany/golib/net/dataset" + "github.com/therootcompany/golib/net/httpcache" +) + +// countSyncer counts Fetch calls and optionally reports updated. +type countSyncer struct { + calls atomic.Int32 + updated bool + err error +} + +func (s *countSyncer) Fetch() (bool, error) { + s.calls.Add(1) + return s.updated, s.err +} + +func TestDataset_Init(t *testing.T) { + syn := &countSyncer{updated: false} + calls := 0 + ds := dataset.New(syn, func() (*string, error) { + calls++ + v := "hello" + return &v, nil + }) + + if err := ds.Init(); err != nil { + t.Fatal(err) + } + if got := ds.Load(); got == nil || *got != "hello" { + t.Fatalf("Load() = %v, want \"hello\"", got) + } + if calls != 1 { + t.Errorf("loader called %d times, want 1", calls) + } + if syn.calls.Load() != 1 { + t.Errorf("Fetch called %d times, want 1", syn.calls.Load()) + } +} + +func TestDataset_LoadBeforeInit(t *testing.T) { + syn := httpcache.NopSyncer{} + ds := dataset.New(syn, func() (*string, error) { + v := "x" + return &v, nil + }) + if ds.Load() != nil { + t.Error("Load() before Init should return nil") + } +} + +func TestDataset_SyncNoUpdate(t *testing.T) { + syn := &countSyncer{updated: false} + calls := 0 + ds := dataset.New(syn, func() (*string, error) { + calls++ + v := "hello" + return &v, nil + }) + if err := ds.Init(); err != nil { + t.Fatal(err) + } + calls = 0 + + updated, err := ds.Sync() + if err != nil { + t.Fatal(err) + } + if updated { + t.Error("Sync() reported updated=true but syncer returned false") + } + if calls != 0 { + t.Errorf("loader called %d times on no-update Sync, want 0", calls) + } +} + +func TestDataset_SyncWithUpdate(t *testing.T) { + syn := &countSyncer{updated: true} + n := 0 + ds := dataset.New(syn, func() (*string, error) { + n++ + v := "v" + string(rune('0'+n)) + return &v, nil + }) + if err := ds.Init(); err != nil { + t.Fatal(err) + } + updated, err := ds.Sync() + if err != nil { + t.Fatal(err) + } + if !updated { + t.Error("Sync() reported updated=false but syncer returned true") + } + if got := ds.Load(); got == nil || *got != "v2" { + t.Errorf("Load() after Sync = %v, want \"v2\"", got) + } +} + +func TestDataset_InitError(t *testing.T) { + syn := &countSyncer{err: errors.New("fetch failed")} + ds := dataset.New(syn, func() (*string, error) { + v := "x" + return &v, nil + }) + if err := ds.Init(); err == nil { + t.Error("expected error from Init when syncer fails") + } + if ds.Load() != nil { + t.Error("Load() should be nil after failed Init") + } +} + +func TestDataset_LoaderError(t *testing.T) { + syn := httpcache.NopSyncer{} + ds := dataset.New(syn, func() (*string, error) { + return nil, errors.New("load failed") + }) + if err := ds.Init(); err == nil { + t.Error("expected error from Init when loader fails") + } +} + +func TestDataset_Close(t *testing.T) { + syn := &countSyncer{updated: true} + var closed []string + n := 0 + ds := dataset.New(syn, func() (*string, error) { + n++ + v := "v" + string(rune('0'+n)) + return &v, nil + }) + ds.Close = func(s *string) { closed = append(closed, *s) } + + if err := ds.Init(); err != nil { + t.Fatal(err) + } + // First swap: old is nil, Close should not be called. + if len(closed) != 0 { + t.Errorf("Close called %d times on Init, want 0", len(closed)) + } + + if _, err := ds.Sync(); err != nil { + t.Fatal(err) + } + if len(closed) != 1 || closed[0] != "v1" { + t.Errorf("Close got %v, want [\"v1\"]", closed) + } +} + +func TestDataset_Run(t *testing.T) { + syn := &countSyncer{updated: true} + n := atomic.Int32{} + ds := dataset.New(syn, func() (*int32, error) { + v := n.Add(1) + return &v, nil + }) + if err := ds.Init(); err != nil { + t.Fatal(err) + } + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + ds.Run(ctx, 10*time.Millisecond) + close(done) + }() + + time.Sleep(60 * time.Millisecond) + cancel() + <-done + + if n.Load() < 2 { + t.Errorf("Run did not tick: loader called %d times", n.Load()) + } +} + +// --- Group tests --- + +func TestGroup_Init(t *testing.T) { + syn := &countSyncer{} + g := dataset.NewGroup(syn) + + callsA, callsB := 0, 0 + dsA := dataset.Add(g, func() (*string, error) { + callsA++ + v := "a" + return &v, nil + }) + dsB := dataset.Add(g, func() (*int, error) { + callsB++ + v := 42 + return &v, nil + }) + + if err := g.Init(); err != nil { + t.Fatal(err) + } + if syn.calls.Load() != 1 { + t.Errorf("Fetch called %d times, want 1", syn.calls.Load()) + } + if callsA != 1 || callsB != 1 { + t.Errorf("loaders called (%d,%d), want (1,1)", callsA, callsB) + } + if got := dsA.Load(); got == nil || *got != "a" { + t.Errorf("dsA.Load() = %v", got) + } + if got := dsB.Load(); got == nil || *got != 42 { + t.Errorf("dsB.Load() = %v", got) + } +} + +func TestGroup_SyncNoUpdate(t *testing.T) { + syn := &countSyncer{updated: false} + g := dataset.NewGroup(syn) + calls := 0 + dataset.Add(g, func() (*string, error) { + calls++ + v := "x" + return &v, nil + }) + if err := g.Init(); err != nil { + t.Fatal(err) + } + calls = 0 + + updated, err := g.Sync() + if err != nil { + t.Fatal(err) + } + if updated || calls != 0 { + t.Errorf("Sync() updated=%v calls=%d, want false/0", updated, calls) + } +} + +func TestGroup_SyncWithUpdate(t *testing.T) { + syn := &countSyncer{updated: true} + g := dataset.NewGroup(syn) + n := 0 + ds := dataset.Add(g, func() (*int, error) { + n++ + return &n, nil + }) + if err := g.Init(); err != nil { + t.Fatal(err) + } + if _, err := g.Sync(); err != nil { + t.Fatal(err) + } + if got := ds.Load(); got == nil || *got != 2 { + t.Errorf("ds.Load() = %v, want 2", got) + } +} + +func TestGroup_FetchError(t *testing.T) { + syn := &countSyncer{err: errors.New("network down")} + g := dataset.NewGroup(syn) + dataset.Add(g, func() (*string, error) { + v := "x" + return &v, nil + }) + if err := g.Init(); err == nil { + t.Error("expected error from Group.Init when syncer fails") + } +} + +func TestGroup_LoaderError(t *testing.T) { + syn := httpcache.NopSyncer{} + g := dataset.NewGroup(syn) + dataset.Add(g, func() (*string, error) { + return nil, errors.New("parse error") + }) + if err := g.Init(); err == nil { + t.Error("expected error from Group.Init when loader fails") + } +} diff --git a/net/ipcohort/ipcohort.go b/net/ipcohort/ipcohort.go index 8887466..9861e68 100644 --- a/net/ipcohort/ipcohort.go +++ b/net/ipcohort/ipcohort.go @@ -122,7 +122,11 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) { ippre = netip.PrefixFrom(ip, 32) } - ip4 := ippre.Addr().As4() + addr := ippre.Addr() + if !addr.Is4() { + return ipv4net, fmt.Errorf("IPv6 not supported: %s", raw) + } + ip4 := addr.As4() prefix := uint8(ippre.Bits()) // 0–32 return NewIPv4Net( binary.BigEndian.Uint32(ip4[:]), diff --git a/net/ipcohort/ipcohort_test.go b/net/ipcohort/ipcohort_test.go new file mode 100644 index 0000000..68b215e --- /dev/null +++ b/net/ipcohort/ipcohort_test.go @@ -0,0 +1,179 @@ +package ipcohort_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/therootcompany/golib/net/ipcohort" +) + +func TestParseIPv4(t *testing.T) { + tests := []struct { + raw string + wantErr bool + }{ + {"1.2.3.4", false}, + {"1.2.3.4/32", false}, + {"10.0.0.0/8", false}, + {"192.168.0.0/16", false}, + {"0.0.0.0/0", false}, + {"", true}, + {"not-an-ip", true}, + {"1.2.3.4/33", true}, + {"::1", true}, // IPv6 not supported + } + for _, tt := range tests { + _, err := ipcohort.ParseIPv4(tt.raw) + if (err != nil) != tt.wantErr { + t.Errorf("ParseIPv4(%q): got err=%v, wantErr=%v", tt.raw, err, tt.wantErr) + } + } +} + +func TestContains_SingleHosts(t *testing.T) { + c, err := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.1"}) + if err != nil { + t.Fatal(err) + } + + hits := []string{"1.2.3.4", "5.6.7.8", "10.0.0.1"} + misses := []string{"1.2.3.5", "5.6.7.7", "10.0.0.2", "0.0.0.0"} + + for _, ip := range hits { + if !c.Contains(ip) { + t.Errorf("expected %s to be in cohort", ip) + } + } + for _, ip := range misses { + if c.Contains(ip) { + t.Errorf("expected %s NOT to be in cohort", ip) + } + } +} + +func TestContains_CIDRRanges(t *testing.T) { + c, err := ipcohort.Parse([]string{"10.0.0.0/8", "192.168.1.0/24"}) + if err != nil { + t.Fatal(err) + } + + hits := []string{ + "10.0.0.0", "10.0.0.1", "10.255.255.255", + "192.168.1.0", "192.168.1.1", "192.168.1.254", "192.168.1.255", + } + misses := []string{ + "9.255.255.255", "11.0.0.0", + "192.168.0.255", "192.168.2.0", + } + + for _, ip := range hits { + if !c.Contains(ip) { + t.Errorf("expected %s to be in cohort (CIDR)", ip) + } + } + for _, ip := range misses { + if c.Contains(ip) { + t.Errorf("expected %s NOT to be in cohort (CIDR)", ip) + } + } +} + +func TestContains_Mixed(t *testing.T) { + c, err := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"}) + if err != nil { + t.Fatal(err) + } + + if !c.Contains("1.2.3.4") { + t.Error("host miss") + } + if !c.Contains("10.5.5.5") { + t.Error("CIDR miss") + } + if c.Contains("1.2.3.5") { + t.Error("false positive for host-adjacent") + } +} + +func TestContains_FailClosed(t *testing.T) { + c, _ := ipcohort.Parse([]string{"1.2.3.4"}) + // Unparseable input should return true (fail-closed). + if !c.Contains("not-an-ip") { + t.Error("expected fail-closed true for unparseable IP") + } +} + +func TestContains_Empty(t *testing.T) { + c, err := ipcohort.Parse(nil) + if err != nil { + t.Fatal(err) + } + if c.Contains("1.2.3.4") { + t.Error("empty cohort should not contain anything") + } +} + +func TestSize(t *testing.T) { + c, _ := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.0/8"}) + if got, want := c.Size(), 3; got != want { + t.Errorf("Size() = %d, want %d", got, want) + } +} + +func TestLoadFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "list.txt") + content := "# comment\n1.2.3.4\n10.0.0.0/8\n\n5.6.7.8\n" + if err := os.WriteFile(path, []byte(content), 0o644); err != nil { + t.Fatal(err) + } + + c, err := ipcohort.LoadFile(path) + if err != nil { + t.Fatal(err) + } + if c.Size() != 3 { + t.Errorf("Size() = %d, want 3", c.Size()) + } + if !c.Contains("1.2.3.4") { + t.Error("missing 1.2.3.4") + } + if !c.Contains("10.5.5.5") { + t.Error("missing CIDR member 10.5.5.5") + } +} + +func TestLoadFiles_Merge(t *testing.T) { + dir := t.TempDir() + + f1 := filepath.Join(dir, "singles.txt") + f2 := filepath.Join(dir, "networks.txt") + os.WriteFile(f1, []byte("1.2.3.4\n5.6.7.8\n"), 0o644) + os.WriteFile(f2, []byte("192.168.0.0/24\n"), 0o644) + + c, err := ipcohort.LoadFiles(f1, f2) + if err != nil { + t.Fatal(err) + } + if c.Size() != 3 { + t.Errorf("Size() = %d, want 3", c.Size()) + } + if !c.Contains("192.168.0.100") { + t.Error("missing merged CIDR member") + } +} + +func TestCSVSkipsIPv6(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "mixed.txt") + os.WriteFile(path, []byte("1.2.3.4\n::1\n2001:db8::/32\n5.6.7.8\n"), 0o644) + + c, err := ipcohort.LoadFile(path) + if err != nil { + t.Fatal(err) + } + if c.Size() != 2 { + t.Errorf("Size() = %d, want 2 (IPv6 should be skipped)", c.Size()) + } +}