diff --git a/net/httpcache/httpcache.go b/net/httpcache/httpcache.go index 46f02f1..2a76d60 100644 --- a/net/httpcache/httpcache.go +++ b/net/httpcache/httpcache.go @@ -37,23 +37,27 @@ func (c *Cacher) Register(fn func() error) { // Init fetches the URL unconditionally (no cached headers yet) and invokes // all callbacks, ensuring files are loaded on startup. func (c *Cacher) Init() error { - if _, err := c.fetch(); err != nil { + if _, err := c.Fetch(); err != nil { return err } return c.invokeCallbacks() } -// Sync sends a conditional GET. If the server returns new content, writes it -// to Path and invokes callbacks. Returns whether the file was updated. +// Sync sends a conditional GET, writes updated content, and invokes callbacks. +// Returns whether the file was updated. func (c *Cacher) Sync() (updated bool, err error) { - updated, err = c.fetch() + updated, err = c.Fetch() if err != nil || !updated { return updated, err } return true, c.invokeCallbacks() } -func (c *Cacher) fetch() (updated bool, err error) { +// Fetch sends a conditional GET and writes new content to Path if the server +// responds with 200. Returns whether the file was updated. Does not invoke +// callbacks — use Sync for the single-cacher case, or call Fetch across +// multiple cachers and handle the reload yourself. +func (c *Cacher) Fetch() (updated bool, err error) { c.mu.Lock() defer c.mu.Unlock() diff --git a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go index b743474..b184a31 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go +++ b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go @@ -19,70 +19,101 @@ type HTTPSource struct { Path string } -type Blacklist struct { - atomic.Pointer[ipcohort.Cohort] - paths []string - git *gitshallow.Repo - http []*httpcache.Cacher +// IPFilter holds up to three cohorts: a whitelist (IPs never blocked), +// an inbound blocklist, and an outbound blocklist. +type IPFilter struct { + whitelist atomic.Pointer[ipcohort.Cohort] + inbound atomic.Pointer[ipcohort.Cohort] + outbound atomic.Pointer[ipcohort.Cohort] + + whitelistPaths []string + inboundPaths []string + outboundPaths []string + + git *gitshallow.Repo + httpInbound []*httpcache.Cacher + httpOutbound []*httpcache.Cacher } -// NewBlacklist loads from one or more local files. -func NewBlacklist(paths ...string) *Blacklist { - return &Blacklist{paths: paths} +// NewFileFilter loads inbound/outbound/whitelist from local files. +func NewFileFilter(whitelist, inbound, outbound []string) *IPFilter { + return &IPFilter{ + whitelistPaths: whitelist, + inboundPaths: inbound, + outboundPaths: outbound, + } } -// NewGitBlacklist clones/pulls gitURL into repoDir and loads relPaths on each update. -func NewGitBlacklist(gitURL, repoDir string, relPaths ...string) *Blacklist { +// NewGitFilter clones/pulls gitURL into repoDir and loads the given relative +// paths for each cohort on each update. +func NewGitFilter(gitURL, repoDir string, whitelist, inboundRel, outboundRel []string) *IPFilter { repo := gitshallow.New(gitURL, repoDir, 1, "") - paths := make([]string, len(relPaths)) - for i, p := range relPaths { - paths[i] = filepath.Join(repoDir, p) + abs := func(rel []string) []string { + out := make([]string, len(rel)) + for i, p := range rel { + out[i] = filepath.Join(repoDir, p) + } + return out } - b := &Blacklist{paths: paths, git: repo} - repo.Register(b.reload) - return b + f := &IPFilter{ + whitelistPaths: whitelist, + inboundPaths: abs(inboundRel), + outboundPaths: abs(outboundRel), + git: repo, + } + repo.Register(f.reloadAll) + return f } -// NewHTTPBlacklist fetches each source URL to its local path, reloading on any change. -func NewHTTPBlacklist(sources ...HTTPSource) *Blacklist { - b := &Blacklist{} - for _, src := range sources { - b.paths = append(b.paths, src.Path) - c := httpcache.New(src.URL, src.Path) - c.Register(b.reload) - b.http = append(b.http, c) +// NewHTTPFilter fetches inbound and outbound sources via HTTP; +// whitelist is always loaded from local files. +func NewHTTPFilter(whitelist []string, inbound, outbound []HTTPSource) *IPFilter { + f := &IPFilter{whitelistPaths: whitelist} + for _, src := range inbound { + f.inboundPaths = append(f.inboundPaths, src.Path) + f.httpInbound = append(f.httpInbound, httpcache.New(src.URL, src.Path)) } - return b + for _, src := range outbound { + f.outboundPaths = append(f.outboundPaths, src.Path) + f.httpOutbound = append(f.httpOutbound, httpcache.New(src.URL, src.Path)) + } + return f } -func (b *Blacklist) Init(lightGC bool) error { +func (f *IPFilter) Init(lightGC bool) error { switch { - case b.git != nil: - return b.git.Init(lightGC) - case len(b.http) > 0: - for _, c := range b.http { - if err := c.Init(); err != nil { + case f.git != nil: + return f.git.Init(lightGC) + case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0: + for _, c := range f.httpInbound { + if _, err := c.Fetch(); err != nil { return err } } - return nil + for _, c := range f.httpOutbound { + if _, err := c.Fetch(); err != nil { + return err + } + } + return f.reloadAll() default: - return b.reload() + return f.reloadAll() } } -func (b *Blacklist) Run(ctx context.Context, lightGC bool) { +func (f *IPFilter) Run(ctx context.Context, lightGC bool) { ticker := time.NewTicker(47 * time.Minute) defer ticker.Stop() for { select { case <-ticker.C: - updated, err := b.sync(lightGC) + updated, err := f.sync(lightGC) if err != nil { - fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err) + fmt.Fprintf(os.Stderr, "error: filter sync: %v\n", err) } else if updated { - fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size()) + fmt.Fprintf(os.Stderr, "filter: reloaded — inbound=%d outbound=%d\n", + f.InboundSize(), f.OutboundSize()) } case <-ctx.Done(): return @@ -90,38 +121,109 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) { } } -func (b *Blacklist) sync(lightGC bool) (bool, error) { +func (f *IPFilter) sync(lightGC bool) (bool, error) { switch { - case b.git != nil: - return b.git.Sync(lightGC) - case len(b.http) > 0: + case f.git != nil: + return f.git.Sync(lightGC) + case len(f.httpInbound) > 0 || len(f.httpOutbound) > 0: var anyUpdated bool - for _, c := range b.http { - updated, err := c.Sync() + for _, c := range f.httpInbound { + updated, err := c.Fetch() if err != nil { return anyUpdated, err } anyUpdated = anyUpdated || updated } - return anyUpdated, nil + for _, c := range f.httpOutbound { + updated, err := c.Fetch() + if err != nil { + return anyUpdated, err + } + anyUpdated = anyUpdated || updated + } + if anyUpdated { + return true, f.reloadAll() + } + return false, nil default: return false, nil } } -func (b *Blacklist) Contains(ipStr string) bool { - return b.Load().Contains(ipStr) +// ContainsInbound reports whether ip is in the inbound blocklist and not whitelisted. +func (f *IPFilter) ContainsInbound(ip string) bool { + if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) { + return false + } + c := f.inbound.Load() + return c != nil && c.Contains(ip) } -func (b *Blacklist) Size() int { - return b.Load().Size() +// ContainsOutbound reports whether ip is in the outbound blocklist and not whitelisted. +func (f *IPFilter) ContainsOutbound(ip string) bool { + if wl := f.whitelist.Load(); wl != nil && wl.Contains(ip) { + return false + } + c := f.outbound.Load() + return c != nil && c.Contains(ip) } -func (b *Blacklist) reload() error { - c, err := ipcohort.LoadFiles(b.paths...) +func (f *IPFilter) InboundSize() int { + if c := f.inbound.Load(); c != nil { + return c.Size() + } + return 0 +} + +func (f *IPFilter) OutboundSize() int { + if c := f.outbound.Load(); c != nil { + return c.Size() + } + return 0 +} + +func (f *IPFilter) reloadAll() error { + if err := f.reloadWhitelist(); err != nil { + return err + } + if err := f.reloadInbound(); err != nil { + return err + } + return f.reloadOutbound() +} + +func (f *IPFilter) reloadWhitelist() error { + if len(f.whitelistPaths) == 0 { + return nil + } + c, err := ipcohort.LoadFiles(f.whitelistPaths...) if err != nil { return err } - b.Store(c) + f.whitelist.Store(c) + return nil +} + +func (f *IPFilter) reloadInbound() error { + if len(f.inboundPaths) == 0 { + return nil + } + c, err := ipcohort.LoadFiles(f.inboundPaths...) + if err != nil { + return err + } + f.inbound.Store(c) + return nil +} + +func (f *IPFilter) reloadOutbound() error { + if len(f.outboundPaths) == 0 { + return nil + } + c, err := ipcohort.LoadFiles(f.outboundPaths...) + if err != nil { + return err + } + f.outbound.Store(c) return nil } diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index 54cb3fb..51305b9 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -12,10 +12,16 @@ const ( inboundNetworkURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/inbound/networks.txt" ) +// outbound blocklist +const ( + outboundSingleURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/outbound/single_ips.txt" + outboundNetworkURL = "https://github.com/bitwire-it/ipblocklist/raw/refs/heads/main/tables/outbound/networks.txt" +) + func main() { if len(os.Args) < 3 { - fmt.Fprintf(os.Stderr, "Usage: %s [git-url]\n", os.Args[0]) - fmt.Fprintf(os.Stderr, " No remote: load from \n") + fmt.Fprintf(os.Stderr, "Usage: %s [git-url]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " No remote: load from (inbound only)\n") fmt.Fprintf(os.Stderr, " git URL: clone/pull into \n") fmt.Fprintf(os.Stderr, " (default): fetch via HTTP into \n") os.Exit(1) @@ -28,34 +34,52 @@ func main() { gitURL = os.Args[3] } - var bl *Blacklist + var f *IPFilter switch { case gitURL != "": - bl = NewGitBlacklist(gitURL, dataPath, - "tables/inbound/single_ips.txt", - "tables/inbound/networks.txt", + f = NewGitFilter(gitURL, dataPath, + nil, + []string{"tables/inbound/single_ips.txt", "tables/inbound/networks.txt"}, + []string{"tables/outbound/single_ips.txt", "tables/outbound/networks.txt"}, ) case strings.HasSuffix(dataPath, ".txt") || strings.HasSuffix(dataPath, ".csv"): - bl = NewBlacklist(dataPath) + f = NewFileFilter(nil, []string{dataPath}, nil) default: // dataPath is a cache directory; fetch the pre-split files via HTTP - bl = NewHTTPBlacklist( - HTTPSource{inboundSingleURL, dataPath + "/single_ips.txt"}, - HTTPSource{inboundNetworkURL, dataPath + "/networks.txt"}, + f = NewHTTPFilter( + nil, + []HTTPSource{ + {inboundSingleURL, dataPath + "/inbound_single_ips.txt"}, + {inboundNetworkURL, dataPath + "/inbound_networks.txt"}, + }, + []HTTPSource{ + {outboundSingleURL, dataPath + "/outbound_single_ips.txt"}, + {outboundNetworkURL, dataPath + "/outbound_networks.txt"}, + }, ) } - if err := bl.Init(false); err != nil { + if err := f.Init(false); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - fmt.Fprintf(os.Stderr, "Loaded %d entries\n", bl.Size()) + fmt.Fprintf(os.Stderr, "Loaded inbound=%d outbound=%d\n", f.InboundSize(), f.OutboundSize()) - if bl.Contains(ipStr) { - fmt.Printf("%s is BLOCKED\n", ipStr) + blockedInbound := f.ContainsInbound(ipStr) + blockedOutbound := f.ContainsOutbound(ipStr) + + switch { + case blockedInbound && blockedOutbound: + fmt.Printf("%s is BLOCKED (inbound + outbound)\n", ipStr) os.Exit(1) + case blockedInbound: + fmt.Printf("%s is BLOCKED (inbound)\n", ipStr) + os.Exit(1) + case blockedOutbound: + fmt.Printf("%s is BLOCKED (outbound)\n", ipStr) + os.Exit(1) + default: + fmt.Printf("%s is allowed\n", ipStr) } - - fmt.Printf("%s is allowed\n", ipStr) }