mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
feat(check-ip): accept IP args, require --serve or args
Positional args are IPs to check and print; at least one IP or --serve must be provided. Refactor server.go: split handle into lookup + writeText methods so main can reuse them. geoip is no longer managed via dataset.Group — it's a single atomic.Pointer[geoip.Databases]. The Fetcher (httpcache or PollFiles) still drives refresh, but via an inline ticker in the serve branch that fetches, reopens, and swaps.
This commit is contained in:
parent
6bcb493d02
commit
b9295608db
@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@ -43,7 +44,7 @@ type IPCheck struct {
|
||||
|
||||
inbound *dataset.View[ipcohort.Cohort]
|
||||
outbound *dataset.View[ipcohort.Cohort]
|
||||
geo *dataset.View[geoip.Databases]
|
||||
geo atomic.Pointer[geoip.Databases]
|
||||
}
|
||||
|
||||
func main() {
|
||||
@ -54,7 +55,8 @@ func main() {
|
||||
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
||||
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <ip> [ip...]\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, " %s --serve <bind> [flags]\n", os.Args[0])
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
|
||||
@ -77,6 +79,12 @@ func main() {
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
ips := fs.Args()
|
||||
if cfg.Bind == "" && len(ips) == 0 {
|
||||
fmt.Fprintln(os.Stderr, "error: provide at least one IP argument or --serve <bind>")
|
||||
fs.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
if cfg.CacheDir == "" {
|
||||
d, err := os.UserCacheDir()
|
||||
if err != nil {
|
||||
@ -160,15 +168,19 @@ func main() {
|
||||
} else {
|
||||
geoFetcher = dataset.PollFiles(cityTarPath, asnTarPath)
|
||||
}
|
||||
geoGroup := dataset.NewGroup(geoFetcher)
|
||||
cfg.geo = dataset.Add(geoGroup, func() (*geoip.Databases, error) {
|
||||
return geoip.Open(maxmindDir)
|
||||
})
|
||||
if err := geoGroup.Load(context.Background()); err != nil {
|
||||
if _, err := geoFetcher.Fetch(); err != nil {
|
||||
log.Fatalf("geoip: %v", err)
|
||||
}
|
||||
defer func() { _ = cfg.geo.Value().Close() }()
|
||||
geoDB, err := geoip.Open(maxmindDir)
|
||||
if err != nil {
|
||||
log.Fatalf("geoip: %v", err)
|
||||
}
|
||||
cfg.geo.Store(geoDB)
|
||||
defer func() { _ = cfg.geo.Load().Close() }()
|
||||
|
||||
for _, ip := range ips {
|
||||
cfg.writeText(os.Stdout, cfg.lookup(ip))
|
||||
}
|
||||
if cfg.Bind == "" {
|
||||
return
|
||||
}
|
||||
@ -178,9 +190,33 @@ func main() {
|
||||
go group.Tick(ctx, refreshInterval, func(err error) {
|
||||
log.Printf("blocklists refresh: %v", err)
|
||||
})
|
||||
go geoGroup.Tick(ctx, refreshInterval, func(err error) {
|
||||
go func() {
|
||||
t := time.NewTicker(refreshInterval)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-t.C:
|
||||
updated, err := geoFetcher.Fetch()
|
||||
if err != nil {
|
||||
log.Printf("geoip refresh: %v", err)
|
||||
})
|
||||
continue
|
||||
}
|
||||
if !updated {
|
||||
continue
|
||||
}
|
||||
db, err := geoip.Open(maxmindDir)
|
||||
if err != nil {
|
||||
log.Printf("geoip refresh: %v", err)
|
||||
continue
|
||||
}
|
||||
if old := cfg.geo.Swap(db); old != nil {
|
||||
_ = old.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
if err := cfg.serve(ctx); err != nil {
|
||||
log.Fatalf("serve: %v", err)
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
@ -23,47 +24,31 @@ type Result struct {
|
||||
Geo geoip.Info `json:"geo,omitzero"`
|
||||
}
|
||||
|
||||
func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
|
||||
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
|
||||
if ip == "" {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
first, _, _ := strings.Cut(xff, ",")
|
||||
ip = strings.TrimSpace(first)
|
||||
} else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
ip = host
|
||||
} else {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
}
|
||||
// lookup builds a Result for ip against the currently loaded blocklists
|
||||
// and GeoIP databases.
|
||||
func (c *IPCheck) lookup(ip string) Result {
|
||||
in := c.inbound.Value().Contains(ip)
|
||||
out := c.outbound.Value().Contains(ip)
|
||||
res := Result{
|
||||
return Result{
|
||||
IP: ip,
|
||||
Blocked: in || out,
|
||||
BlockedInbound: in,
|
||||
BlockedOutbound: out,
|
||||
Geo: c.geo.Value().Lookup(ip),
|
||||
Geo: c.geo.Load().Lookup(ip),
|
||||
}
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("format") == "json" ||
|
||||
strings.Contains(r.Header.Get("Accept"), "application/json") {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
_ = enc.Encode(res)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
// writeText renders res as human-readable plain text.
|
||||
func (c *IPCheck) writeText(w io.Writer, res Result) {
|
||||
switch {
|
||||
case in && out:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", ip)
|
||||
case in:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", ip)
|
||||
case out:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", ip)
|
||||
case res.BlockedInbound && res.BlockedOutbound:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (inbound + outbound)\n", res.IP)
|
||||
case res.BlockedInbound:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (inbound)\n", res.IP)
|
||||
case res.BlockedOutbound:
|
||||
fmt.Fprintf(w, "%s is BLOCKED (outbound)\n", res.IP)
|
||||
default:
|
||||
fmt.Fprintf(w, "%s is allowed\n", ip)
|
||||
fmt.Fprintf(w, "%s is allowed\n", res.IP)
|
||||
}
|
||||
var parts []string
|
||||
if res.Geo.City != "" {
|
||||
@ -83,6 +68,33 @@ func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
|
||||
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
|
||||
if ip == "" {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
first, _, _ := strings.Cut(xff, ",")
|
||||
ip = strings.TrimSpace(first)
|
||||
} else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
ip = host
|
||||
} else {
|
||||
ip = r.RemoteAddr
|
||||
}
|
||||
}
|
||||
res := c.lookup(ip)
|
||||
|
||||
if r.URL.Query().Get("format") == "json" ||
|
||||
strings.Contains(r.Header.Get("Accept"), "application/json") {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
enc := json.NewEncoder(w)
|
||||
enc.SetIndent("", " ")
|
||||
_ = enc.Encode(res)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
c.writeText(w, res)
|
||||
}
|
||||
|
||||
func (c *IPCheck) serve(ctx context.Context) error {
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /check", c.handle)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user