diff --git a/cmd/smbtest/go.mod b/cmd/smbtest/go.mod new file mode 100644 index 0000000..2a9190d --- /dev/null +++ b/cmd/smbtest/go.mod @@ -0,0 +1,14 @@ +module example.com/m + +go 1.24.3 + +require ( + github.com/hirochachacha/go-smb2 v1.1.0 + golang.org/x/term v0.39.0 +) + +require ( + github.com/geoffgarside/ber v1.1.0 // indirect + golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de // indirect + golang.org/x/sys v0.40.0 // indirect +) diff --git a/cmd/smbtest/go.sum b/cmd/smbtest/go.sum new file mode 100644 index 0000000..d5d3752 --- /dev/null +++ b/cmd/smbtest/go.sum @@ -0,0 +1,15 @@ +github.com/geoffgarside/ber v1.1.0 h1:qTmFG4jJbwiSzSXoNJeHcOprVzZ8Ulde2Rrrifu5U9w= +github.com/geoffgarside/ber v1.1.0/go.mod h1:jVPKeCbj6MvQZhwLYsGwaGI52oUorHoHKNecGT85ZCc= +github.com/hirochachacha/go-smb2 v1.1.0 h1:b6hs9qKIql9eVXAiN0M2wSFY5xnhbHAQoCwRKbaRTZI= +github.com/hirochachacha/go-smb2 v1.1.0/go.mod h1:8F1A4d5EZzrGu5R7PU163UcMRDJQl4FtcxjBfsY8TZE= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de h1:ikNHVSjEfnvz6sxdSPCaPt572qowuyMDMJLLm3Db3ig= +golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/cmd/smbtest/main.go b/cmd/smbtest/main.go new file mode 100644 index 0000000..29014af --- /dev/null +++ b/cmd/smbtest/main.go @@ -0,0 +1,211 @@ +package main + +import ( + "flag" + "fmt" + "io" + "io/fs" + "net" + "os" + "strings" + + "golang.org/x/term" + + "github.com/hirochachacha/go-smb2" +) + +const ( + name = "smbtest" + licenseYear = "2026" + licenseOwner = "AJ ONeal" + licenseType = "CC0-1.0" +) + +// set by GoReleaser via ldflags +var ( + version = "0.0.0-dev" + commit = "0000000" + date = "0001-01-01T00:00:00Z" +) + +// printVersion displays the version, commit, and build date. +func printVersion() { + if len(commit) > 7 { + commit = commit[:7] + } + fmt.Fprintf(os.Stderr, "%s v%s %s (%s)\n", name, version, commit, date) + fmt.Fprintf(os.Stderr, "Copyright (C) %s %s\n", licenseYear, licenseOwner) + fmt.Fprintf(os.Stderr, "Licensed under the %s license\n", licenseType) +} + +type CLIConfig struct { + showVersion bool + user string + host string + share string + remotePath string +} + +func main() { + cfg := CLIConfig{} + + mainFlags := flag.NewFlagSet("", flag.ContinueOnError) + + mainFlags.BoolVar(&cfg.showVersion, "version", false, "Print version and exit") + mainFlags.StringVar(&cfg.user, "user", os.Getenv("SMB_USERNAME"), "ex: 'jon', or set SMB_USERNAME (password will be prompted unless SMB_PASSWORD is set") + mainFlags.StringVar(&cfg.host, "host", os.Getenv("SMB_HOST"), "ex: 'localhost:445', or set SMB_HOST") + mainFlags.StringVar(&cfg.share, "share", os.Getenv("SMB_SHARE"), "ex: 'Public', or set SMB_SHARE") + mainFlags.StringVar(&cfg.remotePath, "remote-path", os.Getenv("SMB_REMOTE_PATH"), "ex: 'Public/goodies.zip', or set SMB_REMOTE_PATH") + + mainFlags.Usage = func() { + printVersion() + out := mainFlags.Output() + _, _ = fmt.Fprintf(out, "\n") + _, _ = fmt.Fprintf(out, "USAGE\n") + _, _ = fmt.Fprintf(out, " smbtest [options] \n") + mainFlags.PrintDefaults() + } + + if len(os.Args) > 1 { + switch os.Args[1] { + case "-V", "version", "-version", "--version": + printVersion() + return + case "help", "-help", "--help": + mainFlags.SetOutput(os.Stdout) + mainFlags.Usage() + return + } + } + + if err := mainFlags.Parse(os.Args[1:]); err != nil { + fmt.Fprintln(os.Stderr, err) + + mainFlags.SetOutput(os.Stderr) + mainFlags.Usage() + os.Exit(1) + return + } + + // Handle --version flag after parsing + if cfg.showVersion { + printVersion() + return + } + + pass, hasPass := os.LookupEnv("SMB_PASSWORD") + if !hasPass { + fmt.Fprintf(os.Stderr, "SMB_PASSWORD is not set: ") + fmt.Print("Password: ") + password, err := term.ReadPassword(int(os.Stdin.Fd())) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to read password: %v\n", err) + os.Exit(1) + } + fmt.Fprintf(os.Stderr, "\n") + pass = strings.TrimRight(string(password), "\r\n \t") + } + + fmt.Printf("%s@%s/%s/%s", cfg.user, cfg.host, cfg.share, cfg.remotePath) + trySMB(cfg.host, cfg.share, cfg.user, pass, cfg.remotePath) + fmt.Println("") +} + +type SMBClient struct { + conn net.Conn + session *smb2.Session + fs *smb2.Share +} + +func NewSMBClient(host, share, username, password string) (*SMBClient, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, err + } + + d := &smb2.Dialer{ + Initiator: &smb2.NTLMInitiator{ + User: username, + Password: password, + }, + } + + s, err := d.Dial(conn) + if err != nil { + _ = conn.Close() + return nil, err + } + + f, err := s.Mount(share) + if err != nil { + _ = s.Logoff() + _ = conn.Close() + return nil, err + } + + return &SMBClient{conn: conn, session: s, fs: f}, nil +} + +func (c *SMBClient) ListFiles(path string) ([]string, error) { + matches, err := fs.Glob(c.fs.DirFS(path), "*") + if err != nil { + return nil, err + } + return matches, nil +} + +func (c *SMBClient) ReceiveFile(path string, w io.Writer) error { + f, err := c.fs.Open(path) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + _, err = io.Copy(w, f) + return err +} + +func (c *SMBClient) Close() { + _ = c.fs.Umount() + _ = c.session.Logoff() + _ = c.conn.Close() +} + +func trySMB(host, share, username, password, rpath string) { + client, err := NewSMBClient(host, share, username, password) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return + } + defer client.Close() + + // List files + files, err := client.ListFiles(".") + if err != nil { + fmt.Fprintf(os.Stderr, "List error: %v\n", err) + return + } + if len(files) == 0 { + fmt.Println("No files") + } + for _, f := range files { + fmt.Println(" ", f) + } + + // Receive file + f, err := os.Create(rpath) + if err != nil { + fmt.Fprintf(os.Stderr, "Create error: %v\n", err) + return + } + defer func() { + if err := f.Close(); err != nil { + fmt.Fprintf(os.Stderr, "Close error: %v\n", err) + } + }() + + err = client.ReceiveFile(rpath, f) + if err != nil { + fmt.Fprintf(os.Stderr, "Receive error: %v\n", err) + } +}