golib/net/ipcohort/ipcohort.go
AJ ONeal 896031b6a8
fix: idiomatic Go cleanup across net packages
- gitshallow: replace in-place Depth mutation with effectiveDepth() method;
  remove depth normalisation from New() since it was masking the bug
- ipcohort: extract sortNets() helper using cmp.Compare, eliminating 3 identical
  sort closures; add ContainsAddr(netip.Addr) for pre-parsed callers; guard
  Contains() against IPv6 panic (As4 panics on non-v4); add IPv6 test
- dataset: Add() now sets NopSyncer{} so callers cannot panic by accidentally
  calling Init/Sync/Run on a Group-managed Dataset
2026-04-20 09:47:50 -06:00

229 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ipcohort
import (
"cmp"
"encoding/binary"
"encoding/csv"
"fmt"
"io"
"log"
"net/netip"
"os"
"slices"
"strings"
)
// IPv4Net represents a subnet or single address (/32).
// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8.
type IPv4Net struct {
networkBE uint32
prefix uint8
shift uint8
}
func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
return IPv4Net{
networkBE: ip4be,
prefix: prefix,
shift: 32 - prefix,
}
}
func (r IPv4Net) Contains(ip uint32) bool {
mask := uint32(0xFFFFFFFF << r.shift)
return (ip & mask) == r.networkBE
}
// Cohort is an immutable, read-only set of IPv4 addresses and subnets.
// Contains is safe for concurrent use without locks.
//
// hosts holds sorted /32 addresses for O(log n) binary search.
// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small.
type Cohort struct {
hosts []uint32
nets []IPv4Net
}
func New() *Cohort {
return &Cohort{}
}
func sortNets(nets []IPv4Net) {
slices.SortFunc(nets, func(a, b IPv4Net) int {
return cmp.Compare(a.networkBE, b.networkBE)
})
}
// Size returns the total number of entries (hosts + nets).
func (c *Cohort) Size() int {
return len(c.hosts) + len(c.nets)
}
// Contains reports whether ipStr falls within any host or subnet in the cohort.
// Returns true on parse error (fail-closed): unparseable input is treated as
// blocked so that garbage strings never accidentally bypass a blocklist check.
// IPv6 addresses are not stored and always return false.
func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true // fail-closed
}
return c.ContainsAddr(ip)
}
// ContainsAddr reports whether ip falls within any host or subnet in the cohort.
// IPv6 addresses always return false (cohort is IPv4-only).
func (c *Cohort) ContainsAddr(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
if _, found := slices.BinarySearch(c.hosts, ipU32); found {
return true
}
for _, net := range c.nets {
if net.Contains(ipU32) {
return true
}
}
return false
}
func Parse(prefixList []string) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for _, raw := range prefixList {
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
var ippre netip.Prefix
var ip netip.Addr
if strings.Contains(raw, "/") {
ippre, err = netip.ParsePrefix(raw)
if err != nil {
return ipv4net, err
}
} else {
ip, err = netip.ParseAddr(raw)
if err != nil {
return ipv4net, err
}
ippre = netip.PrefixFrom(ip, 32)
}
addr := ippre.Addr()
if !addr.Is4() {
return ipv4net, fmt.Errorf("IPv6 not supported: %s", raw)
}
ip4 := addr.As4()
prefix := uint8(ippre.Bits()) // 032
return NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
), nil
}
func LoadFile(path string) (*Cohort, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("could not load %q: %v", path, err)
}
defer f.Close()
return ParseCSV(f)
}
// LoadFiles loads and merges multiple files into one Cohort.
// Useful when hosts and networks are stored in separate files.
func LoadFiles(paths ...string) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for _, path := range paths {
c, err := LoadFile(path)
if err != nil {
return nil, err
}
hosts = append(hosts, c.hosts...)
nets = append(nets, c.nets...)
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
func ParseCSV(f io.Reader) (*Cohort, error) {
r := csv.NewReader(f)
r.FieldsPerRecord = -1
return ReadAll(r)
}
func ReadAll(r *csv.Reader) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for {
record, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("csv read error: %w", err)
}
if len(record) == 0 {
continue
}
raw := strings.TrimSpace(record[0])
if raw == "" || strings.HasPrefix(raw, "#") {
continue
}
// skip IPv6
if strings.Contains(raw, ":") {
continue
}
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}