From 39e4be8f841739a8aef635a9f668a8a6fa9c04bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ask=20Bj=C3=B8rn=20Hansen?= Date: Mon, 13 Feb 2023 23:10:55 -0800 Subject: [PATCH] Cleanup API a little and add minimal docs --- vault.go | 183 ++++++++++++++++++++++++------------------------------- 1 file changed, 78 insertions(+), 105 deletions(-) diff --git a/vault.go b/vault.go index 4d365de..3407e0b 100644 --- a/vault.go +++ b/vault.go @@ -4,14 +4,13 @@ import ( "bytes" "context" "crypto/hmac" - "crypto/rand" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "log" + "math/rand" "os" - "path" "strconv" "sync" "time" @@ -19,8 +18,6 @@ import ( vaultapi "github.com/hashicorp/vault/api" ) -const tokenRefreshInterval = 10 * time.Hour - type notFoundError struct{} func (m *notFoundError) Error() string { @@ -34,43 +31,57 @@ type token struct { } type TokenManager struct { - key string - basePath string + path string + vault *vaultapi.Client + + cfg *Config latest *token versions map[int]*token - vault *vaultapi.Client - lock sync.RWMutex + lock sync.RWMutex } -func New(key, depEnv string) (*TokenManager, error) { +// Config configures the token manager on initalization +type Config struct { + // Vault is a vault client configured with SetToken (or defaulting to VAULT_TOKEN + // from the environment) + Vault *vaultapi.Client + // Path is the key in vault (in a KVv2 secrets engine), for example "kv/data/project/tokens" + Path string + // RefreshInterval is the interval a new token is written. Vault defaults to keeping + // 10 versions (so by default the validity period of the signatures is 10 * interval). + // Defaults to 16 hours, plus/minus 90 seconds to minimize race conditions (rotating + // the token twice). + // https://developer.hashicorp.com/vault/tutorials/secrets-management/versioned-kv#step-4-specify-the-number-of-versions-to-keep + RefreshInterval time.Duration +} - if len(depEnv) == 0 { - return nil, fmt.Errorf("invalid deployment mode parameter %q", depEnv) +// New returns a TokenManager using specified Config. A goroutine will +// run to refresh the token until the context is cancelled. +func New(ctx context.Context, cfg *Config) (*TokenManager, error) { + + if cfg.RefreshInterval == 0 { + cfg.RefreshInterval = 16 * time.Hour } - var basePath = fmt.Sprintf("kv/data/ntppool/%s/", depEnv) - - cl, err := vaultClient() - if err != nil { - return nil, err + if cfg.RefreshInterval < time.Minute*10 { + return nil, fmt.Errorf("RefreshInterval must be at least 10 minutes") } tm := &TokenManager{ - key: key, - basePath: basePath, - vault: cl, + path: cfg.Path, + vault: cfg.Vault, + cfg: cfg, versions: map[int]*token{}, } - err = tm.populate() + err := tm.populate() if err != nil { return nil, err } - // todo: pass context so it can be shutdown - go tm.rotateTokensBackground() + go tm.rotateTokensBackground(ctx) return tm, nil } @@ -89,7 +100,10 @@ func getSignatureVersion(sig []byte) (int, error) { return version, nil } -func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) { +// ValidateBytes will validate the signature matches the specified data. The +// signature from SignBytes includes a key version. An error is returned +// if the key version isn't available anymore. +func (tm *TokenManager) ValidateBytes(sig []byte, data ...[]byte) (bool, error) { version, err := getSignatureVersion(sig) if err != nil || version == 0 { return false, err @@ -114,7 +128,9 @@ func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) { return false, fmt.Errorf("could not validate signature") } -func (tm *TokenManager) Sign(data ...[]byte) ([]byte, error) { +// SignBytes returns a base64 encoded hmac signature of the given data, +// using the most recent key. The signature is prefixed with the key version +func (tm *TokenManager) SignBytes(data ...[]byte) ([]byte, error) { token, err := tm.getToken(context.Background()) if err != nil { @@ -150,12 +166,10 @@ func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) { return r, nil } -func (tm *TokenManager) rotateTokensBackground() { - ctx := context.Background() // for when the app has context properly +func (tm *TokenManager) rotateTokensBackground(ctx context.Context) { + l := log.New(os.Stderr, "vault-token-manager: ", 0) - l := log.New(os.Stderr, "rotateTokensBackground: ", 0) - - ticker := time.NewTicker(tokenRefreshInterval / 5) + ticker := time.NewTicker(tm.cfg.RefreshInterval / 5) defer ticker.Stop() for { @@ -167,23 +181,32 @@ func (tm *TokenManager) rotateTokensBackground() { latestTime := time.Unix(latest.Created, 0) - l.Printf("checking token age, latest is %s old (interval: %s)", time.Since(latestTime), tokenRefreshInterval.String()) + // l.Printf("checking token age %q, latest is %s old (interval: %s)", tm.path, time.Since(latestTime), tm.cfg.RefreshInterval.String()) - if age := time.Since(latestTime); age > tokenRefreshInterval { - l.Printf("token age (%s) is more than %s, rotate it", age, tokenRefreshInterval) + if age := time.Since(latestTime); age > tm.cfg.RefreshInterval { + // l.Printf("token age (%s) is more than %s, rotate it", age, tm.cfg.RefreshInterval) tm.createNewToken(ctx, latest.version) tm.lock.Lock() tm.latest = nil tm.lock.Unlock() tm.getToken(ctx) - l.Printf("finished renewing token") + // l.Printf("finished renewing token") } + untilNext := time.Until(latestTime.Add(tm.cfg.RefreshInterval)) + + // randomize the time within +/- 90 seconds + untilNext = untilNext + time.Second*time.Duration(rand.Intn(180)-90) + + timer := time.NewTimer(untilNext) + select { case <-ctx.Done(): - l.Printf("context done") + ticker.Stop() return case <-ticker.C: + timer.Stop() + case <-timer.C: } } @@ -200,7 +223,7 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error { }, } - _, err := tm.vault.Logical().WriteWithContext(ctx, tm.path(), data) + _, err := tm.vault.Logical().WriteWithContext(ctx, tm.path, data) if err != nil { return err } @@ -208,10 +231,6 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error { return nil } -func (tm *TokenManager) path() string { - return path.Join(tm.basePath, tm.key) -} - func (tm *TokenManager) populate() error { ctx := context.Background() @@ -283,8 +302,8 @@ func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*toke return token, nil } - log.Printf("requesting token %q/%d", tm.key, version) - rv, err := tm.getKVVersion(ctx, tm.key, version) + // log.Printf("requesting token %q/%d", tm.path, version) + rv, err := tm.getKVVersion(ctx, version) if err != nil { return nil, err } @@ -308,7 +327,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) { return token, nil } - log.Printf("getToken didn't have latest token, getting from vault") + // log.Printf("getToken didn't have latest token,`` getting from vault") tm.lock.RUnlock() tm.lock.Lock() @@ -319,9 +338,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) { return tm.latest, nil } - log.Printf("getToken calling getKV") - - rv, err := tm.getKV(ctx, tm.key) + rv, err := tm.getKV(ctx) if err != nil { return nil, err } @@ -337,8 +354,6 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) { tm.latest = t tm.versions[t.version] = t - log.Printf("getToken returning success") - return t, nil } @@ -398,44 +413,8 @@ func makeToken() *token { } } -var hasOutputVaultEnvMessage bool - -func vaultClient() (*vaultapi.Client, error) { - - c := vaultapi.DefaultConfig() - - if c.Address == "https://127.0.0.1:8200" { - c.Address = "https://vault.ntppool.org" - } - - cl, err := vaultapi.NewClient(c) - if err != nil { - return nil, err - } - - // VAULT_TOKEN is read automatically from the environment if set - // so we just try the file here - token, err := os.ReadFile("/vault/secrets/token") - if err == nil { - cl.SetToken(string(token)) - } else { - if !hasOutputVaultEnvMessage { - hasOutputVaultEnvMessage = true - log.Printf("could not read /vault/secrets/token (%s), using VAULT_TOKEN", err) - } - } - - return cl, nil -} - -func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret, error) { - - cl, err := vaultClient() - if err != nil { - return nil, nil - } - - rv, err := cl.Logical().ReadWithContext(ctx, tm.path()) +func (tm *TokenManager) getKV(ctx context.Context) (*vaultapi.Secret, error) { + rv, err := tm.vault.Logical().ReadWithContext(ctx, tm.path) if err != nil { return nil, err } @@ -443,18 +422,12 @@ func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret, return rv, nil } -func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int) (*vaultapi.Secret, error) { - - cl, err := vaultClient() - if err != nil { - return nil, nil - } - +func (tm *TokenManager) getKVVersion(ctx context.Context, version int) (*vaultapi.Secret, error) { data := map[string][]string{ "version": {strconv.Itoa(version)}, } - rv, err := cl.Logical().ReadWithDataWithContext(ctx, tm.path(), data) + rv, err := tm.vault.Logical().ReadWithDataWithContext(ctx, tm.path, data) if err != nil { return nil, err } @@ -462,17 +435,17 @@ func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int) return rv, nil } -func (tm *TokenManager) SetKV(ctx context.Context, k string, data *vaultapi.Secret) error { - p := tm.path() - cl, err := vaultClient() - if err != nil { - return nil - } +// func (tm *TokenManager) setKV(ctx context.Context, k string, data *vaultapi.Secret) error { +// p := tm.path() +// cl, err := vaultClient() +// if err != nil { +// return nil +// } - _, err = cl.Logical().WriteWithContext(ctx, p, data.Data) - if err != nil { - return err - } +// _, err = cl.Logical().WriteWithContext(ctx, p, data.Data) +// if err != nil { +// return err +// } - return nil -} +// return nil +// }