WIP: auth flow
This commit is contained in:
parent
5ba8859256
commit
3eb061a1eb
|
@ -2,8 +2,9 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v3/challenge"
|
"github.com/go-acme/lego/v3/challenge"
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
|
@ -32,32 +33,16 @@ func (p *acmeProvider) CleanUp(domain, token, keyAuth string) error {
|
||||||
|
|
||||||
func handleDNSRoutes(r chi.Router) {
|
func handleDNSRoutes(r chi.Router) {
|
||||||
r.Route("/dns", func(r chi.Router) {
|
r.Route("/dns", func(r chi.Router) {
|
||||||
r.Use(func(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
ctx := r.Context()
|
|
||||||
valid, _ := ctx.Value(MWKey("valid")).(bool)
|
|
||||||
|
|
||||||
if !valid {
|
|
||||||
// misdirection
|
|
||||||
time.Sleep(250 * time.Millisecond)
|
|
||||||
w.Write([]byte("{\"success\":true}\n"))
|
|
||||||
//http.Error(w, `{"error":"could not verify token"}`, http.StatusBadRequest)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
/*
|
|
||||||
if nil != err2 {
|
|
||||||
// a little misdirection there
|
|
||||||
msg := `{"error":"internal server error"}`
|
|
||||||
http.Error(w, msg, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
r.Post("/{domain}", func(w http.ResponseWriter, r *http.Request) {
|
r.Post("/{domain}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
domain := chi.URLParam(r, "domain")
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, ok := ctx.Value(MWKey("claims")).(*MgmtClaims)
|
||||||
|
if !ok || !strings.HasPrefix(domain+".", claims.Slug) {
|
||||||
|
msg := `{ "error": "invalid domain" }`
|
||||||
|
http.Error(w, msg+"\n", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
ch := Challenge{}
|
ch := Challenge{}
|
||||||
|
|
||||||
|
@ -70,7 +55,6 @@ func handleDNSRoutes(r chi.Router) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
domain := chi.URLParam(r, "domain")
|
|
||||||
//domain := chi.URLParam(r, "*")
|
//domain := chi.URLParam(r, "*")
|
||||||
ch.Domain = domain
|
ch.Domain = domain
|
||||||
|
|
||||||
|
@ -80,7 +64,8 @@ func handleDNSRoutes(r chi.Router) {
|
||||||
presenters <- &ch
|
presenters <- &ch
|
||||||
err = <-ch.error
|
err = <-ch.error
|
||||||
if nil != err || "" == ch.Token || "" == ch.KeyAuth {
|
if nil != err || "" == ch.Token || "" == ch.KeyAuth {
|
||||||
msg := `{"error":"expected json in the format {\"token\":\"xxx\",\"key_authorization\":\"yyy\"}"}`
|
fmt.Println("presenter err", err, ch.Token, ch.KeyAuth)
|
||||||
|
msg := `{"error":"ACME dns-01 error"}`
|
||||||
http.Error(w, msg, http.StatusUnprocessableEntity)
|
http.Error(w, msg, http.StatusUnprocessableEntity)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -90,6 +75,7 @@ func handleDNSRoutes(r chi.Router) {
|
||||||
|
|
||||||
// TODO ugly Delete, but whatever
|
// TODO ugly Delete, but whatever
|
||||||
r.Delete("/{domain}/{token}/{keyAuth}", func(w http.ResponseWriter, r *http.Request) {
|
r.Delete("/{domain}/{token}/{keyAuth}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// TODO authenticate
|
||||||
|
|
||||||
ch := Challenge{
|
ch := Challenge{
|
||||||
Domain: chi.URLParam(r, "domain"),
|
Domain: chi.URLParam(r, "domain"),
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
valid, _ := ctx.Value(MWKey("valid")).(bool)
|
||||||
|
|
||||||
|
if !valid {
|
||||||
|
// misdirection
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
w.Write([]byte("{\"success\":true}\n"))
|
||||||
|
//http.Error(w, `{"error":"could not verify token"}`, http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
/*
|
||||||
|
if nil != err2 {
|
||||||
|
// a little misdirection there
|
||||||
|
msg := `{"error":"internal server error"}`
|
||||||
|
http.Error(w, msg, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
})
|
|
@ -1,7 +1,10 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -12,18 +15,29 @@ import (
|
||||||
|
|
||||||
func handleDeviceRoutes(r chi.Router) {
|
func handleDeviceRoutes(r chi.Router) {
|
||||||
r.Route("/devices", func(r chi.Router) {
|
r.Route("/devices", func(r chi.Router) {
|
||||||
// TODO needs admin auth
|
// only the admin can get past this point
|
||||||
// r.Use() // must have slug '*'
|
r.Use(func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, ok := ctx.Value(MWKey("claims")).(*MgmtClaims)
|
||||||
|
if !ok || "*" != claims.Slug {
|
||||||
|
msg := `{"error":"missing or invalid authorization token"}`
|
||||||
|
http.Error(w, msg+"\n", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
r.Post("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Post("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
auth := &authstore.Authorization{}
|
auth := &authstore.Authorization{}
|
||||||
|
|
||||||
|
// Slug is mandatory, ID and MachinePPID must NOT be set
|
||||||
decoder := json.NewDecoder(r.Body)
|
decoder := json.NewDecoder(r.Body)
|
||||||
err := decoder.Decode(&auth)
|
err := decoder.Decode(&auth)
|
||||||
// Slug is mandatory, ID and MachinePPID must NOT be set
|
|
||||||
epoch := time.Time{}
|
epoch := time.Time{}
|
||||||
if nil != err || "" != auth.ID || "" != auth.MachinePPID ||
|
if nil != err || "" != auth.ID || "" != auth.MachinePPID || "" == auth.Slug ||
|
||||||
"" == auth.Slug || "" == auth.SharedKey ||
|
|
||||||
epoch != auth.CreatedAt || epoch != auth.UpdatedAt || epoch != auth.DeletedAt {
|
epoch != auth.CreatedAt || epoch != auth.UpdatedAt || epoch != auth.DeletedAt {
|
||||||
result, _ := json.Marshal(&authstore.Authorization{})
|
result, _ := json.Marshal(&authstore.Authorization{})
|
||||||
msg, _ := json.Marshal(&struct {
|
msg, _ := json.Marshal(&struct {
|
||||||
|
@ -35,18 +49,46 @@ func handleDeviceRoutes(r chi.Router) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if "" == auth.SharedKey {
|
||||||
|
rnd := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(rnd); nil != err {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
auth.SharedKey = base64.RawURLEncoding.EncodeToString(rnd)
|
||||||
|
}
|
||||||
|
if len(auth.SharedKey) < 20 {
|
||||||
|
msg := `{"error":"shared_key must be >= 16 bytes"}`
|
||||||
|
http.Error(w, string(msg), http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
pub := authstore.ToPublicKeyString(auth.SharedKey)
|
||||||
|
if "" == auth.PublicKey {
|
||||||
|
auth.PublicKey = pub
|
||||||
|
}
|
||||||
|
if len(auth.PublicKey) > 24 {
|
||||||
|
auth.PublicKey = auth.PublicKey[:24]
|
||||||
|
}
|
||||||
|
if pub != auth.PublicKey {
|
||||||
|
msg := `{"error":"public_key must be the first 24 bytes of the base64-encoded hash of the shared_key"}`
|
||||||
|
http.Error(w, msg+"\n", http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
err = store.Add(auth)
|
err = store.Add(auth)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}`
|
msg := `{"error":"not really sure what happened, but it didn't go well (check the logs)"}`
|
||||||
log.Printf("/api/devices/\n", auth.Slug)
|
if authstore.ErrExists == err {
|
||||||
|
msg = fmt.Sprintf(`{ "error": "%s" }`, err.Error())
|
||||||
|
}
|
||||||
|
log.Printf("/api/devices/\n")
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, msg, http.StatusInternalServerError)
|
http.Error(w, msg, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
//auth.SharedKey = "[redacted]"
|
|
||||||
result, _ := json.Marshal(auth)
|
result, _ := json.Marshal(auth)
|
||||||
w.Write(result)
|
w.Write([]byte(string(result) + "\n"))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Get("/{slug}", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/{slug}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -69,7 +111,7 @@ func handleDeviceRoutes(r chi.Router) {
|
||||||
auth.SharedKey = "[redacted]"
|
auth.SharedKey = "[redacted]"
|
||||||
}
|
}
|
||||||
result, _ := json.Marshal(auth)
|
result, _ := json.Marshal(auth)
|
||||||
w.Write(result)
|
w.Write([]byte(string(result) + "\n"))
|
||||||
})
|
})
|
||||||
|
|
||||||
r.Delete("/{slug}", func(w http.ResponseWriter, r *http.Request) {
|
r.Delete("/{slug}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -83,7 +125,7 @@ func handleDeviceRoutes(r chi.Router) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Write([]byte("{\"success\":true}\n"))
|
w.Write([]byte(`{"success":true}` + "\n"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -65,6 +65,7 @@ func routeAll() chi.Router {
|
||||||
tokenString,
|
tokenString,
|
||||||
&MgmtClaims{},
|
&MgmtClaims{},
|
||||||
func(token *jwt.Token) (interface{}, error) {
|
func(token *jwt.Token) (interface{}, error) {
|
||||||
|
fmt.Println("parsed jwt", token)
|
||||||
kid, ok := token.Header["kid"].(string)
|
kid, ok := token.Header["kid"].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("missing jwt header 'kid' (key id)")
|
return nil, fmt.Errorf("missing jwt header 'kid' (key id)")
|
||||||
|
@ -101,20 +102,27 @@ func routeAll() chi.Router {
|
||||||
return []byte(auth.SharedKey), nil
|
return []byte(auth.SharedKey), nil
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
fmt.Println("ppid:", auth.MachinePPID)
|
||||||
|
|
||||||
return []byte(auth.MachinePPID), nil
|
return []byte(auth.MachinePPID), nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx context.Context
|
ctx := r.Context()
|
||||||
|
if nil != err {
|
||||||
|
fmt.Println("auth err", err)
|
||||||
|
ctx = context.WithValue(ctx, MWKey("error"), err)
|
||||||
|
}
|
||||||
if nil != tok {
|
if nil != tok {
|
||||||
ctx = context.WithValue(r.Context(), MWKey("token"), tok)
|
fmt.Println("any auth?", tok)
|
||||||
if tok.Valid {
|
if tok.Valid {
|
||||||
ctx = context.WithValue(r.Context(), MWKey("valid"), nil != tok)
|
ctx = context.WithValue(ctx, MWKey("token"), tok)
|
||||||
|
ctx = context.WithValue(ctx, MWKey("claims"), tok.Claims)
|
||||||
|
ctx = context.WithValue(ctx, MWKey("valid"), true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if nil != err {
|
fmt.Println("Good Auth?")
|
||||||
ctx = context.WithValue(r.Context(), MWKey("error"), nil != tok)
|
fmt.Println(ctx.Value(MWKey("claims")))
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
|
@ -156,6 +164,16 @@ func routeAll() chi.Router {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO hash the PPID and check against the Public Key?
|
// TODO hash the PPID and check against the Public Key?
|
||||||
|
pub := authstore.ToPublicKeyString(auth.MachinePPID)
|
||||||
|
if pub != auth.PublicKey {
|
||||||
|
msg, _ := json.Marshal(&struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}{
|
||||||
|
Error: "expected `public_key` to be the first 24 bytes of the hash of the `machine_ppid`",
|
||||||
|
})
|
||||||
|
http.Error(w, string(msg), http.StatusUnprocessableEntity)
|
||||||
|
return
|
||||||
|
}
|
||||||
original.PublicKey = auth.PublicKey
|
original.PublicKey = auth.PublicKey
|
||||||
original.MachinePPID = auth.MachinePPID
|
original.MachinePPID = auth.MachinePPID
|
||||||
err = store.Set(original)
|
err = store.Set(original)
|
||||||
|
@ -172,8 +190,30 @@ func routeAll() chi.Router {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.Post("/ping", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
claims, ok := ctx.Value(MWKey("claims")).(*MgmtClaims)
|
||||||
|
if !ok {
|
||||||
|
msg := `{"error":"failure to ping: 1"}`
|
||||||
|
fmt.Println("touch no claims", claims)
|
||||||
|
http.Error(w, msg+"\n", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("ping pong??", claims)
|
||||||
|
err := store.Touch(claims.Slug)
|
||||||
|
if nil != err {
|
||||||
|
msg := `{"error":"failure to ping: 2"}`
|
||||||
|
fmt.Println("touch err", err)
|
||||||
|
http.Error(w, msg+"\n", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte(`{ "success": true }` + "\n"))
|
||||||
|
})
|
||||||
|
|
||||||
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write([]byte("welcome\n"))
|
w.Write([]byte("Hello\n"))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,16 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.coolaj86.com/coolaj86/go-telebitd/mplexer/mgmt/authstore"
|
||||||
|
|
||||||
|
"github.com/denisbrodbeck/machineid"
|
||||||
jwt "github.com/dgrijalva/jwt-go"
|
jwt "github.com/dgrijalva/jwt-go"
|
||||||
_ "github.com/joho/godotenv/autoload"
|
_ "github.com/joho/godotenv/autoload"
|
||||||
)
|
)
|
||||||
|
@ -11,7 +18,7 @@ import (
|
||||||
func main() {
|
func main() {
|
||||||
var secret string
|
var secret string
|
||||||
|
|
||||||
if len(os.Args) == 2 {
|
if len(os.Args) >= 2 {
|
||||||
secret = os.Args[1]
|
secret = os.Args[1]
|
||||||
}
|
}
|
||||||
if "" == secret {
|
if "" == secret {
|
||||||
|
@ -23,7 +30,28 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
tok, err := getToken(secret, []string{})
|
if len(os.Args) >= 3 {
|
||||||
|
muid, err := machineid.ProtectedID("test-id|" + secret)
|
||||||
|
if nil != err {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
muidBytes, _ := hex.DecodeString(muid)
|
||||||
|
muid = base64.RawURLEncoding.EncodeToString(muidBytes)
|
||||||
|
fmt.Println(
|
||||||
|
muid,
|
||||||
|
authstore.ToPublicKeyString(muid),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
b := make([]byte, 16)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
claims := &jwt.StandardClaims{
|
||||||
|
Id: base64.RawURLEncoding.EncodeToString(b),
|
||||||
|
IssuedAt: time.Now().Unix(),
|
||||||
|
ExpiresAt: time.Now().Add(5 * time.Minute).Unix(),
|
||||||
|
}
|
||||||
|
tok, err := getToken(secret, claims)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
fmt.Fprintf(os.Stderr, "signing error: %s", err)
|
fmt.Fprintf(os.Stderr, "signing error: %s", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
@ -33,10 +61,22 @@ func main() {
|
||||||
fmt.Println(tok)
|
fmt.Println(tok)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getToken(secret string, domains []string) (token string, err error) {
|
func getToken(secret string, tokenData *jwt.StandardClaims) (token string, err error) {
|
||||||
tokenData := jwt.MapClaims{"domains": domains}
|
keyID := authstore.ToPublicKeyString(secret)
|
||||||
|
|
||||||
|
fmt.Fprintf(os.Stderr, "secret: %s\n", secret)
|
||||||
|
fmt.Fprintf(os.Stderr, "kid: %s\n", keyID)
|
||||||
|
|
||||||
|
jwtToken := &jwt.Token{
|
||||||
|
Header: map[string]interface{}{
|
||||||
|
"kid": keyID,
|
||||||
|
"typ": "JWT",
|
||||||
|
"alg": jwt.SigningMethodHS256.Alg(),
|
||||||
|
},
|
||||||
|
Claims: tokenData,
|
||||||
|
Method: jwt.SigningMethodHS256,
|
||||||
|
}
|
||||||
|
|
||||||
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, tokenData)
|
|
||||||
if token, err = jwtToken.SignedString([]byte(secret)); err != nil {
|
if token, err = jwtToken.SignedString([]byte(secret)); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
package authstore
|
package authstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrExists = errors.New("token already exists")
|
||||||
|
|
||||||
type Authorization struct {
|
type Authorization struct {
|
||||||
ID string `db:"id,omitempty" json:"-"`
|
ID string `db:"id,omitempty" json:"-"`
|
||||||
|
|
||||||
|
@ -21,9 +26,19 @@ type Store interface {
|
||||||
SetMaster(secret string) error
|
SetMaster(secret string) error
|
||||||
Add(auth *Authorization) error
|
Add(auth *Authorization) error
|
||||||
Set(auth *Authorization) error
|
Set(auth *Authorization) error
|
||||||
|
Touch(id string) error
|
||||||
Get(id string) (*Authorization, error)
|
Get(id string) (*Authorization, error)
|
||||||
GetBySlug(id string) (*Authorization, error)
|
GetBySlug(id string) (*Authorization, error)
|
||||||
GetByPub(id string) (*Authorization, error)
|
GetByPub(id string) (*Authorization, error)
|
||||||
Delete(auth *Authorization) error
|
Delete(auth *Authorization) error
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToPublicKeyString(secret string) string {
|
||||||
|
pubBytes := sha256.Sum256([]byte(secret))
|
||||||
|
pub := base64.RawURLEncoding.EncodeToString(pubBytes[:])
|
||||||
|
if len(pub) > 24 {
|
||||||
|
pub = pub[:24]
|
||||||
|
}
|
||||||
|
return pub
|
||||||
|
}
|
||||||
|
|
|
@ -48,13 +48,13 @@ func (s *PGStore) SetMaster(secret string) error {
|
||||||
defer done()
|
defer done()
|
||||||
|
|
||||||
pubBytes := sha256.Sum256([]byte(secret))
|
pubBytes := sha256.Sum256([]byte(secret))
|
||||||
pub := base64.RawURLEncoding.EncodeToString(pubBytes[:])
|
pub := base64.RawURLEncoding.EncodeToString(pubBytes[:])
|
||||||
pub = pub[:24]
|
pub = pub[:24]
|
||||||
auth := &Authorization{
|
auth := &Authorization{
|
||||||
Slug: "*",
|
Slug: "*",
|
||||||
SharedKey: secret,
|
SharedKey: secret,
|
||||||
MachinePPID: secret,
|
MachinePPID: secret,
|
||||||
PublicKey: pub,
|
PublicKey: pub,
|
||||||
}
|
}
|
||||||
err := s.Add(auth)
|
err := s.Add(auth)
|
||||||
|
|
||||||
|
@ -90,6 +90,7 @@ func (s *PGStore) Add(auth *Authorization) error {
|
||||||
SELECT slug FROM authorizations WHERE deleted_at = '1970-01-01 00:00:00' AND slug = $1
|
SELECT slug FROM authorizations WHERE deleted_at = '1970-01-01 00:00:00' AND slug = $1
|
||||||
)
|
)
|
||||||
`
|
`
|
||||||
|
now := time.Now()
|
||||||
res, err := tx.ExecContext(ctx, query2, auth.Slug, auth.SharedKey, auth.PublicKey)
|
res, err := tx.ExecContext(ctx, query2, auth.Slug, auth.SharedKey, auth.PublicKey)
|
||||||
if nil != err {
|
if nil != err {
|
||||||
return err
|
return err
|
||||||
|
@ -97,13 +98,16 @@ func (s *PGStore) Add(auth *Authorization) error {
|
||||||
|
|
||||||
// PostgreSQL does support RowsAffected(), but not LastInsertId()
|
// PostgreSQL does support RowsAffected(), but not LastInsertId()
|
||||||
if count, _ := res.RowsAffected(); count != 1 {
|
if count, _ := res.RowsAffected(); count != 1 {
|
||||||
return fmt.Errorf("record not added (probably exists)")
|
// TODO be more sure?
|
||||||
|
return ErrExists // fmt.Errorf("record not added (probably exists)")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); nil != err {
|
if err := tx.Commit(); nil != err {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auth.CreatedAt = now
|
||||||
|
auth.UpdatedAt = now
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,9 +116,10 @@ func (s *PGStore) Set(auth *Authorization) error {
|
||||||
defer done()
|
defer done()
|
||||||
query := `
|
query := `
|
||||||
UPDATE authorizations SET
|
UPDATE authorizations SET
|
||||||
machine_ppid=$1,
|
machine_ppid = $1,
|
||||||
shared_key=$2,
|
shared_key = $2,
|
||||||
public_key=$3
|
public_key = $3,
|
||||||
|
updated_at = 'now'
|
||||||
WHERE
|
WHERE
|
||||||
deleted_at = '1970-01-01 00:00:00'
|
deleted_at = '1970-01-01 00:00:00'
|
||||||
AND shared_key = $2
|
AND shared_key = $2
|
||||||
|
@ -131,14 +136,34 @@ func (s *PGStore) Set(auth *Authorization) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *PGStore) Touch(pub string) error {
|
||||||
|
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||||
|
defer done()
|
||||||
|
query := `
|
||||||
|
UPDATE authorizations SET
|
||||||
|
updated_at = 'now'
|
||||||
|
WHERE deleted_at = '1970-01-01 00:00:00'
|
||||||
|
AND (public_key = $1 OR slug = $1)
|
||||||
|
`
|
||||||
|
row, err := s.dbx.ExecContext(ctx, query, pub)
|
||||||
|
if nil != err {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// PostgreSQL does support RowsAffected()
|
||||||
|
if count, _ := row.RowsAffected(); count != 1 {
|
||||||
|
return fmt.Errorf("record was not updated")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *PGStore) Get(id string) (*Authorization, error) {
|
func (s *PGStore) Get(id string) (*Authorization, error) {
|
||||||
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
ctx, done := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||||
defer done()
|
defer done()
|
||||||
query := `
|
query := `
|
||||||
SELECT * FROM authorizations
|
SELECT * FROM authorizations
|
||||||
WHERE deleted_at = '1970-01-01 00:00:00'
|
WHERE deleted_at = '1970-01-01 00:00:00'
|
||||||
AND (slug = $1 OR public_key = $1 OR shared_key = $1)
|
AND (slug = $1 OR public_key = $1 OR shared_key = $1)
|
||||||
`
|
`
|
||||||
row := s.dbx.QueryRowxContext(ctx, query, id)
|
row := s.dbx.QueryRowxContext(ctx, query, id)
|
||||||
if nil != row {
|
if nil != row {
|
||||||
auth := &Authorization{}
|
auth := &Authorization{}
|
||||||
|
@ -194,7 +219,7 @@ func (s *PGStore) Delete(auth *Authorization) error {
|
||||||
}
|
}
|
||||||
// PostgreSQL does support RowsAffected()
|
// PostgreSQL does support RowsAffected()
|
||||||
if count, _ := row.RowsAffected(); count != 1 {
|
if count, _ := row.RowsAffected(); count != 1 {
|
||||||
return fmt.Errorf("record exists")
|
return fmt.Errorf("record does not exist")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue