Simplify worker auth rotation timing logic (#3237)

* Simplify worker auth rotation timing logic

* Add some jitter
pull/3238/head
Jeff Mitchell 3 years ago
parent 6f4f51cc9f
commit e0787f4de0

@ -144,6 +144,7 @@ type Controller struct {
} }
func New(ctx context.Context, conf *Config) (*Controller, error) { func New(ctx context.Context, conf *Config) (*Controller, error) {
const op = "controller.New"
metric.InitializeApiCollectors(conf.PrometheusRegisterer) metric.InitializeApiCollectors(conf.PrometheusRegisterer)
c := &Controller{ c := &Controller{
conf: conf, conf: conf,
@ -395,7 +396,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) {
} }
_, err = server.RotateRoots(ctx, serversRepo, nodeenrollment.WithCertificateLifetime(conf.TestOverrideWorkerAuthCaCertificateLifetime)) _, err = server.RotateRoots(ctx, serversRepo, nodeenrollment.WithCertificateLifetime(conf.TestOverrideWorkerAuthCaCertificateLifetime))
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to ensure worker auth roots exist: %w", err) event.WriteSysEvent(ctx, op, "unable to ensure worker auth roots exist, may be due to multiple controllers starting at once, continuing")
} }
if downstreamersFactory != nil { if downstreamersFactory != nil {

@ -473,6 +473,9 @@ type TestControllerOpts struct {
// The amount of time between the scheduler waking up to run it's registered // The amount of time between the scheduler waking up to run it's registered
// jobs. // jobs.
SchedulerRunJobInterval time.Duration SchedulerRunJobInterval time.Duration
// The time to use for CA certificate lifetime for worker auth
WorkerAuthCaCertificateLifetime time.Duration
} }
func NewTestController(t testing.TB, opts *TestControllerOpts) *TestController { func NewTestController(t testing.TB, opts *TestControllerOpts) *TestController {
@ -774,6 +777,7 @@ func TestControllerConfig(t testing.TB, ctx context.Context, tc *TestController,
RawConfig: opts.Config, RawConfig: opts.Config,
Server: tc.b, Server: tc.b,
DisableAuthorizationFailures: opts.DisableAuthorizationFailures, DisableAuthorizationFailures: opts.DisableAuthorizationFailures,
TestOverrideWorkerAuthCaCertificateLifetime: opts.WorkerAuthCaCertificateLifetime,
} }
} }

