367 lines
7.4 KiB
Go
367 lines
7.4 KiB
Go
package updater
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"log"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"encoding/hex"
|
|
|
|
"encoding/json"
|
|
|
|
"github.com/UnnoTed/fileb0x/file"
|
|
"github.com/airking05/termui"
|
|
)
|
|
|
|
// Auth holds authentication for the http basic auth
|
|
type Auth struct {
|
|
Username string
|
|
Password string
|
|
}
|
|
|
|
// ResponseInit holds a list of hashes from the server
|
|
// to be sent to the client so it can check if there
|
|
// is a new file or a changed file
|
|
type ResponseInit struct {
|
|
Success bool
|
|
Hashes map[string]string
|
|
}
|
|
|
|
// ProgressReader implements a io.Reader with a Read
|
|
// function that lets a callback report how much
|
|
// of the file was read
|
|
type ProgressReader struct {
|
|
io.Reader
|
|
Reporter func(r int64)
|
|
}
|
|
|
|
func (pr *ProgressReader) Read(p []byte) (n int, err error) {
|
|
n, err = pr.Reader.Read(p)
|
|
pr.Reporter(int64(n))
|
|
return
|
|
}
|
|
|
|
// Updater sends files that should be update to the b0x server
|
|
type Updater struct {
|
|
Server string
|
|
Auth Auth
|
|
ui []termui.Bufferer
|
|
|
|
RemoteHashes map[string]string
|
|
LocalHashes map[string]string
|
|
ToUpdate []string
|
|
Workers int
|
|
}
|
|
|
|
// Init gets the list of file hash from the server
|
|
func (up *Updater) Init() error {
|
|
return up.Get()
|
|
}
|
|
|
|
// Get gets the list of file hash from the server
|
|
func (up *Updater) Get() error {
|
|
log.Println("Creating hash list request...")
|
|
req, err := http.NewRequest("GET", up.Server, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
req.SetBasicAuth(up.Auth.Username, up.Auth.Password)
|
|
|
|
log.Println("Sending hash list request...")
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
return errors.New("Error Unautorized")
|
|
}
|
|
|
|
log.Println("Reading hash list response's body...")
|
|
var buf bytes.Buffer
|
|
_, err = buf.ReadFrom(resp.Body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Println("Parsing hash list response's body...")
|
|
ri := &ResponseInit{}
|
|
err = json.Unmarshal(buf.Bytes(), &ri)
|
|
if err != nil {
|
|
log.Println("Body is", buf.Bytes())
|
|
return err
|
|
}
|
|
resp.Body.Close()
|
|
|
|
// copy hash list
|
|
if ri.Success {
|
|
log.Println("Copying hash list...")
|
|
up.RemoteHashes = ri.Hashes
|
|
up.LocalHashes = map[string]string{}
|
|
log.Println("Done")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Updatable checks if there is any file that should be updaTed
|
|
func (up *Updater) Updatable(files map[string]*file.File) (bool, error) {
|
|
hasUpdates := !up.EqualHashes(files)
|
|
|
|
if hasUpdates {
|
|
log.Println("----------------------------------------")
|
|
log.Println("-- Found files that should be updated --")
|
|
log.Println("----------------------------------------")
|
|
} else {
|
|
log.Println("-----------------------")
|
|
log.Println("-- Nothing to update --")
|
|
log.Println("-----------------------")
|
|
}
|
|
|
|
return hasUpdates, nil
|
|
}
|
|
|
|
// EqualHash checks if a local file hash equals a remote file hash
|
|
// it returns false when a remote file hash isn't found (new files)
|
|
func (up *Updater) EqualHash(name string) bool {
|
|
hash, existsLocally := up.LocalHashes[name]
|
|
_, existsRemotely := up.RemoteHashes[name]
|
|
if !existsRemotely || !existsLocally || hash != up.RemoteHashes[name] {
|
|
if hash != up.RemoteHashes[name] {
|
|
log.Println("Found changes in file: ", name)
|
|
|
|
} else if !existsRemotely && existsLocally {
|
|
log.Println("Found new file: ", name)
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
// EqualHashes builds the list of local hashes before
|
|
// checking if there is any that should be updated
|
|
func (up *Updater) EqualHashes(files map[string]*file.File) bool {
|
|
for _, f := range files {
|
|
log.Println("Checking file for changes:", f.Path)
|
|
|
|
if len(f.Bytes) == 0 && !f.ReplacedText {
|
|
data, err := ioutil.ReadFile(f.OriginalPath)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
f.Bytes = data
|
|
|
|
// removes the []byte("") from the string
|
|
// when the data isn't in the Bytes variable
|
|
} else if len(f.Bytes) == 0 && f.ReplacedText && len(f.Data) > 0 {
|
|
f.Data = strings.TrimPrefix(f.Data, `[]byte("`)
|
|
f.Data = strings.TrimSuffix(f.Data, `")`)
|
|
f.Data = strings.Replace(f.Data, "\\x", "", -1)
|
|
|
|
var err error
|
|
f.Bytes, err = hex.DecodeString(f.Data)
|
|
if err != nil {
|
|
log.Println("SHIT", err)
|
|
return false
|
|
}
|
|
|
|
f.Data = ""
|
|
}
|
|
|
|
sha := sha256.New()
|
|
if _, err := sha.Write(f.Bytes); err != nil {
|
|
panic(err)
|
|
return false
|
|
}
|
|
|
|
up.LocalHashes[f.Path] = hex.EncodeToString(sha.Sum(nil))
|
|
}
|
|
|
|
// check if there is any file to update
|
|
update := false
|
|
for k := range up.LocalHashes {
|
|
if !up.EqualHash(k) {
|
|
up.ToUpdate = append(up.ToUpdate, k)
|
|
update = true
|
|
}
|
|
}
|
|
|
|
return !update
|
|
}
|
|
|
|
type job struct {
|
|
current int
|
|
files *file.File
|
|
total int
|
|
}
|
|
|
|
// UpdateFiles sends all files that should be updated to the server
|
|
// the limit is 3 concurrent files at once
|
|
func (up *Updater) UpdateFiles(files map[string]*file.File) error {
|
|
updatable, err := up.Updatable(files)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !updatable {
|
|
return nil
|
|
}
|
|
|
|
// everything's height
|
|
height := 3
|
|
err = termui.Init()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
defer termui.Close()
|
|
|
|
// info text
|
|
p := termui.NewPar("PRESS ANY KEY TO QUIT")
|
|
p.Height = height
|
|
p.Width = 50
|
|
p.TextFgColor = termui.ColorWhite
|
|
up.ui = append(up.ui, p)
|
|
|
|
doneTotal := 0
|
|
total := len(up.ToUpdate)
|
|
jobs := make(chan *job, total)
|
|
done := make(chan bool, total)
|
|
|
|
if up.Workers <= 0 {
|
|
up.Workers = 1
|
|
}
|
|
|
|
// just so it can listen to events
|
|
go func() {
|
|
termui.Loop()
|
|
}()
|
|
|
|
// cancel with any key
|
|
termui.Handle("/sys/kbd", func(termui.Event) {
|
|
termui.StopLoop()
|
|
os.Exit(1)
|
|
})
|
|
|
|
// stops rendering when total is reached
|
|
go func(upp *Updater, d *int) {
|
|
for {
|
|
if *d >= total {
|
|
break
|
|
}
|
|
|
|
termui.Render(upp.ui...)
|
|
}
|
|
}(up, &doneTotal)
|
|
|
|
for i := 0; i < up.Workers; i++ {
|
|
// creates a progress bar
|
|
g := termui.NewGauge()
|
|
g.Width = termui.TermWidth()
|
|
g.Height = height
|
|
g.BarColor = termui.ColorBlue
|
|
g.Y = len(up.ui) * height
|
|
up.ui = append(up.ui, g)
|
|
|
|
go up.worker(jobs, done, g)
|
|
}
|
|
|
|
for i, name := range up.ToUpdate {
|
|
jobs <- &job{
|
|
current: i + 1,
|
|
files: files[name],
|
|
total: total,
|
|
}
|
|
}
|
|
close(jobs)
|
|
|
|
for i := 0; i < total; i++ {
|
|
<-done
|
|
doneTotal++
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (up *Updater) worker(jobs <-chan *job, done chan<- bool, g *termui.Gauge) {
|
|
for job := range jobs {
|
|
f := job.files
|
|
fr := bytes.NewReader(f.Bytes)
|
|
g.BorderLabel = fmt.Sprintf("%d/%d %s", job.current, job.total, f.Path)
|
|
|
|
// updates progress bar's percentage
|
|
var total int64
|
|
pr := &ProgressReader{fr, func(r int64) {
|
|
total += r
|
|
g.Percent = int(float64(total) / float64(fr.Size()) * 100)
|
|
}}
|
|
|
|
r, w := io.Pipe()
|
|
writer := multipart.NewWriter(w)
|
|
|
|
// copy the file into the form
|
|
go func(fr *ProgressReader) {
|
|
defer w.Close()
|
|
part, err := writer.CreateFormFile("file", f.Path)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
_, err = io.Copy(part, fr)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
err = writer.Close()
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
}(pr)
|
|
|
|
// create a post request with basic auth
|
|
// and the file included in a form
|
|
req, err := http.NewRequest("POST", up.Server, r)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
req.Header.Set("Content-Type", writer.FormDataContentType())
|
|
req.SetBasicAuth(up.Auth.Username, up.Auth.Password)
|
|
|
|
// sends the request
|
|
client := &http.Client{}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
body := &bytes.Buffer{}
|
|
_, err = body.ReadFrom(resp.Body)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if err := resp.Body.Close(); err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
if body.String() != "ok" {
|
|
panic(body.String())
|
|
}
|
|
|
|
done <- true
|
|
}
|
|
}
|