AJ ONeal 5f48a9beaa
feat: ipcohort filter with inbound/outbound/whitelist cohorts
Blacklist → IPFilter with three separate atomic cohorts: whitelist
(never blocked), inbound, and outbound. ContainsInbound/ContainsOutbound
each skip the whitelist. HTTP sync fetches all cachers before a single
reload to avoid double-load. Also fixes httpcache.Init calling c.Fetch().
2026-04-19 23:17:12 -06:00

230 lines
5.2 KiB
Go

package main
import (
"context"
"fmt"
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort"
)
// HTTPSource pairs a remote URL with a local cache path.
type HTTPSource struct {
URL string
Path string
}
// IPFilter holds up to three cohorts: a whitelist (IPs never blocked),
// an inbound blocklist, and an outbound blocklist.
type IPFilter struct {
whitelist atomic.Pointer[ipcohort.Cohort]
inbound atomic.Pointer[ipcohort.Cohort]
outbound atomic.Pointer[ipcohort.Cohort]
whitelistPaths []string
inboundPaths []string
outboundPaths []string
git *gitshallow.Repo
httpInbound []*httpcache.Cacher
httpOutbound []*httpcache.Cacher
}
// NewFileFilter loads inbound/outbound/whitelist from local files.
func NewFileFilter(whitelist, inbound, outbound []string) *IPFilter {
return &IPFilter{
whitelistPaths: whitelist,
inboundPaths: inbound,
outboundPaths: outbound,
}
}
// NewGitFilter clones/pulls gitURL into repoDir and loads the given relative
// paths for each cohort on each update.
func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []string) *IPFilter {
repo := gitshallow.New(gitURL, repoDir, 1, "")
abs := func(rel []string) []string {
out := make([]string, len(rel))
for i, p := range rel {
out[i] = filepath.Join(repoDir, p)
}
return out
}
f := &IPFilter{
whitelistPaths: whitelist,
inboundPaths: abs(inboundRel),
outboundPaths: abs(outboundRel),
git: repo,
}
repo.Register(f.reloadAll)
return f
}
// NewHTTPFilter fetches inbound and outbound sources via HTTP;
// whitelist is always loaded from local files.
func NewHTTPFilter(whitelist []string, inbound, outbound []HTTPSource) *IPFilter {
f := &IPFilter{whitelistPaths: whitelist}
for _, src := range inbound {
f.inboundPaths = append(f.inboundPaths, src.Path)
f.httpInbound = append(f.httpInbound, httpcache.New(src.URL, src.Path))
}
for _, src := range outbound {
f.outboundPaths = append(f.outboundPaths, src.Path)
f.httpOutbound = append(f.httpOutbound, httpcache.New(src.URL, src.Path))
}
return f
}
func (f *IPFilter) Init(lightGC bool) error {
switch {
case f.git != nil:
return f.git.Init(lightGC)
case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0:
for _, c := range f.httpInbound {
if _, err := c.Fetch(); err != nil {
return err
}
}
for _, c := range f.httpOutbound {
if _, err := c.Fetch(); err != nil {
return err
}
}
return f.reloadAll()
default:
return f.reloadAll()
}
}
func (f *IPFilter) Run(ctx context.Context, lightGC bool) {
ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
updated, err := f.sync(lightGC)
if err != nil {
fmt.Fprintf(os.Stderr, "error: filter sync: %v\n", err)
} else if updated {
fmt.Fprintf(os.Stderr, "filter: reloaded — inbound=%d outbound=%d\n",
f.InboundSize(), f.OutboundSize())
}
case <-ctx.Done():
return
}
}
}
func (f *IPFilter) sync(lightGC bool) (bool, error) {
switch {
case f.git != nil:
return f.git.Sync(lightGC)
case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0:
var anyUpdated bool
for _, c := range f.httpInbound {
updated, err := c.Fetch()
if err != nil {
return anyUpdated, err
}
anyUpdated = anyUpdated || updated
}
for _, c := range f.httpOutbound {
updated, err := c.Fetch()
if err != nil {
return anyUpdated, err
}
anyUpdated = anyUpdated || updated
}
if anyUpdated {
return true, f.reloadAll()
}
return false, nil
default:
return false, nil
}
}
// ContainsInbound reports whether ip is in the inbound blocklist and not whitelisted.
func (f *IPFilter) ContainsInbound(ip string) bool {
if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
c := f.inbound.Load()
return c != nil && c.Contains(ip)
}
// ContainsOutbound reports whether ip is in the outbound blocklist and not whitelisted.
func (f *IPFilter) ContainsOutbound(ip string) bool {
if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
c := f.outbound.Load()
return c != nil && c.Contains(ip)
}
func (f *IPFilter) InboundSize() int {
if c := f.inbound.Load(); c != nil {
return c.Size()
}
return 0
}
func (f *IPFilter) OutboundSize() int {
if c := f.outbound.Load(); c != nil {
return c.Size()
}
return 0
}
func (f *IPFilter) reloadAll() error {
if err := f.reloadWhitelist(); err != nil {
return err
}
if err := f.reloadInbound(); err != nil {
return err
}
return f.reloadOutbound()
}
func (f *IPFilter) reloadWhitelist() error {
if len(f.whitelistPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.whitelistPaths...)
if err != nil {
return err
}
f.whitelist.Store(c)
return nil
}
func (f *IPFilter) reloadInbound() error {
if len(f.inboundPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.inboundPaths...)
if err != nil {
return err
}
f.inbound.Store(c)
return nil
}
func (f *IPFilter) reloadOutbound() error {
if len(f.outboundPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.outboundPaths...)
if err != nil {
return err
}
f.outbound.Store(c)
return nil
}