@ -8,6 +8,8 @@ import (
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt" "fmt"
"math/rand"
"sync/atomic"
"time" "time"
berrors "github.com/hashicorp/boundary/internal/errors" berrors "github.com/hashicorp/boundary/internal/errors"
@ -17,6 +19,18 @@ import (
"github.com/hashicorp/nodeenrollment/types" "github.com/hashicorp/nodeenrollment/types"
) )
// The default time to use when we encounter an error or some other reason
// we can't get a better reset time
const defaultAuthRotationResetDuration = 5 * time.Second
// AuthRotationResetDuration allows us to view it from tests, which allows us to
// get test timing right. It'll start at 0 which will cause us to run
// immediately.
var AuthRotationResetDuration time.Duration
// AuthRotationNextRotation is useful in tests to understand how long to sleep
var AuthRotationNextRotation atomic.Pointer[time.Time]
func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) { func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
const op = "worker.(Worker).startAuthRotationTicking" const op = "worker.(Worker).startAuthRotationTicking"
@ -25,22 +39,27 @@ func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
return return
} }
// The default time to use when we encounter an error or some other reason // Get a valid value in
// we can't get a better reset time startNow := time.Now()
const defaultResetDuration = time.Hour AuthRotationNextRotation.Store(&startNow)
var resetDuration time.Duration // will start at 0, so we run immediately
timer := time.NewTimer(defaultResetDuration) timer := time.NewTimer(defaultAuthRotationResetDuration)
lastRotation := time.Time{}.UTC() lastRotation := time.Time{}.UTC()
for { for {
// You're not supposed to call reset on timers that haven't been stopped or // Per the example in the docs, if you stop a timer you should drain the
// expired, so we stop it here and reset it to the current resetDuration. That way if // channel if it returns false. However, their example of blindly
// we want to adjust time, e.g. for tests, we can set the resetDuration // reading from the channel can deadlock. So we just do a select here to
// value from within the loop // be safer.
timer.Stop() timer.Stop()
select {
case <-timer.C:
default:
}
if w.TestOverrideAuthRotationPeriod != 0 { if w.TestOverrideAuthRotationPeriod != 0 {
resetDuration = w.TestOverrideAuthRotationPeriod AuthRotationResetDuration = w.TestOverrideAuthRotationPeriod
} }
timer.Reset(resetDuration) timer.Reset(AuthRotationResetDuration)
select { select {
case <-cancelCtx.Done(): case <-cancelCtx.Done():
@ -48,7 +67,10 @@ func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
return return
case <-timer.C: case <-timer.C:
resetDuration = defaultResetDuration // Add some jitter in case there is some issue to prevent thundering
// herd
jitter := time.Duration(rand.Intn(6)) * time.Second
AuthRotationResetDuration = defaultAuthRotationResetDuration + jitter
// Check if it's time to rotate and if not don't do anything // Check if it's time to rotate and if not don't do anything
currentNodeCreds, err := types.LoadNodeCredentials( currentNodeCreds, err := types.LoadNodeCredentials(
@ -59,18 +81,15 @@ func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
) )
if err != nil { if err != nil {
if errors.Is(err, nodeenrollment.ErrNotFound) { if errors.Is(err, nodeenrollment.ErrNotFound) {
// Be silent, but check again soon // Be silent as we likely haven't authorized yet
resetDuration = 5 * time.Second
continue continue
} }
event.WriteError(cancelCtx, op, err) event.WriteError(cancelCtx, op, err)
resetDuration = 5 * time.Second
continue continue
} }
if currentNodeCreds == nil { if currentNodeCreds == nil {
event.WriteSysEvent(cancelCtx, op, "no error loading worker pki auth creds but nil creds, skipping rotation") event.WriteSysEvent(cancelCtx, op, "no error loading worker pki auth creds but nil creds, skipping rotation")
resetDuration = 5 * time.Second
continue continue
} }
@ -81,7 +100,6 @@ func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
// again fairly quickly because this may also be prior to the // again fairly quickly because this may also be prior to the
// first rotation and we want to ensure the validity period is // first rotation and we want to ensure the validity period is
// checked soon. // checked soon.
resetDuration = 5 * time.Second
continue continue
} }
@ -167,21 +185,22 @@ func (w *Worker) startAuthRotationTicking(cancelCtx context.Context) {
)..., )...,
) )
// See if we don't need to do anything // See if we don't need to do anything, and if so calculate reset
// duration and loop back
if !lastRotation.IsZero() && now.Before(nextRotation) { if !lastRotation.IsZero() && now.Before(nextRotation) {
resetDuration = lastRotation.Add(rotationInterval / 2).Sub(now) AuthRotationResetDuration = lastRotation.Add(rotationInterval / 2).Sub(now)
continue continue
} }
newRotationInterval, err := rotateWorkerAuth(cancelCtx, w, currentNodeCreds) newRotationInterval, err := rotateWorkerAuth(cancelCtx, w, currentNodeCreds)
if err != nil { if err != nil {
resetDuration = 5 * time.Second
event.WriteError(cancelCtx, op, err) event.WriteError(cancelCtx, op, err)
continue continue
} }
lastRotation = now lastRotation = now
nextRotation = lastRotation.Add(newRotationInterval / 2) nextRotation = lastRotation.Add(newRotationInterval / 2)
resetDuration = nextRotation.Sub(now) AuthRotationNextRotation.Store(&nextRotation)
AuthRotationResetDuration = nextRotation.Sub(now)
event.WriteSysEvent(cancelCtx, op, "worker credentials rotated", "next_rotation", nextRotation) event.WriteSysEvent(cancelCtx, op, "worker credentials rotated", "next_rotation", nextRotation)
} }
} }

