mirror of
https://github.com/therootcompany/golib.git
synced 2026-04-24 12:48:00 +00:00
refactor(check-ip): IPCheck struct holds flag config + handler method
Follow golang-cli-flags pattern: config struct holds parsed flags and loaded resources; handle and serve are methods on *IPCheck. Adds -V/help pre-parse handling. Inlines clientIP into the handler.
This commit is contained in:
parent
0c281a494b
commit
7aa4493cb0
@ -24,45 +24,78 @@ import (
|
|||||||
const (
|
const (
|
||||||
defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
|
defaultBlocklistRepo = "https://github.com/bitwire-it/ipblocklist.git"
|
||||||
refreshInterval = 47 * time.Minute
|
refreshInterval = 47 * time.Minute
|
||||||
|
version = "dev"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// IPCheck holds the parsed CLI config and the loaded data sources used by
|
||||||
|
// the HTTP handler.
|
||||||
|
type IPCheck struct {
|
||||||
|
Bind string
|
||||||
|
ConfPath string
|
||||||
|
RepoURL string
|
||||||
|
CacheDir string
|
||||||
|
|
||||||
|
inbound *dataset.View[ipcohort.Cohort]
|
||||||
|
outbound *dataset.View[ipcohort.Cohort]
|
||||||
|
geo *geoip.Databases
|
||||||
|
}
|
||||||
|
|
||||||
|
func printVersion(w *os.File) {
|
||||||
|
fmt.Fprintf(w, "check-ip %s\n", version)
|
||||||
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var bind, confPath, repoURL, cacheDir string
|
cfg := IPCheck{}
|
||||||
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
|
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
|
||||||
fs.StringVar(&bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
|
fs.StringVar(&cfg.Bind, "serve", "", "bind address for the HTTP API, e.g. :8080")
|
||||||
fs.StringVar(&confPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
|
fs.StringVar(&cfg.ConfPath, "geoip-conf", "", "path to GeoIP.conf (default: ./GeoIP.conf or ~/.config/maxmind/GeoIP.conf)")
|
||||||
fs.StringVar(&repoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
fs.StringVar(&cfg.RepoURL, "blocklist-repo", defaultBlocklistRepo, "git URL of the blocklist repo (must match bitwire-it layout)")
|
||||||
fs.StringVar(&cacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
|
fs.StringVar(&cfg.CacheDir, "cache-dir", "", "cache parent dir, holds bitwire-it/ and maxmind/ subdirs (default: OS user cache)")
|
||||||
fs.Usage = func() {
|
fs.Usage = func() {
|
||||||
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
|
fmt.Fprintf(os.Stderr, "Usage: %s --serve <bind> [flags]\n", os.Args[0])
|
||||||
fs.PrintDefaults()
|
fs.PrintDefaults()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(os.Args) > 1 {
|
||||||
|
switch os.Args[1] {
|
||||||
|
case "-V", "-version", "--version", "version":
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
os.Exit(0)
|
||||||
|
case "help", "-help", "--help":
|
||||||
|
printVersion(os.Stdout)
|
||||||
|
fmt.Fprintln(os.Stdout, "")
|
||||||
|
fs.SetOutput(os.Stdout)
|
||||||
|
fs.Usage()
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := fs.Parse(os.Args[1:]); err != nil {
|
if err := fs.Parse(os.Args[1:]); err != nil {
|
||||||
if errors.Is(err, flag.ErrHelp) {
|
if errors.Is(err, flag.ErrHelp) {
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
if cacheDir == "" {
|
if cfg.CacheDir == "" {
|
||||||
d, err := os.UserCacheDir()
|
d, err := os.UserCacheDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("cache-dir: %v", err)
|
log.Fatalf("cache-dir: %v", err)
|
||||||
}
|
}
|
||||||
cacheDir = d
|
cfg.CacheDir = d
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||||
defer stop()
|
defer stop()
|
||||||
|
|
||||||
repo := gitshallow.New(repoURL, filepath.Join(cacheDir, "bitwire-it"), 1, "")
|
repo := gitshallow.New(cfg.RepoURL, filepath.Join(cfg.CacheDir, "bitwire-it"), 1, "")
|
||||||
group := dataset.NewGroup(repo)
|
group := dataset.NewGroup(repo)
|
||||||
inbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
cfg.inbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
||||||
return ipcohort.LoadFiles(
|
return ipcohort.LoadFiles(
|
||||||
repo.FilePath("tables/inbound/single_ips.txt"),
|
repo.FilePath("tables/inbound/single_ips.txt"),
|
||||||
repo.FilePath("tables/inbound/networks.txt"),
|
repo.FilePath("tables/inbound/networks.txt"),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
outbound := dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
cfg.outbound = dataset.Add(group, func() (*ipcohort.Cohort, error) {
|
||||||
return ipcohort.LoadFiles(
|
return ipcohort.LoadFiles(
|
||||||
repo.FilePath("tables/outbound/single_ips.txt"),
|
repo.FilePath("tables/outbound/single_ips.txt"),
|
||||||
repo.FilePath("tables/outbound/networks.txt"),
|
repo.FilePath("tables/outbound/networks.txt"),
|
||||||
@ -75,9 +108,9 @@ func main() {
|
|||||||
log.Printf("refresh: %v", err)
|
log.Printf("refresh: %v", err)
|
||||||
})
|
})
|
||||||
|
|
||||||
maxmind := filepath.Join(cacheDir, "maxmind")
|
maxmind := filepath.Join(cfg.CacheDir, "maxmind")
|
||||||
geo, err := geoip.OpenDatabases(
|
geo, err := geoip.OpenDatabases(
|
||||||
confPath,
|
cfg.ConfPath,
|
||||||
filepath.Join(maxmind, geoip.CityEdition+".mmdb"),
|
filepath.Join(maxmind, geoip.CityEdition+".mmdb"),
|
||||||
filepath.Join(maxmind, geoip.ASNEdition+".mmdb"),
|
filepath.Join(maxmind, geoip.ASNEdition+".mmdb"),
|
||||||
)
|
)
|
||||||
@ -85,11 +118,12 @@ func main() {
|
|||||||
log.Fatalf("geoip: %v", err)
|
log.Fatalf("geoip: %v", err)
|
||||||
}
|
}
|
||||||
defer func() { _ = geo.Close() }()
|
defer func() { _ = geo.Close() }()
|
||||||
|
cfg.geo = geo
|
||||||
|
|
||||||
if bind == "" {
|
if cfg.Bind == "" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := serve(ctx, bind, inbound, outbound, geo); err != nil {
|
if err := cfg.serve(ctx); err != nil {
|
||||||
log.Fatalf("serve: %v", err)
|
log.Fatalf("serve: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -12,8 +12,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/therootcompany/golib/net/geoip"
|
"github.com/therootcompany/golib/net/geoip"
|
||||||
"github.com/therootcompany/golib/net/ipcohort"
|
|
||||||
"github.com/therootcompany/golib/sync/dataset"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Result is the JSON verdict for a single IP.
|
// Result is the JSON verdict for a single IP.
|
||||||
@ -25,25 +23,26 @@ type Result struct {
|
|||||||
Geo geoip.Info `json:"geo,omitzero"`
|
Geo geoip.Info `json:"geo,omitzero"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func serve(
|
func (c *IPCheck) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx context.Context,
|
|
||||||
bind string,
|
|
||||||
inbound, outbound *dataset.View[ipcohort.Cohort],
|
|
||||||
geo *geoip.Databases,
|
|
||||||
) error {
|
|
||||||
handle := func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
|
ip := strings.TrimSpace(r.URL.Query().Get("ip"))
|
||||||
if ip == "" {
|
if ip == "" {
|
||||||
ip = clientIP(r)
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||||
|
first, _, _ := strings.Cut(xff, ",")
|
||||||
|
ip = strings.TrimSpace(first)
|
||||||
|
} else if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||||
|
ip = host
|
||||||
|
} else {
|
||||||
|
ip = r.RemoteAddr
|
||||||
}
|
}
|
||||||
in := inbound.Value().Contains(ip)
|
}
|
||||||
out := outbound.Value().Contains(ip)
|
in := c.inbound.Value().Contains(ip)
|
||||||
|
out := c.outbound.Value().Contains(ip)
|
||||||
res := Result{
|
res := Result{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Blocked: in || out,
|
Blocked: in || out,
|
||||||
BlockedInbound: in,
|
BlockedInbound: in,
|
||||||
BlockedOutbound: out,
|
BlockedOutbound: out,
|
||||||
Geo: geo.Lookup(ip),
|
Geo: c.geo.Lookup(ip),
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Query().Get("format") == "json" ||
|
if r.URL.Query().Get("format") == "json" ||
|
||||||
@ -84,12 +83,13 @@ func serve(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *IPCheck) serve(ctx context.Context) error {
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
mux.HandleFunc("GET /check", handle)
|
mux.HandleFunc("GET /check", c.handle)
|
||||||
mux.HandleFunc("GET /{$}", handle)
|
mux.HandleFunc("GET /{$}", c.handle)
|
||||||
|
|
||||||
srv := &http.Server{
|
srv := &http.Server{
|
||||||
Addr: bind,
|
Addr: c.Bind,
|
||||||
Handler: mux,
|
Handler: mux,
|
||||||
BaseContext: func(_ net.Listener) context.Context { return ctx },
|
BaseContext: func(_ net.Listener) context.Context { return ctx },
|
||||||
}
|
}
|
||||||
@ -100,22 +100,9 @@ func serve(
|
|||||||
_ = srv.Shutdown(shutCtx)
|
_ = srv.Shutdown(shutCtx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
log.Printf("listening on %s", bind)
|
log.Printf("listening on %s", c.Bind)
|
||||||
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientIP extracts the caller's IP, honoring X-Forwarded-For when present.
|
|
||||||
func clientIP(r *http.Request) string {
|
|
||||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
||||||
first, _, _ := strings.Cut(xff, ",")
|
|
||||||
return strings.TrimSpace(first)
|
|
||||||
}
|
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if err != nil {
|
|
||||||
return r.RemoteAddr
|
|
||||||
}
|
|
||||||
return host
|
|
||||||
}
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user