Cleanup API a little and add minimal docs

This commit is contained in:
Ask Bjørn Hansen 2023-02-13 23:10:55 -08:00
parent 57639813e0
commit 39e4be8f84

181
vault.go
View File

@ -4,14 +4,13 @@ import (
"bytes" "bytes"
"context" "context"
"crypto/hmac" "crypto/hmac"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"log" "log"
"math/rand"
"os" "os"
"path"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -19,8 +18,6 @@ import (
vaultapi "github.com/hashicorp/vault/api" vaultapi "github.com/hashicorp/vault/api"
) )
const tokenRefreshInterval = 10 * time.Hour
type notFoundError struct{} type notFoundError struct{}
func (m *notFoundError) Error() string { func (m *notFoundError) Error() string {
@ -34,43 +31,57 @@ type token struct {
} }
type TokenManager struct { type TokenManager struct {
key string path string
basePath string vault *vaultapi.Client
cfg *Config
latest *token latest *token
versions map[int]*token versions map[int]*token
vault *vaultapi.Client
lock sync.RWMutex lock sync.RWMutex
} }
func New(key, depEnv string) (*TokenManager, error) { // Config configures the token manager on initalization
type Config struct {
if len(depEnv) == 0 { // Vault is a vault client configured with SetToken (or defaulting to VAULT_TOKEN
return nil, fmt.Errorf("invalid deployment mode parameter %q", depEnv) // from the environment)
Vault *vaultapi.Client
// Path is the key in vault (in a KVv2 secrets engine), for example "kv/data/project/tokens"
Path string
// RefreshInterval is the interval a new token is written. Vault defaults to keeping
// 10 versions (so by default the validity period of the signatures is 10 * interval).
// Defaults to 16 hours, plus/minus 90 seconds to minimize race conditions (rotating
// the token twice).
// https://developer.hashicorp.com/vault/tutorials/secrets-management/versioned-kv#step-4-specify-the-number-of-versions-to-keep
RefreshInterval time.Duration
} }
var basePath = fmt.Sprintf("kv/data/ntppool/%s/", depEnv) // New returns a TokenManager using specified Config. A goroutine will
// run to refresh the token until the context is cancelled.
func New(ctx context.Context, cfg *Config) (*TokenManager, error) {
cl, err := vaultClient() if cfg.RefreshInterval == 0 {
if err != nil { cfg.RefreshInterval = 16 * time.Hour
return nil, err }
if cfg.RefreshInterval < time.Minute*10 {
return nil, fmt.Errorf("RefreshInterval must be at least 10 minutes")
} }
tm := &TokenManager{ tm := &TokenManager{
key: key, path: cfg.Path,
basePath: basePath, vault: cfg.Vault,
vault: cl, cfg: cfg,
versions: map[int]*token{}, versions: map[int]*token{},
} }
err = tm.populate() err := tm.populate()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// todo: pass context so it can be shutdown go tm.rotateTokensBackground(ctx)
go tm.rotateTokensBackground()
return tm, nil return tm, nil
} }
@ -89,7 +100,10 @@ func getSignatureVersion(sig []byte) (int, error) {
return version, nil return version, nil
} }
func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) { // ValidateBytes will validate the signature matches the specified data. The
// signature from SignBytes includes a key version. An error is returned
// if the key version isn't available anymore.
func (tm *TokenManager) ValidateBytes(sig []byte, data ...[]byte) (bool, error) {
version, err := getSignatureVersion(sig) version, err := getSignatureVersion(sig)
if err != nil || version == 0 { if err != nil || version == 0 {
return false, err return false, err
@ -114,7 +128,9 @@ func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) {
return false, fmt.Errorf("could not validate signature") return false, fmt.Errorf("could not validate signature")
} }
func (tm *TokenManager) Sign(data ...[]byte) ([]byte, error) { // SignBytes returns a base64 encoded hmac signature of the given data,
// using the most recent key. The signature is prefixed with the key version
func (tm *TokenManager) SignBytes(data ...[]byte) ([]byte, error) {
token, err := tm.getToken(context.Background()) token, err := tm.getToken(context.Background())
if err != nil { if err != nil {
@ -150,12 +166,10 @@ func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) {
return r, nil return r, nil
} }
func (tm *TokenManager) rotateTokensBackground() { func (tm *TokenManager) rotateTokensBackground(ctx context.Context) {
ctx := context.Background() // for when the app has context properly l := log.New(os.Stderr, "vault-token-manager: ", 0)
l := log.New(os.Stderr, "rotateTokensBackground: ", 0) ticker := time.NewTicker(tm.cfg.RefreshInterval / 5)
ticker := time.NewTicker(tokenRefreshInterval / 5)
defer ticker.Stop() defer ticker.Stop()
for { for {
@ -167,23 +181,32 @@ func (tm *TokenManager) rotateTokensBackground() {
latestTime := time.Unix(latest.Created, 0) latestTime := time.Unix(latest.Created, 0)
l.Printf("checking token age, latest is %s old (interval: %s)", time.Since(latestTime), tokenRefreshInterval.String()) // l.Printf("checking token age %q, latest is %s old (interval: %s)", tm.path, time.Since(latestTime), tm.cfg.RefreshInterval.String())
if age := time.Since(latestTime); age > tokenRefreshInterval { if age := time.Since(latestTime); age > tm.cfg.RefreshInterval {
l.Printf("token age (%s) is more than %s, rotate it", age, tokenRefreshInterval) // l.Printf("token age (%s) is more than %s, rotate it", age, tm.cfg.RefreshInterval)
tm.createNewToken(ctx, latest.version) tm.createNewToken(ctx, latest.version)
tm.lock.Lock() tm.lock.Lock()
tm.latest = nil tm.latest = nil
tm.lock.Unlock() tm.lock.Unlock()
tm.getToken(ctx) tm.getToken(ctx)
l.Printf("finished renewing token") // l.Printf("finished renewing token")
} }
untilNext := time.Until(latestTime.Add(tm.cfg.RefreshInterval))
// randomize the time within +/- 90 seconds
untilNext = untilNext + time.Second*time.Duration(rand.Intn(180)-90)
timer := time.NewTimer(untilNext)
select { select {
case <-ctx.Done(): case <-ctx.Done():
l.Printf("context done") ticker.Stop()
return return
case <-ticker.C: case <-ticker.C:
timer.Stop()
case <-timer.C:
} }
} }
@ -200,7 +223,7 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
}, },
} }
_, err := tm.vault.Logical().WriteWithContext(ctx, tm.path(), data) _, err := tm.vault.Logical().WriteWithContext(ctx, tm.path, data)
if err != nil { if err != nil {
return err return err
} }
@ -208,10 +231,6 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
return nil return nil
} }
func (tm *TokenManager) path() string {
return path.Join(tm.basePath, tm.key)
}
func (tm *TokenManager) populate() error { func (tm *TokenManager) populate() error {
ctx := context.Background() ctx := context.Background()
@ -283,8 +302,8 @@ func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*toke
return token, nil return token, nil
} }
log.Printf("requesting token %q/%d", tm.key, version) // log.Printf("requesting token %q/%d", tm.path, version)
rv, err := tm.getKVVersion(ctx, tm.key, version) rv, err := tm.getKVVersion(ctx, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -308,7 +327,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
return token, nil return token, nil
} }
log.Printf("getToken didn't have latest token, getting from vault") // log.Printf("getToken didn't have latest token,`` getting from vault")
tm.lock.RUnlock() tm.lock.RUnlock()
tm.lock.Lock() tm.lock.Lock()
@ -319,9 +338,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
return tm.latest, nil return tm.latest, nil
} }
log.Printf("getToken calling getKV") rv, err := tm.getKV(ctx)
rv, err := tm.getKV(ctx, tm.key)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -337,8 +354,6 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
tm.latest = t tm.latest = t
tm.versions[t.version] = t tm.versions[t.version] = t
log.Printf("getToken returning success")
return t, nil return t, nil
} }
@ -398,44 +413,8 @@ func makeToken() *token {
} }
} }
var hasOutputVaultEnvMessage bool func (tm *TokenManager) getKV(ctx context.Context) (*vaultapi.Secret, error) {
rv, err := tm.vault.Logical().ReadWithContext(ctx, tm.path)
func vaultClient() (*vaultapi.Client, error) {
c := vaultapi.DefaultConfig()
if c.Address == "https://127.0.0.1:8200" {
c.Address = "https://vault.ntppool.org"
}
cl, err := vaultapi.NewClient(c)
if err != nil {
return nil, err
}
// VAULT_TOKEN is read automatically from the environment if set
// so we just try the file here
token, err := os.ReadFile("/vault/secrets/token")
if err == nil {
cl.SetToken(string(token))
} else {
if !hasOutputVaultEnvMessage {
hasOutputVaultEnvMessage = true
log.Printf("could not read /vault/secrets/token (%s), using VAULT_TOKEN", err)
}
}
return cl, nil
}
func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret, error) {
cl, err := vaultClient()
if err != nil {
return nil, nil
}
rv, err := cl.Logical().ReadWithContext(ctx, tm.path())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -443,18 +422,12 @@ func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret,
return rv, nil return rv, nil
} }
func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int) (*vaultapi.Secret, error) { func (tm *TokenManager) getKVVersion(ctx context.Context, version int) (*vaultapi.Secret, error) {
cl, err := vaultClient()
if err != nil {
return nil, nil
}
data := map[string][]string{ data := map[string][]string{
"version": {strconv.Itoa(version)}, "version": {strconv.Itoa(version)},
} }
rv, err := cl.Logical().ReadWithDataWithContext(ctx, tm.path(), data) rv, err := tm.vault.Logical().ReadWithDataWithContext(ctx, tm.path, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -462,17 +435,17 @@ func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int)
return rv, nil return rv, nil
} }
func (tm *TokenManager) SetKV(ctx context.Context, k string, data *vaultapi.Secret) error { // func (tm *TokenManager) setKV(ctx context.Context, k string, data *vaultapi.Secret) error {
p := tm.path() // p := tm.path()
cl, err := vaultClient() // cl, err := vaultClient()
if err != nil { // if err != nil {
return nil // return nil
} // }
_, err = cl.Logical().WriteWithContext(ctx, p, data.Data) // _, err = cl.Logical().WriteWithContext(ctx, p, data.Data)
if err != nil { // if err != nil {
return err // return err
} // }
return nil // return nil
} // }