refactor: apply check-ip CLI conventions to sibling cmds

Propagate the patterns used in cmd/check-ip to the other command-line
tools touched by this PR:

- flag.FlagSet + Config struct instead of package-level flag.String
  pointers (geoip-update, ipcohort-contains, git-shallow-sync).
- -V/--version/version and help/-help/--help handled before Parse,
  matching the project's CLI conventions.
- Stderr "Loading X... Nms (counts)" progress lines on the stages that
  actually take time: blocklist cohort parse (ipcohort-contains),
  per-edition fetch (geoip-update), and repo sync (git-shallow-sync).
  Stdout stays machine-parseable.
This commit is contained in:
AJ ONeal 2026-04-20 19:13:47 -06:00
parent a181133c2f
commit 8ebc571928
No known key found for this signature in database
3 changed files with 159 additions and 75 deletions

View File

@ -1,6 +1,9 @@
// geoip-update downloads GeoLite2 edition tarballs listed in GeoIP.conf
// via conditional HTTP GETs, writing them to the configured directory.
package main
import (
"errors"
"flag"
"fmt"
"net/http"
@ -12,26 +15,58 @@ import (
"github.com/therootcompany/golib/net/httpcache"
)
const version = "dev"
type Config struct {
ConfPath string
Dir string
FreshDays int
}
func main() {
configPath := flag.String("config", "GeoIP.conf", "path to GeoIP.conf")
dir := flag.String("dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)")
freshDays := flag.Int("fresh-days", 3, "skip download if file is younger than N days")
flag.Parse()
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.ConfPath, "config", "GeoIP.conf", "path to GeoIP.conf")
fs.StringVar(&cfg.Dir, "dir", "", "directory to store .tar.gz files (overrides DatabaseDirectory in config)")
fs.IntVar(&cfg.FreshDays, "fresh-days", 3, "skip download if file is younger than N days")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags]\n", os.Args[0])
fs.PrintDefaults()
}
data, err := os.ReadFile(*configPath)
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "geoip-update %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "geoip-update %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1)
}
data, err := os.ReadFile(cfg.ConfPath)
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
cfg, err := geoip.ParseConf(string(data))
conf, err := geoip.ParseConf(string(data))
if err != nil {
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(1)
}
outDir := *dir
outDir := cfg.Dir
if outDir == "" {
outDir = cfg.DatabaseDirectory
outDir = conf.DatabaseDirectory
}
if outDir == "" {
outDir = "."
@ -41,16 +76,16 @@ func main() {
os.Exit(1)
}
if len(cfg.EditionIDs) == 0 {
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", *configPath)
if len(conf.EditionIDs) == 0 {
fmt.Fprintf(os.Stderr, "error: no EditionIDs found in %s\n", cfg.ConfPath)
os.Exit(1)
}
authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(cfg.AccountID, cfg.LicenseKey)}}
maxAge := time.Duration(*freshDays) * 24 * time.Hour
authHeader := http.Header{"Authorization": []string{httpcache.BasicAuth(conf.AccountID, conf.LicenseKey)}}
maxAge := time.Duration(cfg.FreshDays) * 24 * time.Hour
exitCode := 0
for _, edition := range cfg.EditionIDs {
for _, edition := range conf.EditionIDs {
path := filepath.Join(outDir, geoip.TarGzName(edition))
cacher := &httpcache.Cacher{
URL: geoip.DownloadBase + "/" + edition + "/download?suffix=tar.gz",
@ -58,18 +93,22 @@ func main() {
MaxAge: maxAge,
Header: authHeader,
}
fmt.Fprintf(os.Stderr, "Fetching %s... ", edition)
t := time.Now()
updated, err := cacher.Fetch()
if err != nil {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: %s: %v\n", edition, err)
exitCode = 1
continue
}
info, _ := os.Stat(path)
state := "fresh: "
state := "fresh"
if updated {
state = "updated:"
state = "updated"
}
fmt.Printf("%s %s -> %s (%s)\n", state, edition, path, info.ModTime().Format("2006-01-02"))
fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state)
info, _ := os.Stat(path)
fmt.Printf("%-10s %s -> %s (%s)\n", state+":", edition, path, info.ModTime().Format("2006-01-02"))
}
os.Exit(exitCode)
}

View File

