Cleanup API a little and add minimal docs
This commit is contained in:
parent
57639813e0
commit
39e4be8f84
181
vault.go
181
vault.go
@ -4,14 +4,13 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"math/rand"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -19,8 +18,6 @@ import (
|
|||||||
vaultapi "github.com/hashicorp/vault/api"
|
vaultapi "github.com/hashicorp/vault/api"
|
||||||
)
|
)
|
||||||
|
|
||||||
const tokenRefreshInterval = 10 * time.Hour
|
|
||||||
|
|
||||||
type notFoundError struct{}
|
type notFoundError struct{}
|
||||||
|
|
||||||
func (m *notFoundError) Error() string {
|
func (m *notFoundError) Error() string {
|
||||||
@ -34,43 +31,57 @@ type token struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TokenManager struct {
|
type TokenManager struct {
|
||||||
key string
|
path string
|
||||||
basePath string
|
vault *vaultapi.Client
|
||||||
|
|
||||||
|
cfg *Config
|
||||||
|
|
||||||
latest *token
|
latest *token
|
||||||
versions map[int]*token
|
versions map[int]*token
|
||||||
|
|
||||||
vault *vaultapi.Client
|
|
||||||
lock sync.RWMutex
|
lock sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(key, depEnv string) (*TokenManager, error) {
|
// Config configures the token manager on initalization
|
||||||
|
type Config struct {
|
||||||
|
// Vault is a vault client configured with SetToken (or defaulting to VAULT_TOKEN
|
||||||
|
// from the environment)
|
||||||
|
Vault *vaultapi.Client
|
||||||
|
// Path is the key in vault (in a KVv2 secrets engine), for example "kv/data/project/tokens"
|
||||||
|
Path string
|
||||||
|
// RefreshInterval is the interval a new token is written. Vault defaults to keeping
|
||||||
|
// 10 versions (so by default the validity period of the signatures is 10 * interval).
|
||||||
|
// Defaults to 16 hours, plus/minus 90 seconds to minimize race conditions (rotating
|
||||||
|
// the token twice).
|
||||||
|
// https://developer.hashicorp.com/vault/tutorials/secrets-management/versioned-kv#step-4-specify-the-number-of-versions-to-keep
|
||||||
|
RefreshInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
if len(depEnv) == 0 {
|
// New returns a TokenManager using specified Config. A goroutine will
|
||||||
return nil, fmt.Errorf("invalid deployment mode parameter %q", depEnv)
|
// run to refresh the token until the context is cancelled.
|
||||||
|
func New(ctx context.Context, cfg *Config) (*TokenManager, error) {
|
||||||
|
|
||||||
|
if cfg.RefreshInterval == 0 {
|
||||||
|
cfg.RefreshInterval = 16 * time.Hour
|
||||||
}
|
}
|
||||||
|
|
||||||
var basePath = fmt.Sprintf("kv/data/ntppool/%s/", depEnv)
|
if cfg.RefreshInterval < time.Minute*10 {
|
||||||
|
return nil, fmt.Errorf("RefreshInterval must be at least 10 minutes")
|
||||||
cl, err := vaultClient()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tm := &TokenManager{
|
tm := &TokenManager{
|
||||||
key: key,
|
path: cfg.Path,
|
||||||
basePath: basePath,
|
vault: cfg.Vault,
|
||||||
vault: cl,
|
cfg: cfg,
|
||||||
versions: map[int]*token{},
|
versions: map[int]*token{},
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tm.populate()
|
err := tm.populate()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: pass context so it can be shutdown
|
go tm.rotateTokensBackground(ctx)
|
||||||
go tm.rotateTokensBackground()
|
|
||||||
|
|
||||||
return tm, nil
|
return tm, nil
|
||||||
}
|
}
|
||||||
@ -89,7 +100,10 @@ func getSignatureVersion(sig []byte) (int, error) {
|
|||||||
return version, nil
|
return version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) {
|
// ValidateBytes will validate the signature matches the specified data. The
|
||||||
|
// signature from SignBytes includes a key version. An error is returned
|
||||||
|
// if the key version isn't available anymore.
|
||||||
|
func (tm *TokenManager) ValidateBytes(sig []byte, data ...[]byte) (bool, error) {
|
||||||
version, err := getSignatureVersion(sig)
|
version, err := getSignatureVersion(sig)
|
||||||
if err != nil || version == 0 {
|
if err != nil || version == 0 {
|
||||||
return false, err
|
return false, err
|
||||||
@ -114,7 +128,9 @@ func (tm *TokenManager) Validate(sig []byte, data ...[]byte) (bool, error) {
|
|||||||
return false, fmt.Errorf("could not validate signature")
|
return false, fmt.Errorf("could not validate signature")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) Sign(data ...[]byte) ([]byte, error) {
|
// SignBytes returns a base64 encoded hmac signature of the given data,
|
||||||
|
// using the most recent key. The signature is prefixed with the key version
|
||||||
|
func (tm *TokenManager) SignBytes(data ...[]byte) ([]byte, error) {
|
||||||
|
|
||||||
token, err := tm.getToken(context.Background())
|
token, err := tm.getToken(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -150,12 +166,10 @@ func (tm *TokenManager) signWith(token *token, data ...[]byte) ([]byte, error) {
|
|||||||
return r, nil
|
return r, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) rotateTokensBackground() {
|
func (tm *TokenManager) rotateTokensBackground(ctx context.Context) {
|
||||||
ctx := context.Background() // for when the app has context properly
|
l := log.New(os.Stderr, "vault-token-manager: ", 0)
|
||||||
|
|
||||||
l := log.New(os.Stderr, "rotateTokensBackground: ", 0)
|
ticker := time.NewTicker(tm.cfg.RefreshInterval / 5)
|
||||||
|
|
||||||
ticker := time.NewTicker(tokenRefreshInterval / 5)
|
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -167,23 +181,32 @@ func (tm *TokenManager) rotateTokensBackground() {
|
|||||||
|
|
||||||
latestTime := time.Unix(latest.Created, 0)
|
latestTime := time.Unix(latest.Created, 0)
|
||||||
|
|
||||||
l.Printf("checking token age, latest is %s old (interval: %s)", time.Since(latestTime), tokenRefreshInterval.String())
|
// l.Printf("checking token age %q, latest is %s old (interval: %s)", tm.path, time.Since(latestTime), tm.cfg.RefreshInterval.String())
|
||||||
|
|
||||||
if age := time.Since(latestTime); age > tokenRefreshInterval {
|
if age := time.Since(latestTime); age > tm.cfg.RefreshInterval {
|
||||||
l.Printf("token age (%s) is more than %s, rotate it", age, tokenRefreshInterval)
|
// l.Printf("token age (%s) is more than %s, rotate it", age, tm.cfg.RefreshInterval)
|
||||||
tm.createNewToken(ctx, latest.version)
|
tm.createNewToken(ctx, latest.version)
|
||||||
tm.lock.Lock()
|
tm.lock.Lock()
|
||||||
tm.latest = nil
|
tm.latest = nil
|
||||||
tm.lock.Unlock()
|
tm.lock.Unlock()
|
||||||
tm.getToken(ctx)
|
tm.getToken(ctx)
|
||||||
l.Printf("finished renewing token")
|
// l.Printf("finished renewing token")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
untilNext := time.Until(latestTime.Add(tm.cfg.RefreshInterval))
|
||||||
|
|
||||||
|
// randomize the time within +/- 90 seconds
|
||||||
|
untilNext = untilNext + time.Second*time.Duration(rand.Intn(180)-90)
|
||||||
|
|
||||||
|
timer := time.NewTimer(untilNext)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
l.Printf("context done")
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
|
timer.Stop()
|
||||||
|
case <-timer.C:
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@ -200,7 +223,7 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := tm.vault.Logical().WriteWithContext(ctx, tm.path(), data)
|
_, err := tm.vault.Logical().WriteWithContext(ctx, tm.path, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -208,10 +231,6 @@ func (tm *TokenManager) createNewToken(ctx context.Context, cas int) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) path() string {
|
|
||||||
return path.Join(tm.basePath, tm.key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *TokenManager) populate() error {
|
func (tm *TokenManager) populate() error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
@ -283,8 +302,8 @@ func (tm *TokenManager) getTokenVersion(ctx context.Context, version int) (*toke
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("requesting token %q/%d", tm.key, version)
|
// log.Printf("requesting token %q/%d", tm.path, version)
|
||||||
rv, err := tm.getKVVersion(ctx, tm.key, version)
|
rv, err := tm.getKVVersion(ctx, version)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -308,7 +327,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("getToken didn't have latest token, getting from vault")
|
// log.Printf("getToken didn't have latest token,`` getting from vault")
|
||||||
|
|
||||||
tm.lock.RUnlock()
|
tm.lock.RUnlock()
|
||||||
tm.lock.Lock()
|
tm.lock.Lock()
|
||||||
@ -319,9 +338,7 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
|
|||||||
return tm.latest, nil
|
return tm.latest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("getToken calling getKV")
|
rv, err := tm.getKV(ctx)
|
||||||
|
|
||||||
rv, err := tm.getKV(ctx, tm.key)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -337,8 +354,6 @@ func (tm *TokenManager) getToken(ctx context.Context) (*token, error) {
|
|||||||
tm.latest = t
|
tm.latest = t
|
||||||
tm.versions[t.version] = t
|
tm.versions[t.version] = t
|
||||||
|
|
||||||
log.Printf("getToken returning success")
|
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -398,44 +413,8 @@ func makeToken() *token {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var hasOutputVaultEnvMessage bool
|
func (tm *TokenManager) getKV(ctx context.Context) (*vaultapi.Secret, error) {
|
||||||
|
rv, err := tm.vault.Logical().ReadWithContext(ctx, tm.path)
|
||||||
func vaultClient() (*vaultapi.Client, error) {
|
|
||||||
|
|
||||||
c := vaultapi.DefaultConfig()
|
|
||||||
|
|
||||||
if c.Address == "https://127.0.0.1:8200" {
|
|
||||||
c.Address = "https://vault.ntppool.org"
|
|
||||||
}
|
|
||||||
|
|
||||||
cl, err := vaultapi.NewClient(c)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// VAULT_TOKEN is read automatically from the environment if set
|
|
||||||
// so we just try the file here
|
|
||||||
token, err := os.ReadFile("/vault/secrets/token")
|
|
||||||
if err == nil {
|
|
||||||
cl.SetToken(string(token))
|
|
||||||
} else {
|
|
||||||
if !hasOutputVaultEnvMessage {
|
|
||||||
hasOutputVaultEnvMessage = true
|
|
||||||
log.Printf("could not read /vault/secrets/token (%s), using VAULT_TOKEN", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return cl, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret, error) {
|
|
||||||
|
|
||||||
cl, err := vaultClient()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
rv, err := cl.Logical().ReadWithContext(ctx, tm.path())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -443,18 +422,12 @@ func (tm *TokenManager) getKV(ctx context.Context, k string) (*vaultapi.Secret,
|
|||||||
return rv, nil
|
return rv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int) (*vaultapi.Secret, error) {
|
func (tm *TokenManager) getKVVersion(ctx context.Context, version int) (*vaultapi.Secret, error) {
|
||||||
|
|
||||||
cl, err := vaultClient()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
data := map[string][]string{
|
data := map[string][]string{
|
||||||
"version": {strconv.Itoa(version)},
|
"version": {strconv.Itoa(version)},
|
||||||
}
|
}
|
||||||
|
|
||||||
rv, err := cl.Logical().ReadWithDataWithContext(ctx, tm.path(), data)
|
rv, err := tm.vault.Logical().ReadWithDataWithContext(ctx, tm.path, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -462,17 +435,17 @@ func (tm *TokenManager) getKVVersion(ctx context.Context, k string, version int)
|
|||||||
return rv, nil
|
return rv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tm *TokenManager) SetKV(ctx context.Context, k string, data *vaultapi.Secret) error {
|
// func (tm *TokenManager) setKV(ctx context.Context, k string, data *vaultapi.Secret) error {
|
||||||
p := tm.path()
|
// p := tm.path()
|
||||||
cl, err := vaultClient()
|
// cl, err := vaultClient()
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
|
||||||
_, err = cl.Logical().WriteWithContext(ctx, p, data.Data)
|
// _, err = cl.Logical().WriteWithContext(ctx, p, data.Data)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
return err
|
// return err
|
||||||
}
|
// }
|
||||||
|
|
||||||
return nil
|
// return nil
|
||||||
}
|
// }
|
||||||
|
Loading…
Reference in New Issue
Block a user