feat: add net/httpcache; wire git+http+file into Blacklist

This commit is contained in:
AJ ONeal 2026-04-19 22:57:36 -06:00
parent 4b0f943bd7
commit a9adc3dc18
No known key found for this signature in database
3 changed files with 166 additions and 12 deletions

125
net/httpcache/httpcache.go Normal file
View File

@ -0,0 +1,125 @@
package httpcache
import (
"fmt"
"io"
"net/http"
"os"
"sync"
"time"
)
const defaultTimeout = 30 * time.Second
// Cacher fetches a URL to a local file, using ETag/Last-Modified to skip
// unchanged responses. Calls registered callbacks when the file changes.
type Cacher struct {
URL string
Path string
Timeout time.Duration // 0 uses 30s
mu sync.Mutex
etag string
lastMod string
callbacks []func() error
}
// New creates a Cacher that fetches URL and writes it to path.
func New(url, path string) *Cacher {
return &Cacher{URL: url, Path: path}
}
// Register adds a callback invoked after each successful fetch.
func (c *Cacher) Register(fn func() error) {
c.callbacks = append(c.callbacks, fn)
}
// Init fetches the URL unconditionally (no cached headers yet) and invokes
// all callbacks, ensuring files are loaded on startup.
func (c *Cacher) Init() error {
if _, err := c.fetch(); err != nil {
return err
}
return c.invokeCallbacks()
}
// Sync sends a conditional GET. If the server returns new content, writes it
// to Path and invokes callbacks. Returns whether the file was updated.
func (c *Cacher) Sync() (updated bool, err error) {
updated, err = c.fetch()
if err != nil || !updated {
return updated, err
}
return true, c.invokeCallbacks()
}
func (c *Cacher) fetch() (updated bool, err error) {
c.mu.Lock()
defer c.mu.Unlock()
timeout := c.Timeout
if timeout == 0 {
timeout = defaultTimeout
}
req, err := http.NewRequest(http.MethodGet, c.URL, nil)
if err != nil {
return false, err
}
if c.etag != "" {
req.Header.Set("If-None-Match", c.etag)
} else if c.lastMod != "" {
req.Header.Set("If-Modified-Since", c.lastMod)
}
client := &http.Client{Timeout: timeout}
resp, err := client.Do(req)
if err != nil {
return false, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotModified {
return false, nil
}
if resp.StatusCode != http.StatusOK {
return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL)
}
// Write to a temp file then rename for an atomic swap.
tmp := c.Path + ".tmp"
f, err := os.Create(tmp)
if err != nil {
return false, err
}
if _, err := io.Copy(f, resp.Body); err != nil {
f.Close()
os.Remove(tmp)
return false, err
}
f.Close()
if err := os.Rename(tmp, c.Path); err != nil {
os.Remove(tmp)
return false, err
}
if etag := resp.Header.Get("ETag"); etag != "" {
c.etag = etag
}
if lm := resp.Header.Get("Last-Modified"); lm != "" {
c.lastMod = lm
}
return true, nil
}
func (c *Cacher) invokeCallbacks() error {
for _, fn := range c.callbacks {
if err := fn(); err != nil {
fmt.Fprintf(os.Stderr, "error: reload callback: %v\n", err)
}
}
return nil
}

View File

@ -9,13 +9,15 @@ import (
"time" "time"
"github.com/therootcompany/golib/net/gitshallow" "github.com/therootcompany/golib/net/gitshallow"
"github.com/therootcompany/golib/net/httpcache"
"github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/net/ipcohort"
) )
type Blacklist struct { type Blacklist struct {
atomic.Pointer[ipcohort.Cohort] atomic.Pointer[ipcohort.Cohort]
path string path string
repo *gitshallow.Repo // nil if file-only git *gitshallow.Repo
http *httpcache.Cacher
} }
func NewBlacklist(path string) *Blacklist { func NewBlacklist(path string) *Blacklist {
@ -24,16 +26,27 @@ func NewBlacklist(path string) *Blacklist {
func NewGitBlacklist(gitURL, path string) *Blacklist { func NewGitBlacklist(gitURL, path string) *Blacklist {
repo := gitshallow.New(gitURL, filepath.Dir(path), 1, "") repo := gitshallow.New(gitURL, filepath.Dir(path), 1, "")
b := &Blacklist{path: path, repo: repo} b := &Blacklist{path: path, git: repo}
repo.Register(b.reload) repo.Register(b.reload)
return b return b
} }
func NewHTTPBlacklist(url, path string) *Blacklist {
cacher := httpcache.New(url, path)
b := &Blacklist{path: path, http: cacher}
cacher.Register(b.reload)
return b
}
func (b *Blacklist) Init(lightGC bool) error { func (b *Blacklist) Init(lightGC bool) error {
if b.repo != nil { switch {
return b.repo.Init(lightGC) case b.git != nil:
} return b.git.Init(lightGC)
case b.http != nil:
return b.http.Init()
default:
return b.reload() return b.reload()
}
} }
func (b *Blacklist) Run(ctx context.Context, lightGC bool) { func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
@ -43,7 +56,8 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if updated, err := b.repo.Sync(lightGC); err != nil { updated, err := b.sync(lightGC)
if err != nil {
fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err) fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err)
} else if updated { } else if updated {
fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size()) fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size())
@ -54,6 +68,17 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
} }
} }
func (b *Blacklist) sync(lightGC bool) (bool, error) {
switch {
case b.git != nil:
return b.git.Sync(lightGC)
case b.http != nil:
return b.http.Sync()
default:
return false, nil
}
}
func (b *Blacklist) Contains(ipStr string) bool { func (b *Blacklist) Contains(ipStr string) bool {
return b.Load().Contains(ipStr) return b.Load().Contains(ipStr)
} }

View File

@ -3,25 +3,29 @@ package main
import ( import (
"fmt" "fmt"
"os" "os"
"strings"
) )
func main() { func main() {
if len(os.Args) < 3 { if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url]\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url|http-url]\n", os.Args[0])
os.Exit(1) os.Exit(1)
} }
dataPath := os.Args[1] dataPath := os.Args[1]
ipStr := os.Args[2] ipStr := os.Args[2]
gitURL := "" remoteURL := ""
if len(os.Args) >= 4 { if len(os.Args) >= 4 {
gitURL = os.Args[3] remoteURL = os.Args[3]
} }
var bl *Blacklist var bl *Blacklist
if gitURL != "" { switch {
bl = NewGitBlacklist(gitURL, dataPath) case strings.HasPrefix(remoteURL, "http://") || strings.HasPrefix(remoteURL, "https://"):
} else { bl = NewHTTPBlacklist(remoteURL, dataPath)
case remoteURL != "":
bl = NewGitBlacklist(remoteURL, dataPath)
default:
bl = NewBlacklist(dataPath) bl = NewBlacklist(dataPath)
} }