@ -1,73 +1,92 @@
// git-shallow-sync is a simple CLI tool to synchronize a shallow git repository
// using the github.com/therootcompany/golib/net/gitshallow package.
// git-shallow-sync syncs a shallow git clone at the given local path,
// cloning on first run and fetching + hard-resetting on subsequent runs.
//
// Usage:
//
// git-shallow-sync <repository-url> <local-path>
//
// Example:
//
// git-shallow-sync git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist
package main
import (
"errors"
"flag"
"fmt"
"os"
"path/filepath"
"time"
"github.com/therootcompany/golib/net/gitshallow"
)
const (
defaultDepth = 1 // shallow by default
defaultBranch = "" // empty = default branch + --single-branch
)
const version = "dev"
type Config struct {
Depth int
Branch string
}
func main() {
if len(os.Args) != 3 {
name := filepath.Base(os.Args[0])
fmt.Fprintf(os.Stderr, "Usage: %s <repository-url> <local-path>\n", name)
fmt.Fprintf(os.Stderr, "Example:\n")
fmt.Fprintf(os.Stderr, " %s git@github.com:bitwire-it/ipblocklist.git ~/srv/app/ipblocklist\n", name)
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.IntVar(&cfg.Depth, "depth", 1, "clone/fetch depth (-1 for full history)")
fs.StringVar(&cfg.Branch, "branch", "", "branch to track (empty: remote default)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <repository-url> <local-path>\n", os.Args[0])
fs.PrintDefaults()
}
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "git-shallow-sync %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(1)
}
url := os.Args[1]
path := os.Args[2]
args := fs.Args()
if len(args) != 2 {
fs.Usage()
os.Exit(1)
}
url := args[0]
path := args[1]
// Expand ~ to home directory
if path[0] == '~' {
if len(path) > 0 && path[0] == '~' {
home, err := os.UserHomeDir()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to get home directory: %v\n", err)
fmt.Fprintf(os.Stderr, "error: resolve home: %v\n", err)
os.Exit(1)
}
path = filepath.Join(home, path[1:])
}
absPath, err := filepath.Abs(path)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid path: %v\n", err)
fmt.Fprintf(os.Stderr, "error: invalid path: %v\n", err)
os.Exit(1)
}
fmt.Printf("Syncing repository:\n")
fmt.Printf(" URL: %s\n", url)
fmt.Printf(" Path: %s\n", absPath)
repo := gitshallow.New(url, absPath, defaultDepth, defaultBranch)
fmt.Fprintf(os.Stderr, "Syncing %s -> %s... ", url, absPath)
t := time.Now()
repo := gitshallow.New(url, absPath, cfg.Depth, cfg.Branch)
updated, err := repo.Sync()
if err != nil {
fmt.Fprintf(os.Stderr, "Sync failed: %v\n", err)
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: sync: %v\n", err)
os.Exit(1)
}
state := "already up to date"
if updated {
fmt.Println("Repository was updated (new commits pulled).")
} else {
fmt.Println("Repository is already up to date.")
state = "updated"
}
fmt.Println("Sync complete.")
fmt.Fprintf(os.Stderr, "%s (%s)\n", time.Since(t).Round(time.Millisecond), state)
}

View File

@ -4,48 +4,68 @@
// Usage:
//
// ipcohort-contains [flags] <file>... -- <ip>...
// ipcohort-contains [flags] -ip <ip> <file>...
//
// Examples:
//
// ipcohort-contains networks.txt single_ips.txt -- 1.2.3.4 5.6.7.8
// ipcohort-contains -ip 1.2.3.4 single_ips.txt
// echo "1.2.3.4" | ipcohort-contains networks.txt
// ipcohort-contains [flags] --ip <ip> <file>...
// echo "<ip>" | ipcohort-contains <file>...
//
// Exit code: 0 if all queried IPs are found, 1 if any are not found, 2 on error.
package main
import (
"bufio"
"errors"
"flag"
"fmt"
"os"
"strings"
"time"
"github.com/therootcompany/golib/net/ipcohort"
)
const version = "dev"
type Config struct {
IP string
}
func main() {
ipFlag := flag.String("ip", "", "IP address to check (alternative to -- separator)")
flag.Usage = func() {
cfg := Config{}
fs := flag.NewFlagSet(os.Args[0], flag.ContinueOnError)
fs.StringVar(&cfg.IP, "ip", "", "IP address to check (alternative to -- separator)")
fs.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [flags] <file>... -- <ip>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s -ip <ip> <file>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " %s --ip <ip> <file>...\n", os.Args[0])
fmt.Fprintf(os.Stderr, " echo <ip> | %s <file>...\n", os.Args[0])
fmt.Fprintln(os.Stderr, "Flags:")
flag.PrintDefaults()
fs.PrintDefaults()
fmt.Fprintln(os.Stderr, "Exit: 0=all found, 1=not found, 2=error")
}
flag.Parse()
args := flag.Args()
if len(os.Args) > 1 {
switch os.Args[1] {
case "-V", "-version", "--version", "version":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n", version)
os.Exit(0)
case "help", "-help", "--help":
fmt.Fprintf(os.Stdout, "ipcohort-contains %s\n\n", version)
fs.SetOutput(os.Stdout)
fs.Usage()
os.Exit(0)
}
}
if err := fs.Parse(os.Args[1:]); err != nil {
if errors.Is(err, flag.ErrHelp) {
os.Exit(0)
}
os.Exit(2)
}
args := fs.Args()
var filePaths, ips []string
switch {
case *ipFlag != "":
case cfg.IP != "":
filePaths = args
ips = []string{*ipFlag}
ips = []string{cfg.IP}
default:
// Split args at "--"
sep := -1
for i, a := range args {
if a == "--" {
@ -63,17 +83,23 @@ func main() {
if len(filePaths) == 0 {
fmt.Fprintln(os.Stderr, "error: at least one file path required")
flag.Usage()
fs.Usage()
os.Exit(2)
}
fmt.Fprint(os.Stderr, "Loading cohort... ")
t := time.Now()
cohort, err := ipcohort.LoadFiles(filePaths...)
if err != nil {
fmt.Fprintln(os.Stderr)
fmt.Fprintf(os.Stderr, "error: %v\n", err)
os.Exit(2)
}
fmt.Fprintf(os.Stderr, "%s (entries=%d)\n",
time.Since(t).Round(time.Millisecond),
cohort.Size(),
)
// If no IPs from flags/args, read from stdin.
if len(ips) == 0 {
sc := bufio.NewScanner(os.Stdin)
for sc.Scan() {
@ -89,14 +115,14 @@ func main() {
if len(ips) == 0 {
fmt.Fprintln(os.Stderr, "error: no IP addresses to check")
flag.Usage()
fs.Usage()
os.Exit(2)
}
fmt.Fprintln(os.Stderr)
allFound := true
for _, ip := range ips {
found := cohort.Contains(ip)
if found {
if cohort.Contains(ip) {
fmt.Printf("%s\tFOUND\n", ip)
} else {
fmt.Printf("%s\tNOT FOUND\n", ip)