From a9adc3dc18756f08ae829eb5bdd3926011825b32 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Sun, 19 Apr 2026 22:57:36 -0600 Subject: [PATCH] feat: add net/httpcache; wire git+http+file into Blacklist --- net/httpcache/httpcache.go | 125 ++++++++++++++++++ .../cmd/check-ip-blacklist/blacklist.go | 37 +++++- net/ipcohort/cmd/check-ip-blacklist/main.go | 16 ++- 3 files changed, 166 insertions(+), 12 deletions(-) create mode 100644 net/httpcache/httpcache.go diff --git a/net/httpcache/httpcache.go b/net/httpcache/httpcache.go new file mode 100644 index 0000000..46f02f1 --- /dev/null +++ b/net/httpcache/httpcache.go @@ -0,0 +1,125 @@ +package httpcache + +import ( + "fmt" + "io" + "net/http" + "os" + "sync" + "time" +) + +const defaultTimeout = 30 * time.Second + +// Cacher fetches a URL to a local file, using ETag/Last-Modified to skip +// unchanged responses. Calls registered callbacks when the file changes. +type Cacher struct { + URL string + Path string + Timeout time.Duration // 0 uses 30s + + mu sync.Mutex + etag string + lastMod string + callbacks []func() error +} + +// New creates a Cacher that fetches URL and writes it to path. +func New(url, path string) *Cacher { + return &Cacher{URL: url, Path: path} +} + +// Register adds a callback invoked after each successful fetch. +func (c *Cacher) Register(fn func() error) { + c.callbacks = append(c.callbacks, fn) +} + +// Init fetches the URL unconditionally (no cached headers yet) and invokes +// all callbacks, ensuring files are loaded on startup. +func (c *Cacher) Init() error { + if _, err := c.fetch(); err != nil { + return err + } + return c.invokeCallbacks() +} + +// Sync sends a conditional GET. If the server returns new content, writes it +// to Path and invokes callbacks. Returns whether the file was updated. +func (c *Cacher) Sync() (updated bool, err error) { + updated, err = c.fetch() + if err != nil || !updated { + return updated, err + } + return true, c.invokeCallbacks() +} + +func (c *Cacher) fetch() (updated bool, err error) { + c.mu.Lock() + defer c.mu.Unlock() + + timeout := c.Timeout + if timeout == 0 { + timeout = defaultTimeout + } + + req, err := http.NewRequest(http.MethodGet, c.URL, nil) + if err != nil { + return false, err + } + + if c.etag != "" { + req.Header.Set("If-None-Match", c.etag) + } else if c.lastMod != "" { + req.Header.Set("If-Modified-Since", c.lastMod) + } + + client := &http.Client{Timeout: timeout} + resp, err := client.Do(req) + if err != nil { + return false, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotModified { + return false, nil + } + if resp.StatusCode != http.StatusOK { + return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL) + } + + // Write to a temp file then rename for an atomic swap. + tmp := c.Path + ".tmp" + f, err := os.Create(tmp) + if err != nil { + return false, err + } + if _, err := io.Copy(f, resp.Body); err != nil { + f.Close() + os.Remove(tmp) + return false, err + } + f.Close() + + if err := os.Rename(tmp, c.Path); err != nil { + os.Remove(tmp) + return false, err + } + + if etag := resp.Header.Get("ETag"); etag != "" { + c.etag = etag + } + if lm := resp.Header.Get("Last-Modified"); lm != "" { + c.lastMod = lm + } + + return true, nil +} + +func (c *Cacher) invokeCallbacks() error { + for _, fn := range c.callbacks { + if err := fn(); err != nil { + fmt.Fprintf(os.Stderr, "error: reload callback: %v\n", err) + } + } + return nil +} diff --git a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go index 6f106dd..4d937c7 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/blacklist.go +++ b/net/ipcohort/cmd/check-ip-blacklist/blacklist.go @@ -9,13 +9,15 @@ import ( "time" "github.com/therootcompany/golib/net/gitshallow" + "github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/ipcohort" ) type Blacklist struct { atomic.Pointer[ipcohort.Cohort] path string - repo *gitshallow.Repo // nil if file-only + git *gitshallow.Repo + http *httpcache.Cacher } func NewBlacklist(path string) *Blacklist { @@ -24,16 +26,27 @@ func NewBlacklist(path string) *Blacklist { func NewGitBlacklist(gitURL, path string) *Blacklist { repo := gitshallow.New(gitURL, filepath.Dir(path), 1, "") - b := &Blacklist{path: path, repo: repo} + b := &Blacklist{path: path, git: repo} repo.Register(b.reload) return b } +func NewHTTPBlacklist(url, path string) *Blacklist { + cacher := httpcache.New(url, path) + b := &Blacklist{path: path, http: cacher} + cacher.Register(b.reload) + return b +} + func (b *Blacklist) Init(lightGC bool) error { - if b.repo != nil { - return b.repo.Init(lightGC) + switch { + case b.git != nil: + return b.git.Init(lightGC) + case b.http != nil: + return b.http.Init() + default: + return b.reload() } - return b.reload() } func (b *Blacklist) Run(ctx context.Context, lightGC bool) { @@ -43,7 +56,8 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) { for { select { case <-ticker.C: - if updated, err := b.repo.Sync(lightGC); err != nil { + updated, err := b.sync(lightGC) + if err != nil { fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err) } else if updated { fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size()) @@ -54,6 +68,17 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) { } } +func (b *Blacklist) sync(lightGC bool) (bool, error) { + switch { + case b.git != nil: + return b.git.Sync(lightGC) + case b.http != nil: + return b.http.Sync() + default: + return false, nil + } +} + func (b *Blacklist) Contains(ipStr string) bool { return b.Load().Contains(ipStr) } diff --git a/net/ipcohort/cmd/check-ip-blacklist/main.go b/net/ipcohort/cmd/check-ip-blacklist/main.go index 50a4133..b086afc 100644 --- a/net/ipcohort/cmd/check-ip-blacklist/main.go +++ b/net/ipcohort/cmd/check-ip-blacklist/main.go @@ -3,25 +3,29 @@ package main import ( "fmt" "os" + "strings" ) func main() { if len(os.Args) < 3 { - fmt.Fprintf(os.Stderr, "Usage: %s [git-url]\n", os.Args[0]) + fmt.Fprintf(os.Stderr, "Usage: %s [git-url|http-url]\n", os.Args[0]) os.Exit(1) } dataPath := os.Args[1] ipStr := os.Args[2] - gitURL := "" + remoteURL := "" if len(os.Args) >= 4 { - gitURL = os.Args[3] + remoteURL = os.Args[3] } var bl *Blacklist - if gitURL != "" { - bl = NewGitBlacklist(gitURL, dataPath) - } else { + switch { + case strings.HasPrefix(remoteURL, "http://") || strings.HasPrefix(remoteURL, "https://"): + bl = NewHTTPBlacklist(remoteURL, dataPath) + case remoteURL != "": + bl = NewGitBlacklist(remoteURL, dataPath) + default: bl = NewBlacklist(dataPath) }