diff --git a/internal/kms/kms.go b/internal/kms/kms.go index 926cac3ef8..ffe3258080 100644 --- a/internal/kms/kms.go +++ b/internal/kms/kms.go @@ -139,10 +139,6 @@ func (k *Kms) GetExternalWrappers() *ExternalWrappers { return ret } -func generateKeyId(scopeId string, purpose KeyPurpose, version uint32) string { - return fmt.Sprintf("%s_%s_%d", scopeId, purpose, version) -} - // GetWrapper returns a wrapper for the given scope and purpose. When a keyId is // passed, it will ensure that the returning wrapper has that key ID in the // multiwrapper. This is not necesary for encryption but should be supplied for @@ -168,9 +164,12 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose val, ok := k.scopePurposeCache.Load(scopeId + purpose.String()) if ok { wrapper := val.(*multiwrapper.MultiWrapper) - if opts.withKeyId == "" || wrapper.WrapperForKeyID(opts.withKeyId) != nil { + if opts.withKeyId == "" { return wrapper, nil } + if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil { + return keyIdWrapper, nil + } // Fall through to refresh our multiwrapper for this scope/purpose from the DB } @@ -192,6 +191,13 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose } k.scopePurposeCache.Store(scopeId+purpose.String(), wrapper) + if opts.withKeyId != "" { + if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil { + return keyIdWrapper, nil + } + return nil, errors.New(errors.KeyNotFound, op, "unable to find specified key ID") + } + return wrapper, nil } @@ -299,6 +305,8 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r keys, err = repo.ListTokenKeys(ctx) case KeyPurposeSessions: keys, err = repo.ListSessionKeys(ctx) + default: + return nil, errors.New(errors.InvalidParameter, op, "unknown or invalid DEK purpose specified") } if err != nil { return nil, errors.Wrap(err, op, errors.WithMsg("error listing root keys")) diff --git a/internal/kms/kms_ext_test.go b/internal/kms/kms_ext_test.go new file mode 100644 index 0000000000..d533f4f655 --- /dev/null +++ b/internal/kms/kms_ext_test.go @@ -0,0 +1,127 @@ +package kms_test + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/go-kms-wrapping/wrappers/aead" + "github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// NOTE: This is a sequential test that relies on the actions that have come +// before. Please see the comments for details. +func TestKms(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + extWrapper := db.TestWrapper(t) + badExtWrapper := db.TestWrapper(t) + repo, err := kms.NewRepository(rw, rw) + require.NoError(t, err) + kmsCache := kms.TestKms(t, conn, extWrapper) + org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, extWrapper)) + + // Verify that the cache is empty, so we can show that by the end of the + // test sequence we did actually look up keys and store them in the cache + t.Run("verify cache empty", func(t *testing.T) { + var count int + kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { + count++ + return true + }) + assert.Equal(t, 0, count) + }) + // Verify that the root keys are all in the database and can be decrypted + // with the correct wrapper from the KMS object + t.Run("verify root keys", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rootKeys, err := repo.ListRootKeys(ctx) + require.NoError(err) + wrappers := kmsCache.GetExternalWrappers() + for _, key := range rootKeys { + kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId()) + require.NoError(err) + assert.Len(kvs, 1) + assert.Len(kvs[0].GetKey(), 32) + } + }) + // Verify that the wrong wrapper causes decryption to fail + t.Run("bad external keys", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rootKeys, err := repo.ListRootKeys(ctx) + require.NoError(err) + for _, key := range rootKeys { + _, err := repo.ListRootKeyVersions(ctx, badExtWrapper, key.GetPrivateId()) + require.Error(err) + assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error()) + } + }) + // This next sequence is run twice to ensure that calling for the keys twice + // returns the same value each time and doesn't simply populate more keys + // into the KMS object + keyBytes := map[string]bool{} + keyIds := map[string]bool{} + scopePurposeMap := map[string]interface{}{} + for i := 1; i < 3; i++ { + // This iterates through wrappers for all three scopes and four purposes, + // ensuring that the key bytes and IDs are different for each of them, + // simulating calling the KMS object from different scopes for different + // purposes and ensuring the keys are different when that happens. + t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} { + for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} { + wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose) + if purpose == kms.KeyPurposeUnknown { + require.Error(err) + continue + } + require.NoError(err) + multi, ok := wrapper.(*multiwrapper.MultiWrapper) + require.True(ok) + aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper) + require.True(ok) + foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] + foundKeyId := keyIds[aeadWrapper.KeyID()] + if i == 1 { + assert.False(foundKeyBytes) + assert.False(foundKeyId) + keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true + keyIds[aeadWrapper.KeyID()] = true + } else { + assert.True(foundKeyBytes) + assert.True(foundKeyId) + } + } + } + }) + // Verify that the cache has been populated with unique values. The + // second time we validate that the items we find when going through the + // cache a second time are the same as the first. If they were recreated + // the pointers would be different. + t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) { + var count int + kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { + count++ + if i == 1 { + scopePurposeMap[key.(string)] = value + } else { + assert.Same(t, scopePurposeMap[key.(string)], value) + } + return true + }) + // four purposes x 3 scopes + assert.Equal(t, 12, count) + }) + } +} diff --git a/internal/kms/kms_test.go b/internal/kms/kms_test.go index 62203b49e3..0ce3bb44b4 100644 --- a/internal/kms/kms_test.go +++ b/internal/kms/kms_test.go @@ -1,127 +1,68 @@ -package kms_test +package kms import ( "context" - "encoding/base64" - "fmt" - "strings" + "crypto/rand" "testing" "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/iam" - "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" - "github.com/hashicorp/go-kms-wrapping/wrappers/aead" - "github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper" - "github.com/stretchr/testify/assert" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" ) -// NOTE: This is a sequential test that relies on the actions that have come -// before. Please see the comments for details. -func TestKms(t *testing.T) { +func TestKms_KeyId(t *testing.T) { t.Parallel() + require := require.New(t) ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) - wrapper := db.TestWrapper(t) - badWrapper := db.TestWrapper(t) - repo, err := kms.NewRepository(rw, rw) - require.NoError(t, err) - kmsCache := kms.TestKms(t, conn, wrapper) - org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + extWrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw) + require.NoError(err) - // Verify that the cache is empty, so we can show that by the end of the - // test sequence we did actually look up keys and store them in the cache - t.Run("verify cache empty", func(t *testing.T) { - var count int - kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { - count++ - return true - }) - assert.Equal(t, 0, count) - }) - // Verify that the root keys are all in the database and can be decrypted - // with the correct wrapper from the KMS object - t.Run("verify root keys", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - rootKeys, err := repo.ListRootKeys(ctx) - require.NoError(err) - wrappers := kmsCache.GetExternalWrappers() - for _, key := range rootKeys { - kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId()) - require.NoError(err) - assert.Len(kvs, 1) - assert.Len(kvs[0].GetKey(), 32) - } - }) - // Verify that the wrong wrapper causes decryption to fail - t.Run("bad external keys", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - rootKeys, err := repo.ListRootKeys(ctx) - require.NoError(err) - for _, key := range rootKeys { - _, err := repo.ListRootKeyVersions(ctx, badWrapper, key.GetPrivateId()) - require.Error(err) - assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error()) - } - }) - // This next sequence is run twice to ensure that calling for the keys twice - // returns the same value each time and doesn't simply populate more keys - // into the KMS object - keyBytes := map[string]bool{} - keyIds := map[string]bool{} - scopePurposeMap := map[string]interface{}{} - for i := 1; i < 3; i++ { - // This iterates through wrappers for all three scopes and four purposes, - // ensuring that the key bytes and IDs are different for each of them, - // simulating calling the KMS object from different scopes for different - // purposes and ensuring the keys are different when that happens. - t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} { - for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} { - wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose) - if purpose == kms.KeyPurposeUnknown { - require.Error(err) - continue - } - require.NoError(err) - multi, ok := wrapper.(*multiwrapper.MultiWrapper) - require.True(ok) - aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper) - require.True(ok) - foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] - foundKeyId := keyIds[aeadWrapper.KeyID()] - if i == 1 { - assert.False(foundKeyBytes) - assert.False(foundKeyId) - keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true - keyIds[aeadWrapper.KeyID()] = true - } else { - assert.True(foundKeyBytes) - assert.True(foundKeyId) - } - } - } - }) - // Verify that the cache has been populated with unique values. The - // second time we validate that the items we find when going through the - // cache a second time are the same as the first. If they were recreated - // the pointers would be different. - t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) { - var count int - kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { - count++ - if i == 1 { - scopePurposeMap[key.(string)] = value - } else { - assert.Same(t, scopePurposeMap[key.(string)], value) - } - return true - }) - // four purposes x 3 scopes - assert.Equal(t, 12, count) - }) - } + // Make the global scope base keys + _, err = CreateKeysTx(ctx, rw, rw, extWrapper, rand.Reader, scope.Global.String()) + require.NoError(err) + + // Get the global scope's root wrapper + kmsCache, err := NewKms(repo) + require.NoError(err) + require.NoError(kmsCache.AddExternalWrappers(WithRootWrapper(extWrapper))) + globalRootWrapper, _, err := kmsCache.loadRoot(ctx, scope.Global.String()) + require.NoError(err) + + dks, err := repo.ListDatabaseKeys(ctx) + require.NoError(err) + require.Len(dks, 1) + + // Create another key version + newKeyBytes, err := uuid.GenerateRandomBytes(32) + require.NoError(err) + _, err = repo.CreateDatabaseKeyVersion(ctx, globalRootWrapper, dks[0].GetPrivateId(), newKeyBytes) + require.NoError(err) + + dkvs, err := repo.ListDatabaseKeyVersions(ctx, globalRootWrapper, dks[0].GetPrivateId()) + require.NoError(err) + require.Len(dkvs, 2) + + keyId1 := dkvs[0].GetPrivateId() + keyId2 := dkvs[1].GetPrivateId() + + // First test: just getting the key should return the latest + wrapper, err := kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase) + require.NoError(err) + require.Equal(keyId2, wrapper.KeyID()) + + // Second: ask for each in turn + wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId1)) + require.NoError(err) + require.Equal(keyId1, wrapper.KeyID()) + wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId2)) + require.NoError(err) + require.Equal(keyId2, wrapper.KeyID()) + + // Last: verify something bogus finds nothing + _, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId("foo")) + require.Error(err) } diff --git a/internal/kms/repository_root_key.go b/internal/kms/repository_root_key.go index 2851354f14..9190365327 100644 --- a/internal/kms/repository_root_key.go +++ b/internal/kms/repository_root_key.go @@ -72,7 +72,7 @@ func createRootKeyTx(ctx context.Context, w db.Writer, keyWrapper wrapping.Wrapp } // no oplog entries for root key versions if err := w.Create(ctx, &kv); err != nil { - return nil, nil, errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("key versions"))) + return nil, nil, errors.Wrap(err, op, errors.WithMsg("key versions")) } return &rk, &kv, nil diff --git a/internal/servers/controller/handlers/workers/worker_service.go b/internal/servers/controller/handlers/workers/worker_service.go index 17da401845..a9fe541d83 100644 --- a/internal/servers/controller/handlers/workers/worker_service.go +++ b/internal/servers/controller/handlers/workers/worker_service.go @@ -219,7 +219,7 @@ func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.Looku resp.ConnectionsLeft -= int32(authzSummary.CurrentConnectionCount) } - wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions) + wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions, kms.WithKeyId(sessionInfo.KeyId)) if err != nil { return nil, status.Errorf(codes.Internal, "Error getting sessions wrapper: %v", err) } diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index b0b95095c5..d294293000 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -76,6 +76,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. } newSession.Certificate = certBytes newSession.PublicId = id + newSession.KeyId = sessionWrapper.KeyID() var returnedSession *Session _, err = r.writer.DoTx( diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 150c381067..bcfc37cfb4 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -344,8 +344,10 @@ func TestRepository_CreateSession(t *testing.T) { assert.NotNil(ses.CreateTime) assert.NotNil(ses.States[0].StartTime) assert.Equal(ses.States[0].Status, StatusPending) + assert.Equal(wrapper.KeyID(), ses.KeyId) foundSession, _, err := repo.LookupSession(context.Background(), ses.PublicId) assert.NoError(err) + assert.Equal(wrapper.KeyID(), foundSession.KeyId) // Account for slight offsets in nanos assert.True(foundSession.ExpirationTime.Timestamp.AsTime().Sub(ses.ExpirationTime.Timestamp.AsTime()) < time.Second) diff --git a/internal/session/session.go b/internal/session/session.go index f91b5e90a4..9d172a39c9 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -160,6 +160,7 @@ func (s *Session) Clone() interface{} { Endpoint: s.Endpoint, ConnectionLimit: s.ConnectionLimit, WorkerFilter: s.WorkerFilter, + KeyId: s.KeyId, } if len(s.States) > 0 { clone.States = make([]*State, 0, len(s.States))