vault-token-manager/vault.go

452 lines
9.7 KiB
Go

package vaulttokenmanager
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"math/rand"
"os"
"strconv"
"sync"
"time"
vaultapi "github.com/hashicorp/vault/api"
)
type notFoundError struct{}
func (m *notFoundError) Error() string {
return "token not found"
}
type token struct {
Secret string `json:"token"`
Created int64 `json:"token-ts"`
version int `json:"-"`
}
type TokenManager struct {
path string
vault *vaultapi.Client
cfg *Config
latest *token
versions map[int]*token
lock sync.RWMutex
}
// Config configures the token manager on initalization
type Config struct {
// Vault is a vault client configured with SetToken (or defaulting to VAULT_TOKEN
// from the environment)
Vault *vaultapi.Client
// Path is the key in vault (in a KVv2 secrets engine), for example "kv/data/project/tokens"
Path string
// RefreshInterval is the interval a new token is written. Vault defaults to keeping
// 10 versions (so by default the validity period of the signatures is 10 * interval).
// Defaults to 16 hours, plus/minus 90 seconds to minimize race conditions (rotating
// the token twice).
// https://developer.hashicorp.com/vault/tutorials/secrets-management/versioned-kv#step-4-specify-the-number-of-versions-to-keep
RefreshInterval time.Duration
}
// New returns a TokenManager using specified Config. A goroutine will
// run to refresh the token until the context is cancelled.
func New(ctx context.Context, cfg *Config) (*TokenManager, error) {
if cfg.RefreshInterval == 0 {
cfg.RefreshInterval = 16 * time.Hour
}
if cfg.RefreshInterval < time.Minute*10 {
return nil, fmt.Errorf("RefreshInterval must be at least 10 minutes")
}
tm := &TokenManager{
path: cfg.Path,
vault: cfg.Vault,
cfg: cfg,
versions: map[int]*token{},
}
err := tm.populate()
if err != nil {
return nil, err
}
go tm.rotateTokensBackground(ctx)
return tm, nil
}
func getSignatureVersion(sig []byte) (int, error) {
idx := bytes.IndexByte(sig, '-')
if idx < 1 {
return 0, fmt.Errorf("invalid signature")
}
versionb := sig[0:idx]
version, err := strconv.Atoi(string(versionb))
if err != nil || version == 0 {
return 0, fmt.Errorf("unknown signature version %d: %s", version, err)
}
return version, nil
}
// ValidateBytes will validate the signature matches the specified data. The
// signature from SignBytes includes a key version. An error is returned
// if the key version isn't available anymore.
func (tm *TokenManager) ValidateBytes(sig []byte, data ...[]byte) (bool, error) {
version, err := getSignatureVersion(sig)
if err != nil || version == 0 {
return false, err
}
token, err := tm.getTokenVersion(context.Background(), version)
if err != nil {
return false, err
}
expected, err := tm.signWith(token, data...)
if err != nil {
return false, err
}
if len(expected) > 0 && bytes.Equal(sig, expected) {
return true, nil
}
log.Printf("exp: %s", expected)
log.Printf("got: %s", sig)
return false, fmt.Errorf("could not validate signature")
}
// SignBytes returns a base64 encoded hmac signature of the given data,
// using the most recent key. The signature is prefixed with the key version
func (tm *TokenManager) SignBytes(data ...[]byte) ([]byte, error) {
token, err := tm.getToken(context.Background())
if err != nil {
return nil, err
}
// log.Printf("got version: %d, token: %s", token.version, token.Secret)
return tm.signWith(token, data...)
}
func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) {
hm := hmac.New(sha256.New, []byte(token.Secret))
b := bytes.Join(data, []byte("|"))
p, err := hm.Write([]byte(b))
if err != nil || p != len(b) {
return nil, fmt.Errorf("hmac error: %s", err)
}
sha := hm.Sum(nil)
r := strconv.AppendInt([]byte{}, int64(token.version), 10)
r = append(r, []byte("-")...)
shaenc := make([]byte, base64.RawURLEncoding.EncodedLen(len(sha)))
base64.RawURLEncoding.Encode(shaenc, sha)
r = append(r, shaenc...)
return r, nil
}
func (tm *TokenManager) rotateTokensBackground(ctx context.Context) {
l := log.New(os.Stderr, "vault-token-manager: ", 0)
ticker := time.NewTicker(tm.cfg.RefreshInterval / 5)
defer ticker.Stop()
for {
latest, err := tm.getToken(ctx)
if err != nil {
l.Printf("could not get token: %s", err)
}
latestTime := time.Unix(latest.Created, 0)
// l.Printf("checking token age %q, latest is %s old (interval: %s)", tm.path, time.Since(latestTime), tm.cfg.RefreshInterval.String())
if age := time.Since(latestTime); age > tm.cfg.RefreshInterval {
// l.Printf("token age (%s) is more than %s, rotate it", age, tm.cfg.RefreshInterval)
tm.createNewToken(ctx, latest.version)
tm.lock.Lock()
tm.latest = nil
tm.lock.Unlock()
tm.getToken(ctx)
// l.Printf("finished renewing token")
}
untilNext := time.Until(latestTime.Add(tm.cfg.RefreshInterval))
// randomize the time within +/- 90 seconds
untilNext = untilNext + time.Second*time.Duration(rand.Intn(180)-90)
timer := time.NewTimer(untilNext)
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
timer.Stop()
case <-timer.C:
}
}
}
func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
data := map[string]interface{}{
"data": makeToken(),
"metadata": map[string]interface{}{
"cas_required": true,
},
"options": map[string]interface{}{
"cas": cas,
},
}
_, err := tm.vault.Logical().WriteWithContext(ctx, tm.path, data)
if err != nil {
return err
}
return nil
}
func (tm *TokenManager) populate() error {
ctx := context.Background()
t, err := tm.getToken(ctx)
if err != nil {
if _, ok := err.(*notFoundError); !ok {
return err
}
}
if t == nil {
err := tm.createNewToken(ctx, 0)
if err != nil {
return fmt.Errorf("could not save token: %s", err)
}
t, err = tm.getToken(ctx)
if err != nil {
return err
}
if t == nil {
return fmt.Errorf("could not find token data")
}
}
if t != nil {
tm.latest = t
tm.versions[t.version] = t
}
return nil
}
func (tm *TokenManager) getTokenVersionCache(ctx context.Context, version int) (*token, error) {
tm.lock.RLock()
defer tm.lock.RUnlock()
if t, ok := tm.versions[version]; ok {
return t, nil
}
return nil, nil
}
func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*token, error) {
token, err := tm.getTokenVersionCache(ctx, version)
if err != nil {
return nil, err
}
if token != nil {
return token, nil
}
latest, err := tm.getToken(ctx)
if err != nil {
return nil, err
}
if latest.version < version {
return nil, fmt.Errorf("invalid signature version")
}
tm.lock.Lock()
defer tm.lock.Unlock()
// in case it was set while we were waiting for a lock
if token, ok := tm.versions[version]; ok {
return token, nil
}
// log.Printf("requesting token %q/%d", tm.path, version)
rv, err := tm.getKVVersion(ctx, version)
if err != nil {
return nil, err
}
token, err = parseTokenVaultSecret(rv.Data)
if err != nil {
return nil, err
}
tm.versions[version] = token
return token, err
}
func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
tm.lock.RLock()
if token := tm.latest; token != nil {
tm.lock.RUnlock()
return token, nil
}
// log.Printf("getToken didn't have latest token,`` getting from vault")
tm.lock.RUnlock()
tm.lock.Lock()
defer tm.lock.Unlock()
// in case it was set while we were waiting for a lock
if tm.latest != nil {
return tm.latest, nil
}
rv, err := tm.getKV(ctx)
if err != nil {
return nil, err
}
if rv == nil {
return nil, &notFoundError{}
}
t, err := parseTokenVaultSecret(rv.Data)
if err != nil {
return nil, err
}
tm.latest = t
tm.versions[t.version] = t
return t, nil
}
func parseTokenVaultSecret(data map[string]interface{}) (*token, error) {
t := &token{}
var err error
if dataif, ok := data["data"]; ok {
data := dataif.(map[string]interface{})
if tokData, ok := data["token"]; ok {
if tokStr, ok := tokData.(string); ok {
t.Secret = tokStr
}
}
if tokData, ok := data["token-ts"]; ok {
if tokInt, ok := tokData.(json.Number); ok {
t.Created, err = tokInt.Int64()
if t.Created == 0 || err != nil {
log.Printf("could not parse Created from token secret (%T: %+v): %s", tokData, tokData, err)
}
}
}
}
if metaif, ok := data["metadata"]; ok {
meta := metaif.(map[string]interface{})
if version, ok := meta["version"]; ok {
if v, ok := version.(json.Number); ok {
v64, err := v.Int64()
if err != nil {
return nil, err
}
t.version = int(v64)
}
}
}
if t.version == 0 || len(t.Secret) == 0 {
return nil, fmt.Errorf("expected token data not found")
}
return t, nil
}
func makeToken() *token {
randomBytes := make([]byte, 16)
_, err := rand.Read(randomBytes)
if err != nil {
return nil
}
return &token{
Secret: base64.URLEncoding.EncodeToString(randomBytes),
Created: time.Now().Unix(),
}
}
func (tm *TokenManager) getKV(ctx context.Context) (*vaultapi.Secret, error) {
rv, err := tm.vault.Logical().ReadWithContext(ctx, tm.path)
if err != nil {
return nil, err
}
return rv, nil
}
func (tm *TokenManager) getKVVersion(ctx context.Context, version int) (*vaultapi.Secret, error) {
data := map[string][]string{
"version": {strconv.Itoa(version)},
}
rv, err := tm.vault.Logical().ReadWithDataWithContext(ctx, tm.path, data)
if err != nil {
return nil, err
}
return rv, nil
}
// func (tm *TokenManager) setKV(ctx context.Context, k string, data *vaultapi.Secret) error {
// p := tm.path()
// cl, err := vaultClient()
// if err != nil {
// return nil
// }
// _, err = cl.Logical().WriteWithContext(ctx, p, data.Data)
// if err != nil {
// return err
// }
// return nil
// }