mirror of
https://github.com/therootcompany/golib.git
synced 2026-01-27 23:18:05 +00:00
212 lines
4.5 KiB
Go
212 lines
4.5 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"os"
|
|
"strings"
|
|
|
|
"golang.org/x/term"
|
|
|
|
"github.com/hirochachacha/go-smb2"
|
|
)
|
|
|
|
const (
|
|
name = "smbtest"
|
|
licenseYear = "2026"
|
|
licenseOwner = "AJ ONeal"
|
|
licenseType = "CC0-1.0"
|
|
)
|
|
|
|
// set by GoReleaser via ldflags
|
|
var (
|
|
version = "0.0.0-dev"
|
|
commit = "0000000"
|
|
date = "0001-01-01T00:00:00Z"
|
|
)
|
|
|
|
// printVersion displays the version, commit, and build date.
|
|
func printVersion() {
|
|
if len(commit) > 7 {
|
|
commit = commit[:7]
|
|
}
|
|
fmt.Fprintf(os.Stderr, "%s v%s %s (%s)\n", name, version, commit, date)
|
|
fmt.Fprintf(os.Stderr, "Copyright (C) %s %s\n", licenseYear, licenseOwner)
|
|
fmt.Fprintf(os.Stderr, "Licensed under the %s license\n", licenseType)
|
|
}
|
|
|
|
type CLIConfig struct {
|
|
showVersion bool
|
|
user string
|
|
host string
|
|
share string
|
|
remotePath string
|
|
}
|
|
|
|
func main() {
|
|
cfg := CLIConfig{}
|
|
|
|
mainFlags := flag.NewFlagSet("", flag.ContinueOnError)
|
|
|
|
mainFlags.BoolVar(&cfg.showVersion, "version", false, "Print version and exit")
|
|
mainFlags.StringVar(&cfg.user, "user", os.Getenv("SMB_USERNAME"), "ex: 'jon', or set SMB_USERNAME (password will be prompted unless SMB_PASSWORD is set")
|
|
mainFlags.StringVar(&cfg.host, "host", os.Getenv("SMB_HOST"), "ex: 'localhost:445', or set SMB_HOST")
|
|
mainFlags.StringVar(&cfg.share, "share", os.Getenv("SMB_SHARE"), "ex: 'Public', or set SMB_SHARE")
|
|
mainFlags.StringVar(&cfg.remotePath, "remote-path", os.Getenv("SMB_REMOTE_PATH"), "ex: 'Public/goodies.zip', or set SMB_REMOTE_PATH")
|
|
|
|
mainFlags.Usage = func() {
|
|
printVersion()
|
|
out := mainFlags.Output()
|
|
_, _ = fmt.Fprintf(out, "\n")
|
|
_, _ = fmt.Fprintf(out, "USAGE\n")
|
|
_, _ = fmt.Fprintf(out, " smbtest [options] <url>\n")
|
|
mainFlags.PrintDefaults()
|
|
}
|
|
|
|
if len(os.Args) > 1 {
|
|
switch os.Args[1] {
|
|
case "-V", "version", "-version", "--version":
|
|
printVersion()
|
|
return
|
|
case "help", "-help", "--help":
|
|
mainFlags.SetOutput(os.Stdout)
|
|
mainFlags.Usage()
|
|
return
|
|
}
|
|
}
|
|
|
|
if err := mainFlags.Parse(os.Args[1:]); err != nil {
|
|
fmt.Fprintln(os.Stderr, err)
|
|
|
|
mainFlags.SetOutput(os.Stderr)
|
|
mainFlags.Usage()
|
|
os.Exit(1)
|
|
return
|
|
}
|
|
|
|
// Handle --version flag after parsing
|
|
if cfg.showVersion {
|
|
printVersion()
|
|
return
|
|
}
|
|
|
|
pass, hasPass := os.LookupEnv("SMB_PASSWORD")
|
|
if !hasPass {
|
|
fmt.Fprintf(os.Stderr, "SMB_PASSWORD is not set: ")
|
|
fmt.Print("Password: ")
|
|
password, err := term.ReadPassword(int(os.Stdin.Fd()))
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "failed to read password: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
fmt.Fprintf(os.Stderr, "\n")
|
|
pass = strings.TrimRight(string(password), "\r\n \t")
|
|
}
|
|
|
|
fmt.Printf("%s@%s/%s/%s", cfg.user, cfg.host, cfg.share, cfg.remotePath)
|
|
trySMB(cfg.host, cfg.share, cfg.user, pass, cfg.remotePath)
|
|
fmt.Println("")
|
|
}
|
|
|
|
type SMBClient struct {
|
|
conn net.Conn
|
|
session *smb2.Session
|
|
fs *smb2.Share
|
|
}
|
|
|
|
func NewSMBClient(host, share, username, password string) (*SMBClient, error) {
|
|
conn, err := net.Dial("tcp", host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
d := &smb2.Dialer{
|
|
Initiator: &smb2.NTLMInitiator{
|
|
User: username,
|
|
Password: password,
|
|
},
|
|
}
|
|
|
|
s, err := d.Dial(conn)
|
|
if err != nil {
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
f, err := s.Mount(share)
|
|
if err != nil {
|
|
_ = s.Logoff()
|
|
_ = conn.Close()
|
|
return nil, err
|
|
}
|
|
|
|
return &SMBClient{conn: conn, session: s, fs: f}, nil
|
|
}
|
|
|
|
func (c *SMBClient) ListFiles(path string) ([]string, error) {
|
|
matches, err := fs.Glob(c.fs.DirFS(path), "*")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return matches, nil
|
|
}
|
|
|
|
func (c *SMBClient) ReceiveFile(path string, w io.Writer) error {
|
|
f, err := c.fs.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() { _ = f.Close() }()
|
|
|
|
_, err = io.Copy(w, f)
|
|
return err
|
|
}
|
|
|
|
func (c *SMBClient) Close() {
|
|
_ = c.fs.Umount()
|
|
_ = c.session.Logoff()
|
|
_ = c.conn.Close()
|
|
}
|
|
|
|
func trySMB(host, share, username, password, rpath string) {
|
|
client, err := NewSMBClient(host, share, username, password)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
return
|
|
}
|
|
defer client.Close()
|
|
|
|
// List files
|
|
files, err := client.ListFiles(".")
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "List error: %v\n", err)
|
|
return
|
|
}
|
|
if len(files) == 0 {
|
|
fmt.Println("No files")
|
|
}
|
|
for _, f := range files {
|
|
fmt.Println(" ", f)
|
|
}
|
|
|
|
// Receive file
|
|
f, err := os.Create(rpath)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Create error: %v\n", err)
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := f.Close(); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Close error: %v\n", err)
|
|
}
|
|
}()
|
|
|
|
err = client.ReceiveFile(rpath, f)
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "Receive error: %v\n", err)
|
|
}
|
|
}
|