From 8ebc571928c181222f6cfa26e20f235be182f5b8 Mon Sep 17 00:00:00 2001 From: AJ ONeal Date: Mon, 20 Apr 2026 19:13:47 -0600 Subject: [PATCH] refactor: apply check-ip CLI conventions to sibling cmds Propagate the patterns used in cmd/check-ip to the other command-line tools touched by this PR: - flag.FlagSet + Config struct instead of package-level flag.String pointers (geoip-update, ipcohort-contains, git-shallow-sync). - -V/--version/version and help/-help/--help handled before Parse, matching the project's CLI conventions. - Stderr "Loading X... Nms (counts)" progress lines on the stages that actually take time: blocklist cohort parse (ipcohort-contains), per-edition fetch (geoip-update), and repo sync (git-shallow-sync). Stdout stays machine-parseable. --- net/geoip/cmd/geoip-update/main.go | 73 +++++++++++++---- net/gitshallow/cmd/git-shallow-sync/main.go | 89 +++++++++++++-------- net/ipcohort/cmd/ipcohort-contains/main.go | 72 +++++++++++------ 3 files changed, 159 insertions(+), 75 deletions(-) diff --git a/net/geoip/cmd/geoip-update/main.go b/net/geoip/cmd/geoip-update/main.go index 79e4451..fe557b4 100644 --- a/net/geoip/cmd/geoip-update/main.go +++ b/net/geoip/cmd/geoip-update/main.go @@ -1,6 +1,9 @@ +// geoip-update downloads GeoLite2 edition tarballs listed in GeoIP.conf +// via conditional HTTP GETs, writing them to the configured directory. package main import ( + "errors" "flag" "fmt" "net/http" @@ -12,26 +15,58 @@ import ( "github.com/therootcompany/golib/net/httpcache" ) +const version = "dev" + +type Config struct { + ConfPath string + Dir string + FreshDays int +} + func main() { - configPath := flag.String("config", "GeoIP.conf", "path to GeoIP.conf") - dir := flag.String("dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)") - freshDays := flag.Int("fresh-days", 3, "skip download if file is younger than N days") - flag.Parse() + cfg := Config{} + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.StringVar(&cfg.ConfPath, "config", "GeoIP.conf", "path to GeoIP.conf") + fs.StringVar(&cfg.Dir, "dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)") + fs.IntVar(&cfg.FreshDays, "fresh-days", 3, "skip download if file is younger than N days") + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [flags]\n", os.Args[0]) + fs.PrintDefaults() + } - data, err := os.ReadFile(*configPath) + if len(os.Args) > 1 { + switch os.Args[1] { + case "-V", "-version", "--version", "version": + fmt.Fprintf(os.Stdout, "geoip-update %s\n", version) + os.Exit(0) + case "help", "-help", "--help": + fmt.Fprintf(os.Stdout, "geoip-update %s\n\n", version) + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + if err := fs.Parse(os.Args[1:]); err != nil { + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } + os.Exit(1) + } + + data, err := os.ReadFile(cfg.ConfPath) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - cfg, err := geoip.ParseConf(string(data)) + conf, err := geoip.ParseConf(string(data)) if err != nil { fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(1) } - outDir := *dir + outDir := cfg.Dir if outDir == "" { - outDir = cfg.DatabaseDirectory + outDir = conf.DatabaseDirectory } if outDir == "" { outDir = "." @@ -41,16 +76,16 @@ func main() { os.Exit(1) } - if len(cfg.EditionIDs) == 0 { - fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath) + if len(conf.EditionIDs) == 0 { + fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", cfg.ConfPath) os.Exit(1) } - authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey)}} - maxAge := time.Duration(*freshDays) * 24 * time.Hour + authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(conf.AccountID, conf.LicenseKey)}} + maxAge := time.Duration(cfg.FreshDays) * 24 * time.Hour exitCode := 0 - for _, edition := range cfg.EditionIDs { + for _, edition := range conf.EditionIDs { path := filepath.Join(outDir, geoip.TarGzName(edition)) cacher := &httpcache.Cacher{ URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz", @@ -58,18 +93,22 @@ func main() { MaxAge: maxAge, Header: authHeader, } + fmt.Fprintf(os.Stderr, "Fetching %s... ", edition) + t := time.Now() updated, err := cacher.Fetch() if err != nil { + fmt.Fprintln(os.Stderr) fmt.Fprintf(os.Stderr, "error: %s: %v\n", edition, err) exitCode = 1 continue } - info, _ := os.Stat(path) - state := "fresh: " + state := "fresh" if updated { - state = "updated:" + state = "updated" } - fmt.Printf("%s %s -> %s (%s)\n", state, edition, path, info.ModTime().Format("2006-01-02")) + fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state) + info, _ := os.Stat(path) + fmt.Printf("%-10s %s -> %s (%s)\n", state+":", edition, path, info.ModTime().Format("2006-01-02")) } os.Exit(exitCode) } diff --git a/net/gitshallow/cmd/git-shallow-sync/main.go b/net/gitshallow/cmd/git-shallow-sync/main.go index 511635a..e1b0c81 100644 --- a/net/gitshallow/cmd/git-shallow-sync/main.go +++ b/net/gitshallow/cmd/git-shallow-sync/main.go @@ -1,73 +1,92 @@ -// git-shallow-sync is a simple CLI tool to synchronize a shallow git repository -// using the github.com/therootcompany/golib/net/gitshallow package. +// git-shallow-sync syncs a shallow git clone at the given local path, +// cloning on first run and fetching + hard-resetting on subsequent runs. // // Usage: // // git-shallow-sync -// -// Example: -// -// git-shallow-sync git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist package main import ( + "errors" + "flag" "fmt" "os" "path/filepath" + "time" "github.com/therootcompany/golib/net/gitshallow" ) -const ( - defaultDepth = 1 // shallow by default - defaultBranch = "" // empty = default branch + --single-branch -) +const version = "dev" + +type Config struct { + Depth int + Branch string +} func main() { - if len(os.Args) != 3 { - name := filepath.Base(os.Args[0]) - fmt.Fprintf(os.Stderr, "Usage: %s \n", name) - fmt.Fprintf(os.Stderr, "Example:\n") - fmt.Fprintf(os.Stderr, " %s git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist\n", name) + cfg := Config{} + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.IntVar(&cfg.Depth, "depth", 1, "clone/fetch depth (-1 for full history)") + fs.StringVar(&cfg.Branch, "branch", "", "branch to track (empty: remote default)") + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: %s [flags] \n", os.Args[0]) + fs.PrintDefaults() + } + + if len(os.Args) > 1 { + switch os.Args[1] { + case "-V", "-version", "--version", "version": + fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n", version) + os.Exit(0) + case "help", "-help", "--help": + fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n\n", version) + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + if err := fs.Parse(os.Args[1:]); err != nil { + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } os.Exit(1) } - url := os.Args[1] - path := os.Args[2] + args := fs.Args() + if len(args) != 2 { + fs.Usage() + os.Exit(1) + } + url := args[0] + path := args[1] - // Expand ~ to home directory - if path[0] == '~' { + if len(path) > 0 && path[0] == '~' { home, err := os.UserHomeDir() if err != nil { - fmt.Fprintf(os.Stderr, "Failed to get home directory: %v\n", err) + fmt.Fprintf(os.Stderr, "error: resolve home: %v\n", err) os.Exit(1) } path = filepath.Join(home, path[1:]) } - absPath, err := filepath.Abs(path) if err != nil { - fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err) + fmt.Fprintf(os.Stderr, "error: invalid path: %v\n", err) os.Exit(1) } - fmt.Printf("Syncing repository:\n") - fmt.Printf(" URL: %s\n", url) - fmt.Printf(" Path: %s\n", absPath) - - repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch) - + fmt.Fprintf(os.Stderr, "Syncing %s -> %s... ", url, absPath) + t := time.Now() + repo := gitshallow.New(url, absPath, cfg.Depth, cfg.Branch) updated, err := repo.Sync() if err != nil { - fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err) + fmt.Fprintln(os.Stderr) + fmt.Fprintf(os.Stderr, "error: sync: %v\n", err) os.Exit(1) } - + state := "already up to date" if updated { - fmt.Println("Repository was updated (new commits pulled).") - } else { - fmt.Println("Repository is already up to date.") + state = "updated" } - - fmt.Println("Sync complete.") + fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state) } diff --git a/net/ipcohort/cmd/ipcohort-contains/main.go b/net/ipcohort/cmd/ipcohort-contains/main.go index a59f6db..ba7d919 100644 --- a/net/ipcohort/cmd/ipcohort-contains/main.go +++ b/net/ipcohort/cmd/ipcohort-contains/main.go @@ -4,48 +4,68 @@ // Usage: // // ipcohort-contains [flags] ... -- ... -// ipcohort-contains [flags] -ip ... -// -// Examples: -// -// ipcohort-contains networks.txt single_ips.txt -- 1.2.3.4 5.6.7.8 -// ipcohort-contains -ip 1.2.3.4 single_ips.txt -// echo "1.2.3.4" | ipcohort-contains networks.txt +// ipcohort-contains [flags] --ip ... +// echo "" | ipcohort-contains ... // // Exit code: 0 if all queried IPs are found, 1 if any are not found, 2 on error. package main import ( "bufio" + "errors" "flag" "fmt" "os" "strings" + "time" "github.com/therootcompany/golib/net/ipcohort" ) +const version = "dev" + +type Config struct { + IP string +} + func main() { - ipFlag := flag.String("ip", "", "IP address to check (alternative to -- separator)") - flag.Usage = func() { + cfg := Config{} + fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError) + fs.StringVar(&cfg.IP, "ip", "", "IP address to check (alternative to -- separator)") + fs.Usage = func() { fmt.Fprintf(os.Stderr, "Usage: %s [flags] ... -- ...\n", os.Args[0]) - fmt.Fprintf(os.Stderr, " %s -ip ...\n", os.Args[0]) + fmt.Fprintf(os.Stderr, " %s --ip ...\n", os.Args[0]) fmt.Fprintf(os.Stderr, " echo | %s ...\n", os.Args[0]) - fmt.Fprintln(os.Stderr, "Flags:") - flag.PrintDefaults() + fs.PrintDefaults() fmt.Fprintln(os.Stderr, "Exit: 0=all found, 1=not found, 2=error") } - flag.Parse() - args := flag.Args() + if len(os.Args) > 1 { + switch os.Args[1] { + case "-V", "-version", "--version", "version": + fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n", version) + os.Exit(0) + case "help", "-help", "--help": + fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n\n", version) + fs.SetOutput(os.Stdout) + fs.Usage() + os.Exit(0) + } + } + if err := fs.Parse(os.Args[1:]); err != nil { + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } + os.Exit(2) + } + + args := fs.Args() var filePaths, ips []string - switch { - case *ipFlag != "": + case cfg.IP != "": filePaths = args - ips = []string{*ipFlag} + ips = []string{cfg.IP} default: - // Split args at "--" sep := -1 for i, a := range args { if a == "--" { @@ -63,17 +83,23 @@ func main() { if len(filePaths) == 0 { fmt.Fprintln(os.Stderr, "error: at least one file path required") - flag.Usage() + fs.Usage() os.Exit(2) } + fmt.Fprint(os.Stderr, "Loading cohort... ") + t := time.Now() cohort, err := ipcohort.LoadFiles(filePaths...) if err != nil { + fmt.Fprintln(os.Stderr) fmt.Fprintf(os.Stderr, "error: %v\n", err) os.Exit(2) } + fmt.Fprintf(os.Stderr, "%s (entries=%d)\n", + time.Since(t).Round(time.Millisecond), + cohort.Size(), + ) - // If no IPs from flags/args, read from stdin. if len(ips) == 0 { sc := bufio.NewScanner(os.Stdin) for sc.Scan() { @@ -89,14 +115,14 @@ func main() { if len(ips) == 0 { fmt.Fprintln(os.Stderr, "error: no IP addresses to check") - flag.Usage() + fs.Usage() os.Exit(2) } + fmt.Fprintln(os.Stderr) allFound := true for _, ip := range ips { - found := cohort.Contains(ip) - if found { + if cohort.Contains(ip) { fmt.Printf("%s\tFOUND\n", ip) } else { fmt.Printf("%s\tNOT FOUND\n", ip)