feat: ipcohort filter with inbound/outbound/whitelist cohorts

Blacklist → IPFilter with three separate atomic cohorts: whitelist
(never blocked), inbound, and outbound. ContainsInbound/ContainsOutbound
each skip the whitelist. HTTP sync fetches all cachers before a single
reload to avoid double-load. Also fixes httpcache.Init calling c.Fetch().
This commit is contained in:
AJ ONeal 2026-04-19 23:17:12 -06:00
parent ff224c5bb1
commit 5f48a9beaa
No known key found for this signature in database
3 changed files with 202 additions and 72 deletions

View File

@ -37,23 +37,27 @@ func (c *Cacher) Register(fn func() error) {
// Init fetches the URL unconditionally (no cached headers yet) and invokes // Init fetches the URL unconditionally (no cached headers yet) and invokes
// all callbacks, ensuring files are loaded on startup. // all callbacks, ensuring files are loaded on startup.
func (c *Cacher) Init() error { func (c *Cacher) Init() error {
if _, err := c.fetch(); err != nil { if _, err := c.Fetch(); err != nil {
return err return err
} }
return c.invokeCallbacks() return c.invokeCallbacks()
} }
// Sync sends a conditional GET. If the server returns new content, writes it // Sync sends a conditional GET, writes updated content, and invokes callbacks.
// to Path and invokes callbacks. Returns whether the file was updated. // Returns whether the file was updated.
func (c *Cacher) Sync() (updated bool, err error) { func (c *Cacher) Sync() (updated bool, err error) {
updated, err = c.fetch() updated, err = c.Fetch()
if err != nil || !updated { if err != nil || !updated {
return updated, err return updated, err
} }
return true, c.invokeCallbacks() return true, c.invokeCallbacks()
} }
func (c *Cacher) fetch() (updated bool, err error) { // Fetch sends a conditional GET and writes new content to Path if the server
// responds with 200. Returns whether the file was updated. Does not invoke
// callbacks — use Sync for the single-cacher case, or call Fetch across
// multiple cachers and handle the reload yourself.
func (c *Cacher) Fetch() (updated bool, err error) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()

View File

@ -19,70 +19,101 @@ type HTTPSource struct {
Path string Path string
} }
type Blacklist struct { // IPFilter holds up to three cohorts: a whitelist (IPs never blocked),
atomic.Pointer[ipcohort.Cohort] // an inbound blocklist, and an outbound blocklist.
paths []string type IPFilter struct {
whitelist atomic.Pointer[ipcohort.Cohort]
inbound atomic.Pointer[ipcohort.Cohort]
outbound atomic.Pointer[ipcohort.Cohort]
whitelistPaths []string
inboundPaths []string
outboundPaths []string
git *gitshallow.Repo git *gitshallow.Repo
http []*httpcache.Cacher httpInbound []*httpcache.Cacher
httpOutbound []*httpcache.Cacher
} }
// NewBlacklist loads from one or more local files. // NewFileFilter loads inbound/outbound/whitelist from local files.
func NewBlacklist(paths ...string) *Blacklist { func NewFileFilter(whitelist, inbound, outbound []string) *IPFilter {
return &Blacklist{paths: paths} return &IPFilter{
whitelistPaths: whitelist,
inboundPaths: inbound,
outboundPaths: outbound,
}
} }
// NewGitBlacklist clones/pulls gitURL into repoDir and loads relPaths on each update. // NewGitFilter clones/pulls gitURL into repoDir and loads the given relative
func NewGitBlacklist(gitURL, repoDir string, relPaths ...string) *Blacklist { // paths for each cohort on each update.
func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []string) *IPFilter {
repo := gitshallow.New(gitURL, repoDir, 1, "") repo := gitshallow.New(gitURL, repoDir, 1, "")
paths := make([]string, len(relPaths)) abs := func(rel []string) []string {
for i, p := range relPaths { out := make([]string, len(rel))
paths[i] = filepath.Join(repoDir, p) for i, p := range rel {
out[i] = filepath.Join(repoDir, p)
} }
b := &Blacklist{paths: paths, git: repo} return out
repo.Register(b.reload) }
return b f := &IPFilter{
whitelistPaths: whitelist,
inboundPaths: abs(inboundRel),
outboundPaths: abs(outboundRel),
git: repo,
}
repo.Register(f.reloadAll)
return f
} }
// NewHTTPBlacklist fetches each source URL to its local path, reloading on any change. // NewHTTPFilter fetches inbound and outbound sources via HTTP;
func NewHTTPBlacklist(sources ...HTTPSource) *Blacklist { // whitelist is always loaded from local files.
b := &Blacklist{} func NewHTTPFilter(whitelist []string, inbound, outbound []HTTPSource) *IPFilter {
for _, src := range sources { f := &IPFilter{whitelistPaths: whitelist}
b.paths = append(b.paths, src.Path) for _, src := range inbound {
c := httpcache.New(src.URL, src.Path) f.inboundPaths = append(f.inboundPaths, src.Path)
c.Register(b.reload) f.httpInbound = append(f.httpInbound, httpcache.New(src.URL, src.Path))
b.http = append(b.http, c)
} }
return b for _, src := range outbound {
f.outboundPaths = append(f.outboundPaths, src.Path)
f.httpOutbound = append(f.httpOutbound, httpcache.New(src.URL, src.Path))
}
return f
} }
func (b *Blacklist) Init(lightGC bool) error { func (f *IPFilter) Init(lightGC bool) error {
switch { switch {
case b.git != nil: case f.git != nil:
return b.git.Init(lightGC) return f.git.Init(lightGC)
case len(b.http) > 0: case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0:
for _, c := range b.http { for _, c := range f.httpInbound {
if err := c.Init(); err != nil { if _, err := c.Fetch(); err != nil {
return err return err
} }
} }
return nil for _, c := range f.httpOutbound {
if _, err := c.Fetch(); err != nil {
return err
}
}
return f.reloadAll()
default: default:
return b.reload() return f.reloadAll()
} }
} }
func (b *Blacklist) Run(ctx context.Context, lightGC bool) { func (f *IPFilter) Run(ctx context.Context, lightGC bool) {
ticker := time.NewTicker(47 * time.Minute) ticker := time.NewTicker(47 * time.Minute)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
updated, err := b.sync(lightGC) updated, err := f.sync(lightGC)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err) fmt.Fprintf(os.Stderr, "error: filter sync: %v\n", err)
} else if updated { } else if updated {
fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size()) fmt.Fprintf(os.Stderr, "filter: reloaded — inbound=%d outbound=%d\n",
f.InboundSize(), f.OutboundSize())
} }
case <-ctx.Done(): case <-ctx.Done():
return return
@ -90,38 +121,109 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
} }
} }
func (b *Blacklist) sync(lightGC bool) (bool, error) { func (f *IPFilter) sync(lightGC bool) (bool, error) {
switch { switch {
case b.git != nil: case f.git != nil:
return b.git.Sync(lightGC) return f.git.Sync(lightGC)
case len(b.http) > 0: case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0:
var anyUpdated bool var anyUpdated bool
for _, c := range b.http { for _, c := range f.httpInbound {
updated, err := c.Sync() updated, err := c.Fetch()
if err != nil { if err != nil {
return anyUpdated, err return anyUpdated, err
} }
anyUpdated = anyUpdated || updated anyUpdated = anyUpdated || updated
} }
return anyUpdated, nil for _, c := range f.httpOutbound {
updated, err := c.Fetch()
if err != nil {
return anyUpdated, err
}
anyUpdated = anyUpdated || updated
}
if anyUpdated {
return true, f.reloadAll()
}
return false, nil
default: default:
return false, nil return false, nil
} }
} }
func (b *Blacklist) Contains(ipStr string) bool { // ContainsInbound reports whether ip is in the inbound blocklist and not whitelisted.
return b.Load().Contains(ipStr) func (f *IPFilter) ContainsInbound(ip string) bool {
if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
c := f.inbound.Load()
return c != nil && c.Contains(ip)
} }
func (b *Blacklist) Size() int { // ContainsOutbound reports whether ip is in the outbound blocklist and not whitelisted.
return b.Load().Size() func (f *IPFilter) ContainsOutbound(ip string) bool {
if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) {
return false
}
c := f.outbound.Load()
return c != nil && c.Contains(ip)
} }
func (b *Blacklist) reload() error { func (f *IPFilter) InboundSize() int {
c, err := ipcohort.LoadFiles(b.paths...) if c := f.inbound.Load(); c != nil {
return c.Size()
}
return 0
}
func (f *IPFilter) OutboundSize() int {
if c := f.outbound.Load(); c != nil {
return c.Size()
}
return 0
}
func (f *IPFilter) reloadAll() error {
if err := f.reloadWhitelist(); err != nil {
return err
}
if err := f.reloadInbound(); err != nil {
return err
}
return f.reloadOutbound()
}
func (f *IPFilter) reloadWhitelist() error {
if len(f.whitelistPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.whitelistPaths...)
if err != nil { if err != nil {
return err return err
} }
b.Store(c) f.whitelist.Store(c)
return nil
}
func (f *IPFilter) reloadInbound() error {
if len(f.inboundPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.inboundPaths...)
if err != nil {
return err
}
f.inbound.Store(c)
return nil
}
func (f *IPFilter) reloadOutbound() error {
if len(f.outboundPaths) == 0 {
return nil
}
c, err := ipcohort.LoadFiles(f.outboundPaths...)
if err != nil {
return err
}
f.outbound.Store(c)
return nil return nil
} }

