mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 20:58:00 +00:00
feat: add net/httpcache; wire git+http+file into Blacklist
This commit is contained in:
parent
4b0f943bd7
commit
a9adc3dc18
125
net/httpcache/httpcache.go
Normal file
125
net/httpcache/httpcache.go
Normal file
@ -0,0 +1,125 @@
|
||||
package httpcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultTimeout = 30 * time.Second
|
||||
|
||||
// Cacher fetches a URL to a local file, using ETag/Last-Modified to skip
|
||||
// unchanged responses. Calls registered callbacks when the file changes.
|
||||
type Cacher struct {
|
||||
URL string
|
||||
Path string
|
||||
Timeout time.Duration // 0 uses 30s
|
||||
|
||||
mu sync.Mutex
|
||||
etag string
|
||||
lastMod string
|
||||
callbacks []func() error
|
||||
}
|
||||
|
||||
// New creates a Cacher that fetches URL and writes it to path.
|
||||
func New(url, path string) *Cacher {
|
||||
return &Cacher{URL: url, Path: path}
|
||||
}
|
||||
|
||||
// Register adds a callback invoked after each successful fetch.
|
||||
func (c *Cacher) Register(fn func() error) {
|
||||
c.callbacks = append(c.callbacks, fn)
|
||||
}
|
||||
|
||||
// Init fetches the URL unconditionally (no cached headers yet) and invokes
|
||||
// all callbacks, ensuring files are loaded on startup.
|
||||
func (c *Cacher) Init() error {
|
||||
if _, err := c.fetch(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.invokeCallbacks()
|
||||
}
|
||||
|
||||
// Sync sends a conditional GET. If the server returns new content, writes it
|
||||
// to Path and invokes callbacks. Returns whether the file was updated.
|
||||
func (c *Cacher) Sync() (updated bool, err error) {
|
||||
updated, err = c.fetch()
|
||||
if err != nil || !updated {
|
||||
return updated, err
|
||||
}
|
||||
return true, c.invokeCallbacks()
|
||||
}
|
||||
|
||||
func (c *Cacher) fetch() (updated bool, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
timeout := c.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = defaultTimeout
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, c.URL, nil)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if c.etag != "" {
|
||||
req.Header.Set("If-None-Match", c.etag)
|
||||
} else if c.lastMod != "" {
|
||||
req.Header.Set("If-Modified-Since", c.lastMod)
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
return false, nil
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return false, fmt.Errorf("unexpected status %d fetching %s", resp.StatusCode, c.URL)
|
||||
}
|
||||
|
||||
// Write to a temp file then rename for an atomic swap.
|
||||
tmp := c.Path + ".tmp"
|
||||
f, err := os.Create(tmp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if _, err := io.Copy(f, resp.Body); err != nil {
|
||||
f.Close()
|
||||
os.Remove(tmp)
|
||||
return false, err
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if err := os.Rename(tmp, c.Path); err != nil {
|
||||
os.Remove(tmp)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if etag := resp.Header.Get("ETag"); etag != "" {
|
||||
c.etag = etag
|
||||
}
|
||||
if lm := resp.Header.Get("Last-Modified"); lm != "" {
|
||||
c.lastMod = lm
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *Cacher) invokeCallbacks() error {
|
||||
for _, fn := range c.callbacks {
|
||||
if err := fn(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: reload callback: %v\n", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -9,13 +9,15 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/therootcompany/golib/net/gitshallow"
|
||||
"github.com/therootcompany/golib/net/httpcache"
|
||||
"github.com/therootcompany/golib/net/ipcohort"
|
||||
)
|
||||
|
||||
type Blacklist struct {
|
||||
atomic.Pointer[ipcohort.Cohort]
|
||||
path string
|
||||
repo *gitshallow.Repo // nil if file-only
|
||||
git *gitshallow.Repo
|
||||
http *httpcache.Cacher
|
||||
}
|
||||
|
||||
func NewBlacklist(path string) *Blacklist {
|
||||
@ -24,17 +26,28 @@ func NewBlacklist(path string) *Blacklist {
|
||||
|
||||
func NewGitBlacklist(gitURL, path string) *Blacklist {
|
||||
repo := gitshallow.New(gitURL, filepath.Dir(path), 1, "")
|
||||
b := &Blacklist{path: path, repo: repo}
|
||||
b := &Blacklist{path: path, git: repo}
|
||||
repo.Register(b.reload)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Blacklist) Init(lightGC bool) error {
|
||||
if b.repo != nil {
|
||||
return b.repo.Init(lightGC)
|
||||
func NewHTTPBlacklist(url, path string) *Blacklist {
|
||||
cacher := httpcache.New(url, path)
|
||||
b := &Blacklist{path: path, http: cacher}
|
||||
cacher.Register(b.reload)
|
||||
return b
|
||||
}
|
||||
|
||||
func (b *Blacklist) Init(lightGC bool) error {
|
||||
switch {
|
||||
case b.git != nil:
|
||||
return b.git.Init(lightGC)
|
||||
case b.http != nil:
|
||||
return b.http.Init()
|
||||
default:
|
||||
return b.reload()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
|
||||
ticker := time.NewTicker(47 * time.Minute)
|
||||
@ -43,7 +56,8 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if updated, err := b.repo.Sync(lightGC); err != nil {
|
||||
updated, err := b.sync(lightGC)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "error: blacklist sync: %v\n", err)
|
||||
} else if updated {
|
||||
fmt.Fprintf(os.Stderr, "blacklist: reloaded %d entries\n", b.Size())
|
||||
@ -54,6 +68,17 @@ func (b *Blacklist) Run(ctx context.Context, lightGC bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Blacklist) sync(lightGC bool) (bool, error) {
|
||||
switch {
|
||||
case b.git != nil:
|
||||
return b.git.Sync(lightGC)
|
||||
case b.http != nil:
|
||||
return b.http.Sync()
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Blacklist) Contains(ipStr string) bool {
|
||||
return b.Load().Contains(ipStr)
|
||||
}
|
||||
|
||||
@ -3,25 +3,29 @@ package main
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) < 3 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url]\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s <blacklist.csv> <ip-address> [git-url|http-url]\n", os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
dataPath := os.Args[1]
|
||||
ipStr := os.Args[2]
|
||||
gitURL := ""
|
||||
remoteURL := ""
|
||||
if len(os.Args) >= 4 {
|
||||
gitURL = os.Args[3]
|
||||
remoteURL = os.Args[3]
|
||||
}
|
||||
|
||||
var bl *Blacklist
|
||||
if gitURL != "" {
|
||||
bl = NewGitBlacklist(gitURL, dataPath)
|
||||
} else {
|
||||
switch {
|
||||
case strings.HasPrefix(remoteURL, "http://") || strings.HasPrefix(remoteURL, "https://"):
|
||||
bl = NewHTTPBlacklist(remoteURL, dataPath)
|
||||
case remoteURL != "":
|
||||
bl = NewGitBlacklist(remoteURL, dataPath)
|
||||
default:
|
||||
bl = NewBlacklist(dataPath)
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user