golib/net/ipcohort/ipcohort.go
AJ ONeal 86ffa2fb23
chore: remove IPv6 special-casing (YAGNI)
Drop the explicit IPv6 early-exit in ReadAll — ParseIPv4 already rejects
non-IPv4 via Is4(). Remove IPv6-specific tests and error message wording.
2026-04-20 09:54:04 -06:00

224 lines
4.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package ipcohort
import (
"cmp"
"encoding/binary"
"encoding/csv"
"fmt"
"io"
"log"
"net/netip"
"os"
"slices"
"strings"
)
// IPv4Net represents a subnet or single address (/32).
// 6 bytes: networkBE uint32 + prefix uint8 + shift uint8.
type IPv4Net struct {
networkBE uint32
prefix uint8
shift uint8
}
func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net {
return IPv4Net{
networkBE: ip4be,
prefix: prefix,
shift: 32 - prefix,
}
}
func (r IPv4Net) Contains(ip uint32) bool {
mask := uint32(0xFFFFFFFF << r.shift)
return (ip & mask) == r.networkBE
}
// Cohort is an immutable, read-only set of IPv4 addresses and subnets.
// Contains is safe for concurrent use without locks.
//
// hosts holds sorted /32 addresses for O(log n) binary search.
// nets holds CIDR ranges (prefix < 32) for O(k) linear scan — typically small.
type Cohort struct {
hosts []uint32
nets []IPv4Net
}
func New() *Cohort {
return &Cohort{}
}
func sortNets(nets []IPv4Net) {
slices.SortFunc(nets, func(a, b IPv4Net) int {
return cmp.Compare(a.networkBE, b.networkBE)
})
}
// Size returns the total number of entries (hosts + nets).
func (c *Cohort) Size() int {
return len(c.hosts) + len(c.nets)
}
// Contains reports whether ipStr falls within any host or subnet in the cohort.
// Returns true on parse error (fail-closed): unparseable input is treated as
// blocked so that garbage strings never accidentally bypass a blocklist check.
// IPv6 addresses are not stored and always return false.
func (c *Cohort) Contains(ipStr string) bool {
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return true // fail-closed
}
return c.ContainsAddr(ip)
}
// ContainsAddr reports whether ip falls within any host or subnet in the cohort.
// IPv6 addresses always return false (cohort is IPv4-only).
func (c *Cohort) ContainsAddr(ip netip.Addr) bool {
if !ip.Is4() {
return false
}
ip4 := ip.As4()
ipU32 := binary.BigEndian.Uint32(ip4[:])
if _, found := slices.BinarySearch(c.hosts, ipU32); found {
return true
}
for _, net := range c.nets {
if net.Contains(ipU32) {
return true
}
}
return false
}
func Parse(prefixList []string) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for _, raw := range prefixList {
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
func ParseIPv4(raw string) (ipv4net IPv4Net, err error) {
var ippre netip.Prefix
var ip netip.Addr
if strings.Contains(raw, "/") {
ippre, err = netip.ParsePrefix(raw)
if err != nil {
return ipv4net, err
}
} else {
ip, err = netip.ParseAddr(raw)
if err != nil {
return ipv4net, err
}
ippre = netip.PrefixFrom(ip, 32)
}
addr := ippre.Addr()
if !addr.Is4() {
return ipv4net, fmt.Errorf("not an IPv4 address: %s", raw)
}
ip4 := addr.As4()
prefix := uint8(ippre.Bits()) // 032
return NewIPv4Net(
binary.BigEndian.Uint32(ip4[:]),
prefix,
), nil
}
func LoadFile(path string) (*Cohort, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("could not load %q: %v", path, err)
}
defer f.Close()
return ParseCSV(f)
}
// LoadFiles loads and merges multiple files into one Cohort.
// Useful when hosts and networks are stored in separate files.
func LoadFiles(paths ...string) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for _, path := range paths {
c, err := LoadFile(path)
if err != nil {
return nil, err
}
hosts = append(hosts, c.hosts...)
nets = append(nets, c.nets...)
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}
func ParseCSV(f io.Reader) (*Cohort, error) {
r := csv.NewReader(f)
r.FieldsPerRecord = -1
return ReadAll(r)
}
func ReadAll(r *csv.Reader) (*Cohort, error) {
var hosts []uint32
var nets []IPv4Net
for {
record, err := r.Read()
if err == io.EOF {
break
}
if err != nil {
return nil, fmt.Errorf("csv read error: %w", err)
}
if len(record) == 0 {
continue
}
raw := strings.TrimSpace(record[0])
if raw == "" || strings.HasPrefix(raw, "#") {
continue
}
ipv4net, err := ParseIPv4(raw)
if err != nil {
log.Printf("skipping invalid entry: %q", raw)
continue
}
if ipv4net.prefix == 32 {
hosts = append(hosts, ipv4net.networkBE)
} else {
nets = append(nets, ipv4net)
}
}
slices.Sort(hosts)
sortNets(nets)
return &Cohort{hosts: hosts, nets: nets}, nil
}