refactor: apply check-ip CLI conventions to sibling cmds

Propagate the patterns used in cmd/check-ip to the other command-line
tools touched by this PR:

- flag.FlagSet + Config struct instead of package-level flag.String
  pointers (geoip-update, ipcohort-contains, git-shallow-sync).
- -V/--version/version and help/-help/--help handled before Parse,
  matching the project's CLI conventions.
- Stderr "Loading X... Nms (counts)" progress lines on the stages that
  actually take time: blocklist cohort parse (ipcohort-contains),
  per-edition fetch (geoip-update), and repo sync (git-shallow-sync).
  Stdout stays machine-parseable.
This commit is contained in:
AJ ONeal 2026-04-20 19:13:47 -06:00
parent a181133c2f
commit 8ebc571928
No known key found for this signature in database
3 changed files with 159 additions and 75 deletions

View File

@ -1,6 +1,9 @@
// geoip-update downloads GeoLite2 edition tarballs listed in GeoIP.conf
// via conditional HTTP GETs, writing them to the configured directory.
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"net/http" "net/http"
@ -12,26 +15,58 @@ import (
"github.com/therootcompany/golib/net/httpcache" "github.com/therootcompany/golib/net/httpcache"
) )
const version = "dev"
type Config struct {
ConfPath string
Dir string
FreshDays int
}
func main() { func main() {
configPath := flag.String("config", "GeoIP.conf", "path to GeoIP.conf") cfg := Config{}
dir := flag.String("dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)") fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
freshDays := flag.Int("fresh-days", 3, "skip download if file is younger than N days") fs.StringVar(&cfg.ConfPath, "config", "GeoIP.conf", "path to GeoIP.conf")
flag.Parse() fs.StringVar(&cfg.Dir, "dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)")
fs.IntVar(&cfg.FreshDays, "fresh-days", 3, "skip download if file is younger than N days")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags]\n", os.Args[0])
fs.PrintDefaults()
}
data, err := os.ReadFile(*configPath) if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "geoip-update %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "geoip-update %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1)
}
data, err := os.ReadFile(cfg.ConfPath)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
cfg, err := geoip.ParseConf(string(data)) conf, err := geoip.ParseConf(string(data))
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1) os.Exit(1)
} }
outDir := *dir outDir := cfg.Dir
if outDir == "" { if outDir == "" {
outDir = cfg.DatabaseDirectory outDir = conf.DatabaseDirectory
} }
if outDir == "" { if outDir == "" {
outDir = "." outDir = "."
@ -41,16 +76,16 @@ func main() {
os.Exit(1) os.Exit(1)
} }
if len(cfg.EditionIDs) == 0 { if len(conf.EditionIDs) == 0 {
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath) fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", cfg.ConfPath)
os.Exit(1) os.Exit(1)
} }
authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey)}} authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(conf.AccountID, conf.LicenseKey)}}
maxAge := time.Duration(*freshDays) * 24 * time.Hour maxAge := time.Duration(cfg.FreshDays) * 24 * time.Hour
exitCode := 0 exitCode := 0
for _, edition := range cfg.EditionIDs { for _, edition := range conf.EditionIDs {
path := filepath.Join(outDir, geoip.TarGzName(edition)) path := filepath.Join(outDir, geoip.TarGzName(edition))
cacher := &httpcache.Cacher{ cacher := &httpcache.Cacher{
URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz",
@ -58,18 +93,22 @@ func main() {
MaxAge: maxAge, MaxAge: maxAge,
Header: authHeader, Header: authHeader,
} }
fmt.Fprintf(os.Stderr, "Fetching %s... ", edition)
t := time.Now()
updated, err := cacher.Fetch() updated, err := cacher.Fetch()
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: %s: %v\n", edition, err) fmt.Fprintf(os.Stderr, "error: %s: %v\n", edition, err)
exitCode = 1 exitCode = 1
continue continue
} }
info, _ := os.Stat(path) state := "fresh"
state := "fresh: "
if updated { if updated {
state = "updated:" state = "updated"
} }
fmt.Printf("%s %s -> %s (%s)\n", state, edition, path, info.ModTime().Format("2006-01-02")) fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state)
info, _ := os.Stat(path)
fmt.Printf("%-10s %s -> %s (%s)\n", state+":", edition, path, info.ModTime().Format("2006-01-02"))
} }
os.Exit(exitCode) os.Exit(exitCode)
} }

