package vaulttokenmanager import ( "bytes" "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "log" "math/rand" "os" "strconv" "sync" "time" vaultapi "github.com/hashicorp/vault/api" ) type notFoundError struct{} func (m *notFoundError) Error() string { return "token not found" } type token struct { Secret string `json:"token"` Created int64 `json:"token-ts"` version int `json:"-"` } type TokenManager struct { path string vault *vaultapi.Client cfg *Config latest *token versions map[int]*token lock sync.RWMutex } // Config configures the token manager on initalization type Config struct { // Vault is a vault client configured with SetToken (or defaulting to VAULT_TOKEN // from the environment) Vault *vaultapi.Client // Path is the key in vault (in a KVv2 secrets engine), for example "kv/data/project/tokens" Path string // RefreshInterval is the interval a new token is written. Vault defaults to keeping // 10 versions (so by default the validity period of the signatures is 10 * interval). // Defaults to 16 hours, plus/minus 90 seconds to minimize race conditions (rotating // the token twice). // https://developer.hashicorp.com/vault/tutorials/secrets-management/versioned-kv#step-4-specify-the-number-of-versions-to-keep RefreshInterval time.Duration } // New returns a TokenManager using specified Config. A goroutine will // run to refresh the token until the context is cancelled. func New(ctx context.Context, cfg *Config) (*TokenManager, error) { if cfg.RefreshInterval == 0 { cfg.RefreshInterval = 16 * time.Hour } if cfg.RefreshInterval < time.Minute*10 { return nil, fmt.Errorf("RefreshInterval must be at least 10 minutes") } tm := &TokenManager{ path: cfg.Path, vault: cfg.Vault, cfg: cfg, versions: map[int]*token{}, } err := tm.populate() if err != nil { return nil, err } go tm.rotateTokensBackground(ctx) return tm, nil } func getSignatureVersion(sig []byte) (int, error) { idx := bytes.IndexByte(sig, '-') if idx < 1 { return 0, fmt.Errorf("invalid signature") } versionb := sig[0:idx] version, err := strconv.Atoi(string(versionb)) if err != nil || version == 0 { return 0, fmt.Errorf("unknown signature version %d: %s", version, err) } return version, nil } // ValidateBytes will validate the signature matches the specified data. The // signature from SignBytes includes a key version. An error is returned // if the key version isn't available anymore. func (tm *TokenManager) ValidateBytes(sig []byte, data ...[]byte) (bool, error) { version, err := getSignatureVersion(sig) if err != nil || version == 0 { return false, err } token, err := tm.getTokenVersion(context.Background(), version) if err != nil { return false, err } expected, err := tm.signWith(token, data...) if err != nil { return false, err } if len(expected) > 0 && bytes.Equal(sig, expected) { return true, nil } log.Printf("exp: %s", expected) log.Printf("got: %s", sig) return false, fmt.Errorf("could not validate signature") } // SignBytes returns a base64 encoded hmac signature of the given data, // using the most recent key. The signature is prefixed with the key version func (tm *TokenManager) SignBytes(data ...[]byte) ([]byte, error) { token, err := tm.getToken(context.Background()) if err != nil { return nil, err } // log.Printf("got version: %d, token: %s", token.version, token.Secret) return tm.signWith(token, data...) } func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) { hm := hmac.New(sha256.New, []byte(token.Secret)) b := bytes.Join(data, []byte("|")) p, err := hm.Write([]byte(b)) if err != nil || p != len(b) { return nil, fmt.Errorf("hmac error: %s", err) } sha := hm.Sum(nil) r := strconv.AppendInt([]byte{}, int64(token.version), 10) r = append(r, []byte("-")...) shaenc := make([]byte, base64.RawURLEncoding.EncodedLen(len(sha))) base64.RawURLEncoding.Encode(shaenc, sha) r = append(r, shaenc...) return r, nil } func (tm *TokenManager) rotateTokensBackground(ctx context.Context) { l := log.New(os.Stderr, "vault-token-manager: ", 0) ticker := time.NewTicker(tm.cfg.RefreshInterval / 5) defer ticker.Stop() for { latest, err := tm.getToken(ctx) if err != nil { l.Printf("could not get token: %s", err) } latestTime := time.Unix(latest.Created, 0) // l.Printf("checking token age %q, latest is %s old (interval: %s)", tm.path, time.Since(latestTime), tm.cfg.RefreshInterval.String()) if age := time.Since(latestTime); age > tm.cfg.RefreshInterval { // l.Printf("token age (%s) is more than %s, rotate it", age, tm.cfg.RefreshInterval) tm.createNewToken(ctx, latest.version) tm.lock.Lock() tm.latest = nil tm.lock.Unlock() tm.getToken(ctx) // l.Printf("finished renewing token") } untilNext := time.Until(latestTime.Add(tm.cfg.RefreshInterval)) // randomize the time within +/- 90 seconds untilNext = untilNext + time.Second*time.Duration(rand.Intn(180)-90) timer := time.NewTimer(untilNext) select { case <-ctx.Done(): ticker.Stop() return case <-ticker.C: timer.Stop() case <-timer.C: } } } func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error { data := map[string]interface{}{ "data": makeToken(), "metadata": map[string]interface{}{ "cas_required": true, }, "options": map[string]interface{}{ "cas": cas, }, } _, err := tm.vault.Logical().WriteWithContext(ctx, tm.path, data) if err != nil { return err } return nil } func (tm *TokenManager) populate() error { ctx := context.Background() t, err := tm.getToken(ctx) if err != nil { if _, ok := err.(*notFoundError); !ok { return err } } if t == nil { err := tm.createNewToken(ctx, 0) if err != nil { return fmt.Errorf("could not save token: %s", err) } t, err = tm.getToken(ctx) if err != nil { return err } if t == nil { return fmt.Errorf("could not find token data") } } if t != nil { tm.latest = t tm.versions[t.version] = t } return nil } func (tm *TokenManager) getTokenVersionCache(ctx context.Context, version int) (*token, error) { tm.lock.RLock() defer tm.lock.RUnlock() if t, ok := tm.versions[version]; ok { return t, nil } return nil, nil } func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*token, error) { token, err := tm.getTokenVersionCache(ctx, version) if err != nil { return nil, err } if token != nil { return token, nil } latest, err := tm.getToken(ctx) if err != nil { return nil, err } if latest.version < version { return nil, fmt.Errorf("invalid signature version") } tm.lock.Lock() defer tm.lock.Unlock() // in case it was set while we were waiting for a lock if token, ok := tm.versions[version]; ok { return token, nil } // log.Printf("requesting token %q/%d", tm.path, version) rv, err := tm.getKVVersion(ctx, version) if err != nil { return nil, err } token, err = parseTokenVaultSecret(rv.Data) if err != nil { return nil, err } tm.versions[version] = token return token, err } // GetToken returns the most recent available secret token func (tm *TokenManager) GetToken(ctx context.Context) (*token, error) { return tm.getToken(ctx) } func (tm *TokenManager) getToken(ctx context.Context) (*token, error) { tm.lock.RLock() if token := tm.latest; token != nil { tm.lock.RUnlock() return token, nil } // log.Printf("getToken didn't have latest token,`` getting from vault") tm.lock.RUnlock() tm.lock.Lock() defer tm.lock.Unlock() // in case it was set while we were waiting for a lock if tm.latest != nil { return tm.latest, nil } rv, err := tm.getKV(ctx) if err != nil { return nil, err } if rv == nil { return nil, ¬FoundError{} } t, err := parseTokenVaultSecret(rv.Data) if err != nil { return nil, err } tm.latest = t tm.versions[t.version] = t return t, nil } func parseTokenVaultSecret(data map[string]interface{}) (*token, error) { t := &token{} var err error if dataif, ok := data["data"]; ok { data := dataif.(map[string]interface{}) if tokData, ok := data["token"]; ok { if tokStr, ok := tokData.(string); ok { t.Secret = tokStr } } if tokData, ok := data["token-ts"]; ok { if tokInt, ok := tokData.(json.Number); ok { t.Created, err = tokInt.Int64() if t.Created == 0 || err != nil { log.Printf("could not parse Created from token secret (%T: %+v): %s", tokData, tokData, err) } } } } if metaif, ok := data["metadata"]; ok { meta := metaif.(map[string]interface{}) if version, ok := meta["version"]; ok { if v, ok := version.(json.Number); ok { v64, err := v.Int64() if err != nil { return nil, err } t.version = int(v64) } } } if t.version == 0 || len(t.Secret) == 0 { return nil, fmt.Errorf("expected token data not found") } return t, nil } func makeToken() *token { randomBytes := make([]byte, 16) _, err := rand.Read(randomBytes) if err != nil { return nil } return &token{ Secret: base64.URLEncoding.EncodeToString(randomBytes), Created: time.Now().Unix(), } } func (tm *TokenManager) getKV(ctx context.Context) (*vaultapi.Secret, error) { rv, err := tm.vault.Logical().ReadWithContext(ctx, tm.path) if err != nil { return nil, err } return rv, nil } func (tm *TokenManager) getKVVersion(ctx context.Context, version int) (*vaultapi.Secret, error) { data := map[string][]string{ "version": {strconv.Itoa(version)}, } rv, err := tm.vault.Logical().ReadWithDataWithContext(ctx, tm.path, data) if err != nil { return nil, err } return rv, nil } // func (tm *TokenManager) setKV(ctx context.Context, k string, data *vaultapi.Secret) error { // p := tm.path() // cl, err := vaultClient() // if err != nil { // return nil // } // _, err = cl.Logical().WriteWithContext(ctx, p, data.Data) // if err != nil { // return err // } // return nil // }