View File

@ -12,10 +12,16 @@ const (
inboundNetworkURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/inbound/networks.txt" inboundNetworkURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/inbound/networks.txt"
) )
// outbound blocklist
const (
outboundSingleURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/outbound/single_ips.txt"
outboundNetworkURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/outbound/networks.txt"
)
func main() { func main() {
if len(os.Args) < 3 { if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <cache-dir|blacklist.csv> <ip-address> [git-url]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s <cache-dir|blacklist.txt> <ip-address> [git-url]\n", os.Args[0])
fmt.Fprintf(os.Stderr, " No remote: load from <blacklist.csv>\n") fmt.Fprintf(os.Stderr, " No remote: load from <blacklist.txt> (inbound only)\n")
fmt.Fprintf(os.Stderr, " git URL: clone/pull into <cache-dir>\n") fmt.Fprintf(os.Stderr, " git URL: clone/pull into <cache-dir>\n")
fmt.Fprintf(os.Stderr, " (default): fetch via HTTP into <cache-dir>\n") fmt.Fprintf(os.Stderr, " (default): fetch via HTTP into <cache-dir>\n")
os.Exit(1) os.Exit(1)
@ -28,34 +34,52 @@ func main() {
gitURL = os.Args[3] gitURL = os.Args[3]
} }
var bl *Blacklist var f *IPFilter
switch { switch {
case gitURL != "": case gitURL != "":
bl = NewGitBlacklist(gitURL, dataPath, f = NewGitFilter(gitURL, dataPath,
"tables/inbound/single_ips.txt", nil,
"tables/inbound/networks.txt", []string{"tables/inbound/single_ips.txt", "tables/inbound/networks.txt"},
[]string{"tables/outbound/single_ips.txt", "tables/outbound/networks.txt"},
) )
case strings.HasSuffix(dataPath, ".txt") || strings.HasSuffix(dataPath, ".csv"): case strings.HasSuffix(dataPath, ".txt") || strings.HasSuffix(dataPath, ".csv"):
bl = NewBlacklist(dataPath) f = NewFileFilter(nil, []string{dataPath}, nil)
default: default:
// dataPath is a cache directory; fetch the pre-split files via HTTP // dataPath is a cache directory; fetch the pre-split files via HTTP
bl = NewHTTPBlacklist( f = NewHTTPFilter(
HTTPSource{inboundSingleURL, dataPath + "/single_ips.txt"}, nil,
HTTPSource{inboundNetworkURL, dataPath + "/networks.txt"}, []HTTPSource{
{inboundSingleURL, dataPath + "/inbound_single_ips.txt"},
{inboundNetworkURL, dataPath + "/inbound_networks.txt"},
},
[]HTTPSource{
{outboundSingleURL, dataPath + "/outbound_single_ips.txt"},
{outboundNetworkURL, dataPath + "/outbound_networks.txt"},
},
) )
} }
if err := bl.Init(false); err != nil { if err := f.Init(false); err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
fmt.Fprintf(os.Stderr, "Loaded %d entries\n", bl.Size()) fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", f.InboundSize(), f.OutboundSize())
if bl.Contains(ipStr) { blockedInbound := f.ContainsInbound(ipStr)
fmt.Printf("%s is BLOCKED\n", ipStr) blockedOutbound := f.ContainsOutbound(ipStr)
switch {
case blockedInbound && blockedOutbound:
fmt.Printf("%s is BLOCKED (inbound + outbound)\n", ipStr)
os.Exit(1) os.Exit(1)
} case blockedInbound:
fmt.Printf("%s is BLOCKED (inbound)\n", ipStr)
os.Exit(1)
case blockedOutbound:
fmt.Printf("%s is BLOCKED (outbound)\n", ipStr)
os.Exit(1)
default:
fmt.Printf("%s is allowed\n", ipStr) fmt.Printf("%s is allowed\n", ipStr)
}
} }