View File

@ -1,73 +1,92 @@
// git-shallow-sync is a simple CLI tool to synchronize a shallow git repository // git-shallow-sync syncs a shallow git clone at the given local path,
// using the github.com/therootcompany/golib/net/gitshallow package. // cloning on first run and fetching + hard-resetting on subsequent runs.
// //
// Usage: // Usage:
// //
// git-shallow-sync <repository-url> <local-path> // git-shallow-sync <repository-url> <local-path>
//
// Example:
//
// git-shallow-sync git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist
package main package main
import ( import (
"errors"
"flag"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"time"
"github.com/therootcompany/golib/net/gitshallow" "github.com/therootcompany/golib/net/gitshallow"
) )
const ( const version = "dev"
defaultDepth = 1 // shallow by default
defaultBranch = "" // empty = default branch + --single-branch type Config struct {
) Depth int
Branch string
}
func main() { func main() {
if len(os.Args) != 3 { cfg := Config{}
name := filepath.Base(os.Args[0]) fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fmt.Fprintf(os.Stderr, "Usage: %s <repository-url> <local-path>\n", name) fs.IntVar(&cfg.Depth, "depth", 1, "clone/fetch depth (-1 for full history)")
fmt.Fprintf(os.Stderr, "Example:\n") fs.StringVar(&cfg.Branch, "branch", "", "branch to track (empty: remote default)")
fmt.Fprintf(os.Stderr, " %s git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist\n", name) fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <repository-url> <local-path>\n", os.Args[0])
fs.PrintDefaults()
}
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1) os.Exit(1)
} }
url := os.Args[1] args := fs.Args()
path := os.Args[2] if len(args) != 2 {
fs.Usage()
os.Exit(1)
}
url := args[0]
path := args[1]
// Expand ~ to home directory if len(path) > 0 && path[0] == '~' {
if path[0] == '~' {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Failed to get home directory: %v\n", err) fmt.Fprintf(os.Stderr, "error: resolve home: %v\n", err)
os.Exit(1) os.Exit(1)
} }
path = filepath.Join(home, path[1:]) path = filepath.Join(home, path[1:])
} }
absPath, err := filepath.Abs(path) absPath, err := filepath.Abs(path)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err) fmt.Fprintf(os.Stderr, "error: invalid path: %v\n", err)
os.Exit(1) os.Exit(1)
} }
fmt.Printf("Syncing repository:\n") fmt.Fprintf(os.Stderr, "Syncing %s -> %s... ", url, absPath)
fmt.Printf(" URL: %s\n", url) t := time.Now()
fmt.Printf(" Path: %s\n", absPath) repo := gitshallow.New(url, absPath, cfg.Depth, cfg.Branch)
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
updated, err := repo.Sync() updated, err := repo.Sync()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err) fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: sync: %v\n", err)
os.Exit(1) os.Exit(1)
} }
state := "already up to date"
if updated { if updated {
fmt.Println("Repository was updated (new commits pulled).") state = "updated"
} else {
fmt.Println("Repository is already up to date.")
} }
fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state)
fmt.Println("Sync complete.")
} }

View File

