diff --git a/net/gitshallow/gitshallow.go b/net/gitshallow/gitshallow.go index 80ae863..3a68628 100644 --- a/net/gitshallow/gitshallow.go +++ b/net/gitshallow/gitshallow.go @@ -11,10 +11,11 @@ import ( // Repo manages a shallow git clone used as a periodically-updated data source. type Repo struct { - URL string - Path string - Depth int // 0 defaults to 1, -1 for all - Branch string // Optional: specific branch to clone/pull + URL string + Path string + Depth int // 0 defaults to 1, -1 for all + Branch string // Optional: specific branch to clone/pull + LightGC bool // true = skip aggressive GC; false (default) = aggressive+prune mu sync.Mutex } @@ -178,6 +179,11 @@ func (r *Repo) Sync(lightGC bool) (bool, error) { return r.syncGit(lightGC) } +// Fetch satisfies httpcache.Syncer using the Repo's LightGC setting. +func (r *Repo) Fetch() (bool, error) { + return r.syncGit(r.LightGC) +} + func (r *Repo) syncGit(lightGC bool) (updated bool, err error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/net/httpcache/httpcache.go b/net/httpcache/httpcache.go index cee6387..df9f786 100644 --- a/net/httpcache/httpcache.go +++ b/net/httpcache/httpcache.go @@ -11,6 +11,12 @@ import ( const defaultTimeout = 30 * time.Second +// Syncer is implemented by any value that can fetch a remote resource and +// report whether it changed. Both *Cacher and *gitshallow.Repo satisfy this. +type Syncer interface { + Fetch() (updated bool, err error) +} + // Cacher fetches a URL to a local file, using ETag/Last-Modified to skip // unchanged responses. // @@ -129,12 +135,16 @@ func (c *Cacher) Fetch() (updated bool, err error) { if err != nil { return false, err } - if _, err := io.Copy(f, resp.Body); err != nil { - f.Close() + n, err := io.Copy(f, resp.Body) + f.Close() + if err != nil { os.Remove(tmp) return false, err } - f.Close() + if n == 0 { + os.Remove(tmp) + return false, fmt.Errorf("empty response from %s", c.URL) + } if err := os.Rename(tmp, c.Path); err != nil { os.Remove(tmp) return false, err diff --git a/net/ipcohort/cmd/check-ip/blacklist.go b/net/ipcohort/cmd/check-ip/blacklist.go index 0c714c8..6485470 100644 --- a/net/ipcohort/cmd/check-ip/blacklist.go +++ b/net/ipcohort/cmd/check-ip/blacklist.go @@ -21,9 +21,8 @@ type Sources struct { inboundPaths []string outboundPaths []string - git *gitshallow.Repo - httpInbound []*httpcache.Cacher - httpOutbound []*httpcache.Cacher + gitRepo *gitshallow.Repo // non-nil for git source; used by Init for clone-if-missing + syncs []httpcache.Syncer // all syncable sources (git repo or HTTP cachers) } func newFileSources(whitelist, inbound, outbound []string) *Sources { @@ -42,11 +41,13 @@ func newGitSources(gitURL, repoDir string, whitelist, inboundRel, outboundRel [] } return out } + repo := gitshallow.New(gitURL, repoDir, 1, "") return &Sources{ whitelistPaths: whitelist, inboundPaths: abs(inboundRel), outboundPaths: abs(outboundRel), - git: gitshallow.New(gitURL, repoDir, 1, ""), + gitRepo: repo, + syncs: []httpcache.Syncer{repo}, } } @@ -54,60 +55,39 @@ func newHTTPSources(whitelist []string, inbound, outbound []HTTPSource) *Sources s := &Sources{whitelistPaths: whitelist} for _, src := range inbound { s.inboundPaths = append(s.inboundPaths, src.Path) - s.httpInbound = append(s.httpInbound, httpcache.New(src.URL, src.Path)) + s.syncs = append(s.syncs, httpcache.New(src.URL, src.Path)) } for _, src := range outbound { s.outboundPaths = append(s.outboundPaths, src.Path) - s.httpOutbound = append(s.httpOutbound, httpcache.New(src.URL, src.Path)) + s.syncs = append(s.syncs, httpcache.New(src.URL, src.Path)) } return s } -// Fetch pulls updates from the remote (git or HTTP). -// Returns whether any new data was received. -func (s *Sources) Fetch(lightGC bool) (bool, error) { - switch { - case s.git != nil: - return s.git.Sync(lightGC) - case len(s.httpInbound) > 0 || len(s.httpOutbound) > 0: - var anyUpdated bool - for _, c := range s.httpInbound { - updated, err := c.Fetch() - if err != nil { - return anyUpdated, err - } - anyUpdated = anyUpdated || updated +// Fetch pulls updates from all sources. Returns whether any new data arrived. +// Satisfies httpcache.Syncer. +func (s *Sources) Fetch() (bool, error) { + var anyUpdated bool + for _, syn := range s.syncs { + updated, err := syn.Fetch() + if err != nil { + return anyUpdated, err } - for _, c := range s.httpOutbound { - updated, err := c.Fetch() - if err != nil { - return anyUpdated, err - } - anyUpdated = anyUpdated || updated - } - return anyUpdated, nil - default: - return false, nil + anyUpdated = anyUpdated || updated } + return anyUpdated, nil } -// Init ensures the remote is ready (clones if needed, fetches HTTP files). -// Always returns true so the caller knows to load data on startup. -func (s *Sources) Init(lightGC bool) error { - switch { - case s.git != nil: - _, err := s.git.Init(lightGC) +// Init ensures remotes are ready. For git: clones if missing then syncs. +// For HTTP: fetches each cacher unconditionally on first run. +func (s *Sources) Init() error { + if s.gitRepo != nil { + _, err := s.gitRepo.Init(s.gitRepo.LightGC) return err - case len(s.httpInbound) > 0 || len(s.httpOutbound) > 0: - for _, c := range s.httpInbound { - if _, err := c.Fetch(); err != nil { - return err - } - } - for _, c := range s.httpOutbound { - if _, err := c.Fetch(); err != nil { - return err - } + } + for _, syn := range s.syncs { + if _, err := syn.Fetch(); err != nil { + return err } } return nil diff --git a/net/ipcohort/cmd/check-ip/main.go b/net/ipcohort/cmd/check-ip/main.go index e11c288..c9ad092 100644 --- a/net/ipcohort/cmd/check-ip/main.go +++ b/net/ipcohort/cmd/check-ip/main.go @@ -77,7 +77,7 @@ func main() { var whitelist, inbound, outbound atomic.Pointer[ipcohort.Cohort] - if err := src.Init(false); err != nil { + if err := src.Init(); err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } @@ -186,7 +186,7 @@ func runLoop(ctx context.Context, src *Sources, select { case <-ticker.C: // Blocklists. - if updated, err := src.Fetch(false); err != nil { + if updated, err := src.Fetch(); err != nil { fmt.Fprintf(os.Stderr, "error: blocklist sync: %v\n", err) } else if updated { if err := reloadBlocklists(src, whitelist, inbound, outbound); err != nil {