AJ ONeal 8ebc571928
refactor: apply check-ip CLI conventions to sibling cmds
Propagate the patterns used in cmd/check-ip to the other command-line
tools touched by this PR:

- flag.FlagSet + Config struct instead of package-level flag.String
  pointers (geoip-update, ipcohort-contains, git-shallow-sync).
- -V/--version/version and help/-help/--help handled before Parse,
  matching the project's CLI conventions.
- Stderr "Loading X... Nms (counts)" progress lines on the stages that
  actually take time: blocklist cohort parse (ipcohort-contains),
  per-edition fetch (geoip-update), and repo sync (git-shallow-sync).
  Stdout stays machine-parseable.
2026-04-20 19:13:47 -06:00

137 lines
2.9 KiB
Go

// ipcohort-contains checks whether one or more IP addresses appear in a set
// of cohort files (plain text, one IP/CIDR per line).
//
// Usage:
//
// ipcohort-contains [flags] <file>... -- <ip>...
// ipcohort-contains [flags] --ip <ip> <file>...
// echo "<ip>" | ipcohort-contains <file>...
//
// Exit code: 0 if all queried IPs are found, 1 if any are not found, 2 on error.
package main
import (
"bufio"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/therootcompany/golib/net/ipcohort"
)
const version = "dev"
type Config struct {
IP string
}
func main() {
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.IP, "ip", "", "IP address to check (alternative to -- separator)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <file>... -- <ip>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s --ip <ip> <file>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " echo <ip> | %s <file>...\n", os.Args[0])
fs.PrintDefaults()
fmt.Fprintln(os.Stderr, "Exit: 0=all found, 1=not found, 2=error")
}
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(2)
}
args := fs.Args()
var filePaths, ips []string
switch {
case cfg.IP != "":
filePaths = args
ips = []string{cfg.IP}
default:
sep := -1
for i, a := range args {
if a == "--" {
sep = i
break
}
}
if sep >= 0 {
filePaths = args[:sep]
ips = args[sep+1:]
} else {
filePaths = args
}
}
if len(filePaths) == 0 {
fmt.Fprintln(os.Stderr, "error: at least one file path required")
fs.Usage()
os.Exit(2)
}
fmt.Fprint(os.Stderr, "Loading cohort... ")
t := time.Now()
cohort, err := ipcohort.LoadFiles(filePaths...)
if err != nil {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(2)
}
fmt.Fprintf(os.Stderr, "%s (entries=%d)\n",
time.Since(t).Round(time.Millisecond),
cohort.Size(),
)
if len(ips) == 0 {
sc := bufio.NewScanner(os.Stdin)
for sc.Scan() {
if line := strings.TrimSpace(sc.Text()); line != "" && !strings.HasPrefix(line, "#") {
ips = append(ips, line)
}
}
if err := sc.Err(); err != nil {
fmt.Fprintf(os.Stderr, "error reading stdin: %v\n", err)
os.Exit(2)
}
}
if len(ips) == 0 {
fmt.Fprintln(os.Stderr, "error: no IP addresses to check")
fs.Usage()
os.Exit(2)
}
fmt.Fprintln(os.Stderr)
allFound := true
for _, ip := range ips {
if cohort.Contains(ip) {
fmt.Printf("%s\tFOUND\n", ip)
} else {
fmt.Printf("%s\tNOT FOUND\n", ip)
allFound = false
}
}
if !allFound {
os.Exit(1)
}
}