479 lines
9.3 KiB
Go
479 lines
9.3 KiB
Go
|
package vaulttokenmanager
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"crypto/hmac"
|
||
|
"crypto/rand"
|
||
|
"crypto/sha256"
|
||
|
"encoding/base64"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"log"
|
||
|
"os"
|
||
|
"path"
|
||
|
"strconv"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
vaultapi "github.com/hashicorp/vault/api"
|
||
|
)
|
||
|
|
||
|
const tokenRefreshInterval = 10 * time.Hour
|
||
|
|
||
|
type notFoundError struct{}
|
||
|
|
||
|
func (m *notFoundError) Error() string {
|
||
|
return "token not found"
|
||
|
}
|
||
|
|
||
|
type token struct {
|
||
|
Secret string `json:"token"`
|
||
|
Created int64 `json:"token-ts"`
|
||
|
version int `json:"-"`
|
||
|
}
|
||
|
|
||
|
type TokenManager struct {
|
||
|
key string
|
||
|
basePath string
|
||
|
|
||
|
latest *token
|
||
|
versions map[int]*token
|
||
|
|
||
|
vault *vaultapi.Client
|
||
|
lock sync.RWMutex
|
||
|
}
|
||
|
|
||
|
func New(key, depEnv string) (*TokenManager, error) {
|
||
|
|
||
|
if len(depEnv) == 0 {
|
||
|
return nil, fmt.Errorf("invalid deployment mode parameter %q", depEnv)
|
||
|
}
|
||
|
|
||
|
var basePath = fmt.Sprintf("kv/data/ntppool/%s/", depEnv)
|
||
|
|
||
|
cl, err := vaultClient()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tm := &TokenManager{
|
||
|
key: key,
|
||
|
basePath: basePath,
|
||
|
vault: cl,
|
||
|
versions: map[int]*token{},
|
||
|
}
|
||
|
|
||
|
err = tm.populate()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// todo: pass context so it can be shutdown
|
||
|
go tm.rotateTokensBackground()
|
||
|
|
||
|
return tm, nil
|
||
|
}
|
||
|
|
||
|
func getSignatureVersion(sig []byte) (int, error) {
|
||
|
idx := bytes.IndexByte(sig, '-')
|
||
|
if idx < 1 {
|
||
|
return 0, fmt.Errorf("invalid signature")
|
||
|
}
|
||
|
|
||
|
versionb := sig[0:idx]
|
||
|
version, err := strconv.Atoi(string(versionb))
|
||
|
if err != nil || version == 0 {
|
||
|
return 0, fmt.Errorf("unknown signature version %d: %s", version, err)
|
||
|
}
|
||
|
return version, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) {
|
||
|
version, err := getSignatureVersion(sig)
|
||
|
if err != nil || version == 0 {
|
||
|
return false, err
|
||
|
}
|
||
|
|
||
|
token, err := tm.getTokenVersion(context.Background(), version)
|
||
|
if err != nil {
|
||
|
return false, err
|
||
|
}
|
||
|
|
||
|
expected, err := tm.signWith(token, data...)
|
||
|
if err != nil {
|
||
|
return false, err
|
||
|
}
|
||
|
if len(expected) > 0 && bytes.Equal(sig, expected) {
|
||
|
return true, nil
|
||
|
}
|
||
|
|
||
|
log.Printf("exp: %s", expected)
|
||
|
log.Printf("got: %s", sig)
|
||
|
|
||
|
return false, fmt.Errorf("could not validate signature")
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) Sign(data ...[]byte) ([]byte, error) {
|
||
|
|
||
|
token, err := tm.getToken(context.Background())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
// log.Printf("got version: %d, token: %s", token.version, token.Secret)
|
||
|
|
||
|
return tm.signWith(token, data...)
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) {
|
||
|
|
||
|
hm := hmac.New(sha256.New, []byte(token.Secret))
|
||
|
|
||
|
b := bytes.Join(data, []byte("|"))
|
||
|
|
||
|
p, err := hm.Write([]byte(b))
|
||
|
if err != nil || p != len(b) {
|
||
|
return nil, fmt.Errorf("hmac error: %s", err)
|
||
|
}
|
||
|
|
||
|
sha := hm.Sum(nil)
|
||
|
|
||
|
r := strconv.AppendInt([]byte{}, int64(token.version), 10)
|
||
|
|
||
|
r = append(r, []byte("-")...)
|
||
|
|
||
|
shaenc := make([]byte, base64.RawURLEncoding.EncodedLen(len(sha)))
|
||
|
base64.RawURLEncoding.Encode(shaenc, sha)
|
||
|
|
||
|
r = append(r, shaenc...)
|
||
|
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) rotateTokensBackground() {
|
||
|
ctx := context.Background() // for when the app has context properly
|
||
|
|
||
|
l := log.New(os.Stderr, "rotateTokensBackground: ", 0)
|
||
|
|
||
|
ticker := time.NewTicker(tokenRefreshInterval / 5)
|
||
|
defer ticker.Stop()
|
||
|
|
||
|
for {
|
||
|
|
||
|
latest, err := tm.getToken(ctx)
|
||
|
if err != nil {
|
||
|
l.Printf("could not get token: %s", err)
|
||
|
}
|
||
|
|
||
|
latestTime := time.Unix(latest.Created, 0)
|
||
|
|
||
|
l.Printf("checking token age, latest is %s old (interval: %s)", time.Since(latestTime), tokenRefreshInterval.String())
|
||
|
|
||
|
if age := time.Since(latestTime); age > tokenRefreshInterval {
|
||
|
l.Printf("token age (%s) is more than %s, rotate it", age, tokenRefreshInterval)
|
||
|
tm.createNewToken(ctx, latest.version)
|
||
|
tm.lock.Lock()
|
||
|
tm.latest = nil
|
||
|
tm.lock.Unlock()
|
||
|
tm.getToken(ctx)
|
||
|
l.Printf("finished renewing token")
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
l.Printf("context done")
|
||
|
return
|
||
|
case <-ticker.C:
|
||
|
}
|
||
|
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
|
||
|
data := map[string]interface{}{
|
||
|
"data": makeToken(),
|
||
|
"metadata": map[string]interface{}{
|
||
|
"cas_required": true,
|
||
|
},
|
||
|
"options": map[string]interface{}{
|
||
|
"cas": cas,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
_, err := tm.vault.Logical().WriteWithContext(ctx, tm.path(), data)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) path() string {
|
||
|
return path.Join(tm.basePath, tm.key)
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) populate() error {
|
||
|
ctx := context.Background()
|
||
|
|
||
|
t, err := tm.getToken(ctx)
|
||
|
if err != nil {
|
||
|
if _, ok := err.(*notFoundError); !ok {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if t == nil {
|
||
|
err := tm.createNewToken(ctx, 0)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("could not save token: %s", err)
|
||
|
}
|
||
|
|
||
|
t, err = tm.getToken(ctx)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if t == nil {
|
||
|
return fmt.Errorf("could not find token data")
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if t != nil {
|
||
|
tm.latest = t
|
||
|
tm.versions[t.version] = t
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) getTokenVersionCache(ctx context.Context, version int) (*token, error) {
|
||
|
tm.lock.RLock()
|
||
|
defer tm.lock.RUnlock()
|
||
|
|
||
|
if t, ok := tm.versions[version]; ok {
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*token, error) {
|
||
|
|
||
|
token, err := tm.getTokenVersionCache(ctx, version)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if token != nil {
|
||
|
return token, nil
|
||
|
}
|
||
|
|
||
|
latest, err := tm.getToken(ctx)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if latest.version < version {
|
||
|
return nil, fmt.Errorf("invalid signature version")
|
||
|
}
|
||
|
|
||
|
tm.lock.Lock()
|
||
|
defer tm.lock.Unlock()
|
||
|
|
||
|
// in case it was set while we were waiting for a lock
|
||
|
if token, ok := tm.versions[version]; ok {
|
||
|
return token, nil
|
||
|
}
|
||
|
|
||
|
log.Printf("requesting token %q/%d", tm.key, version)
|
||
|
rv, err := tm.getKVVersion(ctx, tm.key, version)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
token, err = parseTokenVaultSecret(rv.Data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tm.versions[version] = token
|
||
|
|
||
|
return token, err
|
||
|
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
|
||
|
tm.lock.RLock()
|
||
|
|
||
|
if token := tm.latest; token != nil {
|
||
|
tm.lock.RUnlock()
|
||
|
return token, nil
|
||
|
}
|
||
|
|
||
|
log.Printf("getToken didn't have latest token, getting from vault")
|
||
|
|
||
|
tm.lock.RUnlock()
|
||
|
tm.lock.Lock()
|
||
|
defer tm.lock.Unlock()
|
||
|
|
||
|
// in case it was set while we were waiting for a lock
|
||
|
if tm.latest != nil {
|
||
|
return tm.latest, nil
|
||
|
}
|
||
|
|
||
|
log.Printf("getToken calling getKV")
|
||
|
|
||
|
rv, err := tm.getKV(ctx, tm.key)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if rv == nil {
|
||
|
return nil, ¬FoundError{}
|
||
|
}
|
||
|
|
||
|
t, err := parseTokenVaultSecret(rv.Data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
tm.latest = t
|
||
|
tm.versions[t.version] = t
|
||
|
|
||
|
log.Printf("getToken returning success")
|
||
|
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
func parseTokenVaultSecret(data map[string]interface{}) (*token, error) {
|
||
|
t := &token{}
|
||
|
|
||
|
var err error
|
||
|
|
||
|
if dataif, ok := data["data"]; ok {
|
||
|
data := dataif.(map[string]interface{})
|
||
|
|
||
|
if tokData, ok := data["token"]; ok {
|
||
|
if tokStr, ok := tokData.(string); ok {
|
||
|
t.Secret = tokStr
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if tokData, ok := data["token-ts"]; ok {
|
||
|
if tokInt, ok := tokData.(json.Number); ok {
|
||
|
t.Created, err = tokInt.Int64()
|
||
|
if t.Created == 0 || err != nil {
|
||
|
log.Printf("could not parse Created from token secret (%T: %+v): %s", tokData, tokData, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if metaif, ok := data["metadata"]; ok {
|
||
|
meta := metaif.(map[string]interface{})
|
||
|
if version, ok := meta["version"]; ok {
|
||
|
if v, ok := version.(json.Number); ok {
|
||
|
v64, err := v.Int64()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
t.version = int(v64)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if t.version == 0 || len(t.Secret) == 0 {
|
||
|
return nil, fmt.Errorf("expected token data not found")
|
||
|
}
|
||
|
|
||
|
return t, nil
|
||
|
}
|
||
|
|
||
|
func makeToken() *token {
|
||
|
randomBytes := make([]byte, 16)
|
||
|
_, err := rand.Read(randomBytes)
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
return &token{
|
||
|
Secret: base64.URLEncoding.EncodeToString(randomBytes),
|
||
|
Created: time.Now().Unix(),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var hasOutputVaultEnvMessage bool
|
||
|
|
||
|
func vaultClient() (*vaultapi.Client, error) {
|
||
|
|
||
|
c := vaultapi.DefaultConfig()
|
||
|
|
||
|
if c.Address == "https://127.0.0.1:8200" {
|
||
|
c.Address = "https://vault.ntppool.org"
|
||
|
}
|
||
|
|
||
|
cl, err := vaultapi.NewClient(c)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
// VAULT_TOKEN is read automatically from the environment if set
|
||
|
// so we just try the file here
|
||
|
token, err := os.ReadFile("/vault/secrets/token")
|
||
|
if err == nil {
|
||
|
cl.SetToken(string(token))
|
||
|
} else {
|
||
|
if !hasOutputVaultEnvMessage {
|
||
|
hasOutputVaultEnvMessage = true
|
||
|
log.Printf("could not read /vault/secrets/token (%s), using VAULT_TOKEN", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return cl, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret, error) {
|
||
|
|
||
|
cl, err := vaultClient()
|
||
|
if err != nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
rv, err := cl.Logical().ReadWithContext(ctx, tm.path())
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return rv, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int) (*vaultapi.Secret, error) {
|
||
|
|
||
|
cl, err := vaultClient()
|
||
|
if err != nil {
|
||
|
return nil, nil
|
||
|
}
|
||
|
|
||
|
data := map[string][]string{
|
||
|
"version": {strconv.Itoa(version)},
|
||
|
}
|
||
|
|
||
|
rv, err := cl.Logical().ReadWithDataWithContext(ctx, tm.path(), data)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return rv, nil
|
||
|
}
|
||
|
|
||
|
func (tm *TokenManager) SetKV(ctx context.Context, k string, data *vaultapi.Secret) error {
|
||
|
p := tm.path()
|
||
|
cl, err := vaultClient()
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
_, err = cl.Logical().WriteWithContext(ctx, p, data.Data)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|