mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
test: ipcohort + dataset; fix ParseIPv4 panic on IPv6
- ParseIPv4 now returns an error instead of panicking on IPv6 addrs - Add ipcohort tests: ParseIPv4, Contains (host/CIDR/mixed/fail-closed/empty), Size, LoadFile, LoadFiles, IPv6 skip - Add dataset tests: Init, Sync (updated/no-update), error paths, Close hook, Run tick, Group (single fetch drives all loaders)
This commit is contained in:
parent
aeb94fc26b
commit
410b52f72c
284
net/dataset/dataset_test.go
Normal file
284
net/dataset/dataset_test.go
Normal file
@ -0,0 +1,284 @@
|
||||
package dataset_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/dataset"
|
||||
"github.com/therootcompany/golib/net/httpcache"
|
||||
)
|
||||
|
||||
// countSyncer counts Fetch calls and optionally reports updated.
|
||||
type countSyncer struct {
|
||||
calls atomic.Int32
|
||||
updated bool
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *countSyncer) Fetch() (bool, error) {
|
||||
s.calls.Add(1)
|
||||
return s.updated, s.err
|
||||
}
|
||||
|
||||
func TestDataset_Init(t *testing.T) {
|
||||
syn := &countSyncer{updated: false}
|
||||
calls := 0
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
calls++
|
||||
v := "hello"
|
||||
return &v, nil
|
||||
})
|
||||
|
||||
if err := ds.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := ds.Load(); got == nil || *got != "hello" {
|
||||
t.Fatalf("Load() = %v, want \"hello\"", got)
|
||||
}
|
||||
if calls != 1 {
|
||||
t.Errorf("loader called %d times, want 1", calls)
|
||||
}
|
||||
if syn.calls.Load() != 1 {
|
||||
t.Errorf("Fetch called %d times, want 1", syn.calls.Load())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_LoadBeforeInit(t *testing.T) {
|
||||
syn := httpcache.NopSyncer{}
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
v := "x"
|
||||
return &v, nil
|
||||
})
|
||||
if ds.Load() != nil {
|
||||
t.Error("Load() before Init should return nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_SyncNoUpdate(t *testing.T) {
|
||||
syn := &countSyncer{updated: false}
|
||||
calls := 0
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
calls++
|
||||
v := "hello"
|
||||
return &v, nil
|
||||
})
|
||||
if err := ds.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
calls = 0
|
||||
|
||||
updated, err := ds.Sync()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated {
|
||||
t.Error("Sync() reported updated=true but syncer returned false")
|
||||
}
|
||||
if calls != 0 {
|
||||
t.Errorf("loader called %d times on no-update Sync, want 0", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_SyncWithUpdate(t *testing.T) {
|
||||
syn := &countSyncer{updated: true}
|
||||
n := 0
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
n++
|
||||
v := "v" + string(rune('0'+n))
|
||||
return &v, nil
|
||||
})
|
||||
if err := ds.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
updated, err := ds.Sync()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !updated {
|
||||
t.Error("Sync() reported updated=false but syncer returned true")
|
||||
}
|
||||
if got := ds.Load(); got == nil || *got != "v2" {
|
||||
t.Errorf("Load() after Sync = %v, want \"v2\"", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_InitError(t *testing.T) {
|
||||
syn := &countSyncer{err: errors.New("fetch failed")}
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
v := "x"
|
||||
return &v, nil
|
||||
})
|
||||
if err := ds.Init(); err == nil {
|
||||
t.Error("expected error from Init when syncer fails")
|
||||
}
|
||||
if ds.Load() != nil {
|
||||
t.Error("Load() should be nil after failed Init")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_LoaderError(t *testing.T) {
|
||||
syn := httpcache.NopSyncer{}
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
return nil, errors.New("load failed")
|
||||
})
|
||||
if err := ds.Init(); err == nil {
|
||||
t.Error("expected error from Init when loader fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_Close(t *testing.T) {
|
||||
syn := &countSyncer{updated: true}
|
||||
var closed []string
|
||||
n := 0
|
||||
ds := dataset.New(syn, func() (*string, error) {
|
||||
n++
|
||||
v := "v" + string(rune('0'+n))
|
||||
return &v, nil
|
||||
})
|
||||
ds.Close = func(s *string) { closed = append(closed, *s) }
|
||||
|
||||
if err := ds.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// First swap: old is nil, Close should not be called.
|
||||
if len(closed) != 0 {
|
||||
t.Errorf("Close called %d times on Init, want 0", len(closed))
|
||||
}
|
||||
|
||||
if _, err := ds.Sync(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(closed) != 1 || closed[0] != "v1" {
|
||||
t.Errorf("Close got %v, want [\"v1\"]", closed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDataset_Run(t *testing.T) {
|
||||
syn := &countSyncer{updated: true}
|
||||
n := atomic.Int32{}
|
||||
ds := dataset.New(syn, func() (*int32, error) {
|
||||
v := n.Add(1)
|
||||
return &v, nil
|
||||
})
|
||||
if err := ds.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
ds.Run(ctx, 10*time.Millisecond)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
cancel()
|
||||
<-done
|
||||
|
||||
if n.Load() < 2 {
|
||||
t.Errorf("Run did not tick: loader called %d times", n.Load())
|
||||
}
|
||||
}
|
||||
|
||||
// --- Group tests ---
|
||||
|
||||
func TestGroup_Init(t *testing.T) {
|
||||
syn := &countSyncer{}
|
||||
g := dataset.NewGroup(syn)
|
||||
|
||||
callsA, callsB := 0, 0
|
||||
dsA := dataset.Add(g, func() (*string, error) {
|
||||
callsA++
|
||||
v := "a"
|
||||
return &v, nil
|
||||
})
|
||||
dsB := dataset.Add(g, func() (*int, error) {
|
||||
callsB++
|
||||
v := 42
|
||||
return &v, nil
|
||||
})
|
||||
|
||||
if err := g.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if syn.calls.Load() != 1 {
|
||||
t.Errorf("Fetch called %d times, want 1", syn.calls.Load())
|
||||
}
|
||||
if callsA != 1 || callsB != 1 {
|
||||
t.Errorf("loaders called (%d,%d), want (1,1)", callsA, callsB)
|
||||
}
|
||||
if got := dsA.Load(); got == nil || *got != "a" {
|
||||
t.Errorf("dsA.Load() = %v", got)
|
||||
}
|
||||
if got := dsB.Load(); got == nil || *got != 42 {
|
||||
t.Errorf("dsB.Load() = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_SyncNoUpdate(t *testing.T) {
|
||||
syn := &countSyncer{updated: false}
|
||||
g := dataset.NewGroup(syn)
|
||||
calls := 0
|
||||
dataset.Add(g, func() (*string, error) {
|
||||
calls++
|
||||
v := "x"
|
||||
return &v, nil
|
||||
})
|
||||
if err := g.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
calls = 0
|
||||
|
||||
updated, err := g.Sync()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if updated || calls != 0 {
|
||||
t.Errorf("Sync() updated=%v calls=%d, want false/0", updated, calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_SyncWithUpdate(t *testing.T) {
|
||||
syn := &countSyncer{updated: true}
|
||||
g := dataset.NewGroup(syn)
|
||||
n := 0
|
||||
ds := dataset.Add(g, func() (*int, error) {
|
||||
n++
|
||||
return &n, nil
|
||||
})
|
||||
if err := g.Init(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := g.Sync(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := ds.Load(); got == nil || *got != 2 {
|
||||
t.Errorf("ds.Load() = %v, want 2", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_FetchError(t *testing.T) {
|
||||
syn := &countSyncer{err: errors.New("network down")}
|
||||
g := dataset.NewGroup(syn)
|
||||
dataset.Add(g, func() (*string, error) {
|
||||
v := "x"
|
||||
return &v, nil
|
||||
})
|
||||
if err := g.Init(); err == nil {
|
||||
t.Error("expected error from Group.Init when syncer fails")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGroup_LoaderError(t *testing.T) {
|
||||
syn := httpcache.NopSyncer{}
|
||||
g := dataset.NewGroup(syn)
|
||||
dataset.Add(g, func() (*string, error) {
|
||||
return nil, errors.New("parse error")
|
||||
})
|
||||
if err := g.Init(); err == nil {
|
||||
t.Error("expected error from Group.Init when loader fails")
|
||||
}
|
||||
}
|
||||
@ -122,7 +122,11 @@ func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
|
||||
ippre = netip.PrefixFrom(ip, 32)
|
||||
}
|
||||
|
||||
ip4 := ippre.Addr().As4()
|
||||
addr := ippre.Addr()
|
||||
if !addr.Is4() {
|
||||
return ipv4net, fmt.Errorf("IPv6 not supported: %s", raw)
|
||||
}
|
||||
ip4 := addr.As4()
|
||||
prefix := uint8(ippre.Bits()) // 0–32
|
||||
return NewIPv4Net(
|
||||
binary.BigEndian.Uint32(ip4[:]),
|
||||
|
||||
179
net/ipcohort/ipcohort_test.go
Normal file
179
net/ipcohort/ipcohort_test.go
Normal file
@ -0,0 +1,179 @@
|
||||
package ipcohort_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
func TestParseIPv4(t *testing.T) {
|
||||
tests := []struct {
|
||||
raw string
|
||||
wantErr bool
|
||||
}{
|
||||
{"1.2.3.4", false},
|
||||
{"1.2.3.4/32", false},
|
||||
{"10.0.0.0/8", false},
|
||||
{"192.168.0.0/16", false},
|
||||
{"0.0.0.0/0", false},
|
||||
{"", true},
|
||||
{"not-an-ip", true},
|
||||
{"1.2.3.4/33", true},
|
||||
{"::1", true}, // IPv6 not supported
|
||||
}
|
||||
for _, tt := range tests {
|
||||
_, err := ipcohort.ParseIPv4(tt.raw)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseIPv4(%q): got err=%v, wantErr=%v", tt.raw, err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains_SingleHosts(t *testing.T) {
|
||||
c, err := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.1"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hits := []string{"1.2.3.4", "5.6.7.8", "10.0.0.1"}
|
||||
misses := []string{"1.2.3.5", "5.6.7.7", "10.0.0.2", "0.0.0.0"}
|
||||
|
||||
for _, ip := range hits {
|
||||
if !c.Contains(ip) {
|
||||
t.Errorf("expected %s to be in cohort", ip)
|
||||
}
|
||||
}
|
||||
for _, ip := range misses {
|
||||
if c.Contains(ip) {
|
||||
t.Errorf("expected %s NOT to be in cohort", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains_CIDRRanges(t *testing.T) {
|
||||
c, err := ipcohort.Parse([]string{"10.0.0.0/8", "192.168.1.0/24"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
hits := []string{
|
||||
"10.0.0.0", "10.0.0.1", "10.255.255.255",
|
||||
"192.168.1.0", "192.168.1.1", "192.168.1.254", "192.168.1.255",
|
||||
}
|
||||
misses := []string{
|
||||
"9.255.255.255", "11.0.0.0",
|
||||
"192.168.0.255", "192.168.2.0",
|
||||
}
|
||||
|
||||
for _, ip := range hits {
|
||||
if !c.Contains(ip) {
|
||||
t.Errorf("expected %s to be in cohort (CIDR)", ip)
|
||||
}
|
||||
}
|
||||
for _, ip := range misses {
|
||||
if c.Contains(ip) {
|
||||
t.Errorf("expected %s NOT to be in cohort (CIDR)", ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains_Mixed(t *testing.T) {
|
||||
c, err := ipcohort.Parse([]string{"1.2.3.4", "10.0.0.0/8"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !c.Contains("1.2.3.4") {
|
||||
t.Error("host miss")
|
||||
}
|
||||
if !c.Contains("10.5.5.5") {
|
||||
t.Error("CIDR miss")
|
||||
}
|
||||
if c.Contains("1.2.3.5") {
|
||||
t.Error("false positive for host-adjacent")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains_FailClosed(t *testing.T) {
|
||||
c, _ := ipcohort.Parse([]string{"1.2.3.4"})
|
||||
// Unparseable input should return true (fail-closed).
|
||||
if !c.Contains("not-an-ip") {
|
||||
t.Error("expected fail-closed true for unparseable IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestContains_Empty(t *testing.T) {
|
||||
c, err := ipcohort.Parse(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Contains("1.2.3.4") {
|
||||
t.Error("empty cohort should not contain anything")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSize(t *testing.T) {
|
||||
c, _ := ipcohort.Parse([]string{"1.2.3.4", "5.6.7.8", "10.0.0.0/8"})
|
||||
if got, want := c.Size(), 3; got != want {
|
||||
t.Errorf("Size() = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFile(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "list.txt")
|
||||
content := "# comment\n1.2.3.4\n10.0.0.0/8\n\n5.6.7.8\n"
|
||||
if err := os.WriteFile(path, []byte(content), 0o644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c, err := ipcohort.LoadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Size() != 3 {
|
||||
t.Errorf("Size() = %d, want 3", c.Size())
|
||||
}
|
||||
if !c.Contains("1.2.3.4") {
|
||||
t.Error("missing 1.2.3.4")
|
||||
}
|
||||
if !c.Contains("10.5.5.5") {
|
||||
t.Error("missing CIDR member 10.5.5.5")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFiles_Merge(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
f1 := filepath.Join(dir, "singles.txt")
|
||||
f2 := filepath.Join(dir, "networks.txt")
|
||||
os.WriteFile(f1, []byte("1.2.3.4\n5.6.7.8\n"), 0o644)
|
||||
os.WriteFile(f2, []byte("192.168.0.0/24\n"), 0o644)
|
||||
|
||||
c, err := ipcohort.LoadFiles(f1, f2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Size() != 3 {
|
||||
t.Errorf("Size() = %d, want 3", c.Size())
|
||||
}
|
||||
if !c.Contains("192.168.0.100") {
|
||||
t.Error("missing merged CIDR member")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSVSkipsIPv6(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "mixed.txt")
|
||||
os.WriteFile(path, []byte("1.2.3.4\n::1\n2001:db8::/32\n5.6.7.8\n"), 0o644)
|
||||
|
||||
c, err := ipcohort.LoadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if c.Size() != 2 {
|
||||
t.Errorf("Size() = %d, want 2 (IPv6 should be skipped)", c.Size())
|
||||
}
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user