package ipcohort import ( "encoding/binary" "encoding/csv" "fmt" "io" "log" "net/netip" "os" "slices" "sort" "strings" ) // Either a subnet or single address (subnet with /32 CIDR prefix) type IPv4Net struct { networkBE uint32 prefix uint8 shift uint8 } func NewIPv4Net(ip4be uint32, prefix uint8) IPv4Net { return IPv4Net{ networkBE: ip4be, prefix: prefix, shift: 32 - prefix, } } func (r IPv4Net) Contains(ip uint32) bool { mask := uint32(0xFFFFFFFF << (r.shift)) return (ip & mask) == r.networkBE } func New() *Cohort { cohort := &Cohort{ranges: []IPv4Net{}} return cohort } func Parse(prefixList []string) (*Cohort, error) { var ranges []IPv4Net for _, raw := range prefixList { ipv4net, err := ParseIPv4(raw) if err != nil { log.Printf("skipping invalid entry: %q", raw) continue } ranges = append(ranges, ipv4net) } sizedList := make([]IPv4Net, len(ranges)) copy(sizedList, ranges) sortRanges(ranges) cohort := &Cohort{ranges: sizedList} return cohort, nil } func ParseIPv4(raw string) (ipv4net IPv4Net, err error) { var ippre netip.Prefix var ip netip.Addr if strings.Contains(raw, "/") { ippre, err = netip.ParsePrefix(raw) if err != nil { return ipv4net, err } } else { ip, err = netip.ParseAddr(raw) if err != nil { return ipv4net, err } ippre = netip.PrefixFrom(ip, 32) } ip4 := ippre.Addr().As4() prefix := uint8(ippre.Bits()) // 0-32 return NewIPv4Net( binary.BigEndian.Uint32(ip4[:]), prefix, ), nil } func LoadFile(path string, unsorted bool) (*Cohort, error) { f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("could not load %q: %v", path, err) } defer f.Close() return ParseCSV(f, unsorted) } func ParseCSV(f io.Reader, unsorted bool) (*Cohort, error) { r := csv.NewReader(f) r.FieldsPerRecord = -1 return ReadAll(r, unsorted) } func ReadAll(r *csv.Reader, unsorted bool) (*Cohort, error) { var ranges []IPv4Net for { record, err := r.Read() if err == io.EOF { break } if err != nil { return nil, fmt.Errorf("csv read error: %w", err) } if len(record) == 0 { continue } raw := strings.TrimSpace(record[0]) // Skip comments/empty if raw == "" || strings.HasPrefix(raw, "#") { continue } // skip IPv6 if strings.Contains(raw, ":") { continue } ipv4net, err := ParseIPv4(raw) if err != nil { log.Printf("skipping invalid entry: %q", raw) continue } ranges = append(ranges, ipv4net) } if unsorted { sortRanges(ranges) } sizedList := make([]IPv4Net, len(ranges)) copy(sizedList, ranges) cohort := &Cohort{ranges: sizedList} return cohort, nil } func sortRanges(ranges []IPv4Net) { // Sort by network address (required for binary search) sort.Slice(ranges, func(i, j int) bool { // Note: we could also sort by prefix (largest first) return ranges[i].networkBE < ranges[j].networkBE }) // Note: we could also merge ranges here } type Cohort struct { ranges []IPv4Net } func (c *Cohort) Size() int { return len(c.ranges) } func (c *Cohort) Contains(ipStr string) bool { ip, err := netip.ParseAddr(ipStr) if err != nil { return true } ip4 := ip.As4() ipU32 := binary.BigEndian.Uint32(ip4[:]) idx, found := slices.BinarySearchFunc(c.ranges, ipU32, func(r IPv4Net, target uint32) int { if r.networkBE < target { return -1 } if r.networkBE > target { return 1 } return 0 }) if found { return true } // Check the range immediately before the insertion point if idx > 0 { if c.ranges[idx-1].Contains(ipU32) { return true } } return false }