@ -30,27 +30,27 @@ func TestRotationTicking(t *testing.T) {
Level: hclog.Trace, Level: hclog.Trace,
}) })
const rotationPeriod = 20 * time.Second
conf, err := config.DevController() conf, err := config.DevController()
require.NoError(err) require.NoError(err)
c := controller.NewTestController(t, &controller.TestControllerOpts{ c := controller.NewTestController(t, &controller.TestControllerOpts{
Config: conf, Config: conf,
Logger: logger.Named("controller"), Logger: logger.Named("controller"),
WorkerAuthCaCertificateLifetime: rotationPeriod,
}) })
t.Cleanup(c.Shutdown) t.Cleanup(c.Shutdown)
const rotationPeriod = 20 * time.Second
// names should not be set when using pki workers // names should not be set when using pki workers
wConf, err := config.DevWorker() wConf, err := config.DevWorker()
require.NoError(err) require.NoError(err)
wConf.Worker.Name = "" wConf.Worker.Name = ""
wConf.Worker.InitialUpstreams = c.ClusterAddrs() wConf.Worker.InitialUpstreams = c.ClusterAddrs()
w := worker.NewTestWorker(t, &worker.TestWorkerOpts{ w := worker.NewTestWorker(t, &worker.TestWorkerOpts{
InitialUpstreams: c.ClusterAddrs(), InitialUpstreams: c.ClusterAddrs(),
Logger: logger.Named("worker"), Logger: logger.Named("worker"),
AuthRotationPeriod: rotationPeriod, Config: wConf,
Config: wConf,
}) })
t.Cleanup(w.Shutdown) t.Cleanup(w.Shutdown)
@ -76,17 +76,21 @@ func TestRotationTicking(t *testing.T) {
// Decode the proto into the request and create the worker // Decode the proto into the request and create the worker
req := new(types.FetchNodeCredentialsRequest) req := new(types.FetchNodeCredentialsRequest)
require.NoError(proto.Unmarshal(reqBytes, req)) require.NoError(proto.Unmarshal(reqBytes, req))
worker, err := serversRepo.CreateWorker(c.Context(), &server.Worker{ newWorker, err := serversRepo.CreateWorker(c.Context(), &server.Worker{
Worker: &store.Worker{ Worker: &store.Worker{
ScopeId: scope.Global.String(), ScopeId: scope.Global.String(),
}, },
}, server.WithFetchNodeCredentialsRequest(req)) }, server.WithFetchNodeCredentialsRequest(req))
require.NoError(err) require.NoError(err)
// Verify we see one authorized set of credentials now // Wait for a short while; there will be an initial rotation of credentials
// after authentication
time.Sleep(rotationPeriod / 2)
// Verify we see authorized credentials now
auths, err = workerAuthRepo.List(c.Context(), (*types.NodeInformation)(nil)) auths, err = workerAuthRepo.List(c.Context(), (*types.NodeInformation)(nil))
require.NoError(err) require.NoError(err)
assert.Len(auths, 1) assert.Len(auths, 2)
// Fetch creds and store current key // Fetch creds and store current key
currNodeCreds, err := types.LoadNodeCredentials( currNodeCreds, err := types.LoadNodeCredentials(
w.Context(), w.Context(),
@ -103,7 +107,9 @@ func TestRotationTicking(t *testing.T) {
// Now we wait and check that we see new values in the DB and different // Now we wait and check that we see new values in the DB and different
// creds on the worker after each rotation period // creds on the worker after each rotation period
for i := 2; i < 5; i++ { for i := 2; i < 5; i++ {
time.Sleep(rotationPeriod) t.Log("iteration", i)
nextRotation := worker.AuthRotationNextRotation.Load()
time.Sleep((*nextRotation).Sub(time.Now()) + 5*time.Second)
// Verify we see 2- after credentials have rotated, we should see current and previous // Verify we see 2- after credentials have rotated, we should see current and previous
auths, err = workerAuthRepo.List(c.Context(), (*types.NodeInformation)(nil)) auths, err = workerAuthRepo.List(c.Context(), (*types.NodeInformation)(nil))
@ -130,7 +136,7 @@ func TestRotationTicking(t *testing.T) {
assert.Equal(priorKeyId, previousKeyId) assert.Equal(priorKeyId, previousKeyId)
// Get workerAuthSet for this worker id and compare keys // Get workerAuthSet for this worker id and compare keys
workerAuthSet, err := workerAuthRepo.FindWorkerAuthByWorkerId(c.Context(), worker.PublicId) workerAuthSet, err := workerAuthRepo.FindWorkerAuthByWorkerId(c.Context(), newWorker.PublicId)
require.NoError(err) require.NoError(err)
assert.NotNil(workerAuthSet) assert.NotNil(workerAuthSet)
assert.NotNil(workerAuthSet.Previous) assert.NotNil(workerAuthSet.Previous)

Loading…
Cancel
Save