@ -4,48 +4,68 @@
// Usage: // Usage:
// //
// ipcohort-contains [flags] <file>... -- <ip>... // ipcohort-contains [flags] <file>... -- <ip>...
// ipcohort-contains [flags] -ip <ip> <file>... // ipcohort-contains [flags] --ip <ip> <file>...
// // echo "<ip>" | ipcohort-contains <file>...
// Examples:
//
// ipcohort-contains networks.txt single_ips.txt -- 1.2.3.4 5.6.7.8
// ipcohort-contains -ip 1.2.3.4 single_ips.txt
// echo "1.2.3.4" | ipcohort-contains networks.txt
// //
// Exit code: 0 if all queried IPs are found, 1 if any are not found, 2 on error. // Exit code: 0 if all queried IPs are found, 1 if any are not found, 2 on error.
package main package main
import ( import (
"bufio" "bufio"
"errors"
"flag" "flag"
"fmt" "fmt"
"os" "os"
"strings" "strings"
"time"
"github.com/therootcompany/golib/net/ipcohort" "github.com/therootcompany/golib/net/ipcohort"
) )
const version = "dev"
type Config struct {
IP string
}
func main() { func main() {
ipFlag := flag.String("ip", "", "IP address to check (alternative to -- separator)") cfg := Config{}
flag.Usage = func() { fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.IP, "ip", "", "IP address to check (alternative to -- separator)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <file>... -- <ip>...\n", os.Args[0]) fmt.Fprintf(os.Stderr, "Usage: %s [flags] <file>... -- <ip>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s -ip <ip> <file>...\n", os.Args[0]) fmt.Fprintf(os.Stderr, " %s --ip <ip> <file>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " echo <ip> | %s <file>...\n", os.Args[0]) fmt.Fprintf(os.Stderr, " echo <ip> | %s <file>...\n", os.Args[0])
fmt.Fprintln(os.Stderr, "Flags:") fs.PrintDefaults()
flag.PrintDefaults()
fmt.Fprintln(os.Stderr, "Exit: 0=all found, 1=not found, 2=error") fmt.Fprintln(os.Stderr, "Exit: 0=all found, 1=not found, 2=error")
} }
flag.Parse()
args := flag.Args() if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(2)
}
args := fs.Args()
var filePaths, ips []string var filePaths, ips []string
switch { switch {
case *ipFlag != "": case cfg.IP != "":
filePaths = args filePaths = args
ips = []string{*ipFlag} ips = []string{cfg.IP}
default: default:
// Split args at "--"
sep := -1 sep := -1
for i, a := range args { for i, a := range args {
if a == "--" { if a == "--" {
@ -63,17 +83,23 @@ func main() {
if len(filePaths) == 0 { if len(filePaths) == 0 {
fmt.Fprintln(os.Stderr, "error: at least one file path required") fmt.Fprintln(os.Stderr, "error: at least one file path required")
flag.Usage() fs.Usage()
os.Exit(2) os.Exit(2)
} }
fmt.Fprint(os.Stderr, "Loading cohort... ")
t := time.Now()
cohort, err := ipcohort.LoadFiles(filePaths...) cohort, err := ipcohort.LoadFiles(filePaths...)
if err != nil { if err != nil {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: %v\n", err) fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(2) os.Exit(2)
} }
fmt.Fprintf(os.Stderr, "%s (entries=%d)\n",
time.Since(t).Round(time.Millisecond),
cohort.Size(),
)
// If no IPs from flags/args, read from stdin.
if len(ips) == 0 { if len(ips) == 0 {
sc := bufio.NewScanner(os.Stdin) sc := bufio.NewScanner(os.Stdin)
for sc.Scan() { for sc.Scan() {
@ -89,14 +115,14 @@ func main() {
if len(ips) == 0 { if len(ips) == 0 {
fmt.Fprintln(os.Stderr, "error: no IP addresses to check") fmt.Fprintln(os.Stderr, "error: no IP addresses to check")
flag.Usage() fs.Usage()
os.Exit(2) os.Exit(2)
} }
fmt.Fprintln(os.Stderr)
allFound := true allFound := true
for _, ip := range ips { for _, ip := range ips {
found := cohort.Contains(ip) if cohort.Contains(ip) {
if found {
fmt.Printf("%s\tFOUND\n", ip) fmt.Printf("%s\tFOUND\n", ip)
} else { } else {
fmt.Printf("%s\tNOT FOUND\n", ip) fmt.Printf("%s\tNOT FOUND\n", ip)