diff --git a/internal/kms/kms.go b/internal/kms/kms.go index 5675d7bf3a..39438f9993 100644 --- a/internal/kms/kms.go +++ b/internal/kms/kms.go @@ -3,6 +3,8 @@ package kms import ( "context" "fmt" + "io" + "reflect" "sync" "github.com/hashicorp/boundary/internal/db" @@ -269,6 +271,48 @@ func (k *Kms) loadRoot(ctx context.Context, scopeId string, opt ...Option) (*mul return multi, rootKeyId, nil } +// ReconcileKeys will reconcile the keys in the kms against known possible issues. +func (k *Kms) ReconcileKeys(ctx context.Context, randomReader io.Reader) error { + const op = "kms.ReconcileKeys" + if isNil(randomReader) { + return errors.New(ctx, errors.InvalidParameter, op, "missing rand reader") + } + + // it's possible that the global audit key was created after this instance's + // database was initialized... so check if the audit wrapper is available + // for the global scope and if not, then add one to the global scope + if _, err := k.GetWrapper(ctx, scope.Global.String(), KeyPurposeAudit); err != nil { + switch { + case errors.Match(errors.T(errors.KeyNotFound), err): + globalRootWrapper, _, err := k.loadRoot(ctx, scope.Global.String()) + if err != nil { + return errors.Wrap(ctx, err, op) + } + key, err := generateKey(ctx, randomReader) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("error generating random bytes for database key in scope %s", scope.Global.String()))) + } + if _, _, err := k.repo.CreateAuditKey(ctx, globalRootWrapper, key); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("error creating audit key in scope %s", scope.Global.String()))) + } + default: + errors.Wrap(ctx, err, op) + } + } + return nil +} + +func isNil(i interface{}) bool { + if i == nil { + return true + } + switch reflect.TypeOf(i).Kind() { + case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: + return reflect.ValueOf(i).IsNil() + } + return false +} + // Dek is an interface wrapping dek types to allow a lot less switching in loadDek type Dek interface { GetRootKeyId() string diff --git a/internal/kms/kms_ext_test.go b/internal/kms/kms_ext_test.go index d533f4f655..dcfe4acae7 100644 --- a/internal/kms/kms_ext_test.go +++ b/internal/kms/kms_ext_test.go @@ -2,12 +2,15 @@ package kms_test import ( "context" + "crypto/rand" "encoding/base64" "fmt" + "io" "strings" "testing" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" @@ -125,3 +128,77 @@ func TestKms(t *testing.T) { }) } } + +func TestKms_ReconcileKeys(t *testing.T) { + t.Parallel() + testCtx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + tests := []struct { + name string + kms *kms.Kms + reader io.Reader + setup func() + wantErr bool + wantErrMatch *errors.Template + wantErrContains string + }{ + { + name: "missing-reader", + kms: kms.TestKms(t, conn, wrapper), + wantErr: true, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing rand reader", + }, + { + name: "reader-interface-is-nil", + kms: kms.TestKms(t, conn, wrapper), + reader: func() io.Reader { var sr *strings.Reader; var r io.Reader = sr; return r }(), + wantErr: true, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "missing rand reader", + }, + { + name: "nothing-to-reconcile", + kms: kms.TestKms(t, conn, wrapper), + reader: rand.Reader, + wantErr: false, + }, + { + name: "reconcile-audit-key", + kms: kms.TestKms(t, conn, wrapper), + reader: rand.Reader, + setup: func() { + require.NoError(t, conn.Where("1=1").Delete(kms.AllocAuditKey()).Error) + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + // start with no keys... + require.NoError(conn.Where("1=1").Delete(kms.AllocRootKey()).Error) + _, err := kms.CreateKeysTx(context.Background(), rw, rw, wrapper, rand.Reader, scope.Global.String()) + require.NoError(err) + + if tt.setup != nil { + tt.setup() + } + + err = tt.kms.ReconcileKeys(testCtx, tt.reader) + if tt.wantErr { + assert.Error(err) + if tt.wantErrMatch != nil { + assert.Truef(errors.Match(tt.wantErrMatch, err), "expected %q and got err: %+v", tt.wantErrMatch.Code, err) + } + if tt.wantErrContains != "" { + assert.True(strings.Contains(err.Error(), tt.wantErrContains)) + } + return + } + assert.NoError(err) + }) + } +} diff --git a/internal/kms/repository.go b/internal/kms/repository.go index ac8b579c29..e58cd5590b 100644 --- a/internal/kms/repository.go +++ b/internal/kms/repository.go @@ -177,6 +177,10 @@ func CreateKeysTx(ctx context.Context, dbReader db.Reader, dbWriter db.Writer, r KeyTypeOidcKeyVersion: oidcKeyVersion, } if scopeId == scope.Global.String() { + k, err = generateKey(ctx, randomReader) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("error generating random bytes for oidc key in scope %s", scopeId))) + } auditKey, auditKeyVersion, err := createAuditKeyTx(ctx, dbReader, dbWriter, rkvWrapper, k) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to create audit key in scope %s", scopeId))) diff --git a/internal/servers/controller/controller.go b/internal/servers/controller/controller.go index ec6d74911c..27ffe9d75c 100644 --- a/internal/servers/controller/controller.go +++ b/internal/servers/controller/controller.go @@ -168,6 +168,10 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { ); err != nil { return nil, fmt.Errorf("error adding config keys to kms: %w", err) } + if err := c.kms.ReconcileKeys(ctx, c.conf.SecureRandomReader); err != nil { + return nil, fmt.Errorf("error reconciling kms keys: %w", err) + } + // now that the kms is configured, we can get the audit wrapper and rotate // the eventer audit wrapper, so the emitted events can include encrypt and // hmac-sha256 data diff --git a/internal/servers/controller/controller_test.go b/internal/servers/controller/controller_test.go new file mode 100644 index 0000000000..d0aa8c2145 --- /dev/null +++ b/internal/servers/controller/controller_test.go @@ -0,0 +1,34 @@ +package controller + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/internal/kms" + "github.com/stretchr/testify/require" +) + +func TestController_New(t *testing.T) { + t.Run("ReconcileKeys", func(t *testing.T) { + require := require.New(t) + testCtx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + tc := &TestController{ + t: t, + ctx: ctx, + cancel: cancel, + opts: nil, + } + conf := TestControllerConfig(t, ctx, tc, nil) + + // this tests a scenario where there is an audit DEK + c, err := New(testCtx, conf) + require.NoError(err) + + // this tests a scenario where there is NOT an audit DEK + require.NoError(c.conf.Server.Database.Where("1=1").Delete(kms.AllocAuditKey()).Error) + _, err = New(testCtx, conf) + require.NoError(err) + }) + +} diff --git a/internal/servers/controller/testing.go b/internal/servers/controller/testing.go index 1a20eb842f..f8fe35b192 100644 --- a/internal/servers/controller/testing.go +++ b/internal/servers/controller/testing.go @@ -406,6 +406,34 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { opts: opts, } + conf := TestControllerConfig(t, ctx, tc, opts) + var err error + tc.c, err = New(ctx, conf) + if err != nil { + tc.Shutdown() + t.Fatal(err) + } + + tc.buildClient() + + if !opts.DisableAutoStart { + if err := tc.c.Start(); err != nil { + tc.Shutdown() + t.Fatal(err) + } + } + + return tc +} + +// TestControllerConfig provides a way to create a config for a TestController. +// The tc passed as a parameter will be modified by this func. +func TestControllerConfig(t *testing.T, ctx context.Context, tc *TestController, opts *TestControllerOpts) *Config { + const op = "controller.TestControllerConfig" + if opts == nil { + opts = new(TestControllerOpts) + } + // Base server tc.b = base.NewServer(&base.Command{ Context: ctx, @@ -600,28 +628,11 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController { } } - conf := &Config{ + return &Config{ RawConfig: opts.Config, Server: tc.b, DisableAuthorizationFailures: opts.DisableAuthorizationFailures, } - - tc.c, err = New(ctx, conf) - if err != nil { - tc.Shutdown() - t.Fatal(err) - } - - tc.buildClient() - - if !opts.DisableAutoStart { - if err := tc.c.Start(); err != nil { - tc.Shutdown() - t.Fatal(err) - } - } - - return tc } func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestControllerOpts) *TestController {