From c3684d20db0d01b8f65fb2bd0921178fcb3a7c07 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Wed, 3 Mar 2021 11:23:00 -0500 Subject: [PATCH] Fix WithKeyId option (#970) Currently this would validate that a KeyId was in the returned wrapper but then return the base wrapper; and if it required a database lookup it wouldn't validate it anyways. All uses of this were for decryption so it didn't actually affect anything to this point (you'd error later instead of sooner) but for any case where we need to encrypt against a specific version or derive a key against a specific version this would be broken, in a way that could lead you to use the wrong key version. This also properly sends the Key ID around sessions. This worked currently for the same reason as above but would have broken for any session in flight once a key was rotated. --- internal/kms/kms.go | 18 +- internal/kms/kms_ext_test.go | 127 ++++++++++++++ internal/kms/kms_test.go | 163 ++++++------------ internal/kms/repository_root_key.go | 2 +- .../handlers/workers/worker_service.go | 2 +- internal/session/repository_session.go | 1 + internal/session/repository_session_test.go | 2 + internal/session/session.go | 1 + 8 files changed, 198 insertions(+), 118 deletions(-) create mode 100644 internal/kms/kms_ext_test.go diff --git a/internal/kms/kms.go b/internal/kms/kms.go index 926cac3ef8..ffe3258080 100644 --- a/internal/kms/kms.go +++ b/internal/kms/kms.go @@ -139,10 +139,6 @@ func (k *Kms) GetExternalWrappers() *ExternalWrappers { return ret } -func generateKeyId(scopeId string, purpose KeyPurpose, version uint32) string { - return fmt.Sprintf("%s_%s_%d", scopeId, purpose, version) -} - // GetWrapper returns a wrapper for the given scope and purpose. When a keyId is // passed, it will ensure that the returning wrapper has that key ID in the // multiwrapper. This is not necesary for encryption but should be supplied for @@ -168,9 +164,12 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose val, ok := k.scopePurposeCache.Load(scopeId + purpose.String()) if ok { wrapper := val.(*multiwrapper.MultiWrapper) - if opts.withKeyId == "" || wrapper.WrapperForKeyID(opts.withKeyId) != nil { + if opts.withKeyId == "" { return wrapper, nil } + if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil { + return keyIdWrapper, nil + } // Fall through to refresh our multiwrapper for this scope/purpose from the DB } @@ -192,6 +191,13 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose } k.scopePurposeCache.Store(scopeId+purpose.String(), wrapper) + if opts.withKeyId != "" { + if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil { + return keyIdWrapper, nil + } + return nil, errors.New(errors.KeyNotFound, op, "unable to find specified key ID") + } + return wrapper, nil } @@ -299,6 +305,8 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r keys, err = repo.ListTokenKeys(ctx) case KeyPurposeSessions: keys, err = repo.ListSessionKeys(ctx) + default: + return nil, errors.New(errors.InvalidParameter, op, "unknown or invalid DEK purpose specified") } if err != nil { return nil, errors.Wrap(err, op, errors.WithMsg("error listing root keys")) diff --git a/internal/kms/kms_ext_test.go b/internal/kms/kms_ext_test.go new file mode 100644 index 0000000000..d533f4f655 --- /dev/null +++ b/internal/kms/kms_ext_test.go @@ -0,0 +1,127 @@ +package kms_test + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/go-kms-wrapping/wrappers/aead" + "github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// NOTE: This is a sequential test that relies on the actions that have come +// before. Please see the comments for details. +func TestKms(t *testing.T) { + t.Parallel() + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + extWrapper := db.TestWrapper(t) + badExtWrapper := db.TestWrapper(t) + repo, err := kms.NewRepository(rw, rw) + require.NoError(t, err) + kmsCache := kms.TestKms(t, conn, extWrapper) + org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, extWrapper)) + + // Verify that the cache is empty, so we can show that by the end of the + // test sequence we did actually look up keys and store them in the cache + t.Run("verify cache empty", func(t *testing.T) { + var count int + kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { + count++ + return true + }) + assert.Equal(t, 0, count) + }) + // Verify that the root keys are all in the database and can be decrypted + // with the correct wrapper from the KMS object + t.Run("verify root keys", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rootKeys, err := repo.ListRootKeys(ctx) + require.NoError(err) + wrappers := kmsCache.GetExternalWrappers() + for _, key := range rootKeys { + kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId()) + require.NoError(err) + assert.Len(kvs, 1) + assert.Len(kvs[0].GetKey(), 32) + } + }) + // Verify that the wrong wrapper causes decryption to fail + t.Run("bad external keys", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + rootKeys, err := repo.ListRootKeys(ctx) + require.NoError(err) + for _, key := range rootKeys { + _, err := repo.ListRootKeyVersions(ctx, badExtWrapper, key.GetPrivateId()) + require.Error(err) + assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error()) + } + }) + // This next sequence is run twice to ensure that calling for the keys twice + // returns the same value each time and doesn't simply populate more keys + // into the KMS object + keyBytes := map[string]bool{} + keyIds := map[string]bool{} + scopePurposeMap := map[string]interface{}{} + for i := 1; i < 3; i++ { + // This iterates through wrappers for all three scopes and four purposes, + // ensuring that the key bytes and IDs are different for each of them, + // simulating calling the KMS object from different scopes for different + // purposes and ensuring the keys are different when that happens. + t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} { + for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} { + wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose) + if purpose == kms.KeyPurposeUnknown { + require.Error(err) + continue + } + require.NoError(err) + multi, ok := wrapper.(*multiwrapper.MultiWrapper) + require.True(ok) + aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper) + require.True(ok) + foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] + foundKeyId := keyIds[aeadWrapper.KeyID()] + if i == 1 { + assert.False(foundKeyBytes) + assert.False(foundKeyId) + keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true + keyIds[aeadWrapper.KeyID()] = true + } else { + assert.True(foundKeyBytes) + assert.True(foundKeyId) + } + } + } + }) + // Verify that the cache has been populated with unique values. The + // second time we validate that the items we find when going through the + // cache a second time are the same as the first. If they were recreated + // the pointers would be different. + t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) { + var count int + kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { + count++ + if i == 1 { + scopePurposeMap[key.(string)] = value + } else { + assert.Same(t, scopePurposeMap[key.(string)], value) + } + return true + }) + // four purposes x 3 scopes + assert.Equal(t, 12, count) + }) + } +} diff --git a/internal/kms/kms_test.go b/internal/kms/kms_test.go index 62203b49e3..0ce3bb44b4 100644 --- a/internal/kms/kms_test.go +++ b/internal/kms/kms_test.go @@ -1,127 +1,68 @@ -package kms_test +package kms import ( "context" - "encoding/base64" - "fmt" - "strings" + "crypto/rand" "testing" "github.com/hashicorp/boundary/internal/db" - "github.com/hashicorp/boundary/internal/iam" - "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" - "github.com/hashicorp/go-kms-wrapping/wrappers/aead" - "github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper" - "github.com/stretchr/testify/assert" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/require" ) -// NOTE: This is a sequential test that relies on the actions that have come -// before. Please see the comments for details. -func TestKms(t *testing.T) { +func TestKms_KeyId(t *testing.T) { t.Parallel() + require := require.New(t) ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) - wrapper := db.TestWrapper(t) - badWrapper := db.TestWrapper(t) - repo, err := kms.NewRepository(rw, rw) - require.NoError(t, err) - kmsCache := kms.TestKms(t, conn, wrapper) - org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) + extWrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw) + require.NoError(err) - // Verify that the cache is empty, so we can show that by the end of the - // test sequence we did actually look up keys and store them in the cache - t.Run("verify cache empty", func(t *testing.T) { - var count int - kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { - count++ - return true - }) - assert.Equal(t, 0, count) - }) - // Verify that the root keys are all in the database and can be decrypted - // with the correct wrapper from the KMS object - t.Run("verify root keys", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - rootKeys, err := repo.ListRootKeys(ctx) - require.NoError(err) - wrappers := kmsCache.GetExternalWrappers() - for _, key := range rootKeys { - kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId()) - require.NoError(err) - assert.Len(kvs, 1) - assert.Len(kvs[0].GetKey(), 32) - } - }) - // Verify that the wrong wrapper causes decryption to fail - t.Run("bad external keys", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - rootKeys, err := repo.ListRootKeys(ctx) - require.NoError(err) - for _, key := range rootKeys { - _, err := repo.ListRootKeyVersions(ctx, badWrapper, key.GetPrivateId()) - require.Error(err) - assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error()) - } - }) - // This next sequence is run twice to ensure that calling for the keys twice - // returns the same value each time and doesn't simply populate more keys - // into the KMS object - keyBytes := map[string]bool{} - keyIds := map[string]bool{} - scopePurposeMap := map[string]interface{}{} - for i := 1; i < 3; i++ { - // This iterates through wrappers for all three scopes and four purposes, - // ensuring that the key bytes and IDs are different for each of them, - // simulating calling the KMS object from different scopes for different - // purposes and ensuring the keys are different when that happens. - t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} { - for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} { - wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose) - if purpose == kms.KeyPurposeUnknown { - require.Error(err) - continue - } - require.NoError(err) - multi, ok := wrapper.(*multiwrapper.MultiWrapper) - require.True(ok) - aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper) - require.True(ok) - foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] - foundKeyId := keyIds[aeadWrapper.KeyID()] - if i == 1 { - assert.False(foundKeyBytes) - assert.False(foundKeyId) - keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true - keyIds[aeadWrapper.KeyID()] = true - } else { - assert.True(foundKeyBytes) - assert.True(foundKeyId) - } - } - } - }) - // Verify that the cache has been populated with unique values. The - // second time we validate that the items we find when going through the - // cache a second time are the same as the first. If they were recreated - // the pointers would be different. - t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) { - var count int - kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool { - count++ - if i == 1 { - scopePurposeMap[key.(string)] = value - } else { - assert.Same(t, scopePurposeMap[key.(string)], value) - } - return true - }) - // four purposes x 3 scopes - assert.Equal(t, 12, count) - }) - } + // Make the global scope base keys + _, err = CreateKeysTx(ctx, rw, rw, extWrapper, rand.Reader, scope.Global.String()) + require.NoError(err) + + // Get the global scope's root wrapper + kmsCache, err := NewKms(repo) + require.NoError(err) + require.NoError(kmsCache.AddExternalWrappers(WithRootWrapper(extWrapper))) + globalRootWrapper, _, err := kmsCache.loadRoot(ctx, scope.Global.String()) + require.NoError(err) + + dks, err := repo.ListDatabaseKeys(ctx) + require.NoError(err) + require.Len(dks, 1) + + // Create another key version + newKeyBytes, err := uuid.GenerateRandomBytes(32) + require.NoError(err) + _, err = repo.CreateDatabaseKeyVersion(ctx, globalRootWrapper, dks[0].GetPrivateId(), newKeyBytes) + require.NoError(err) + + dkvs, err := repo.ListDatabaseKeyVersions(ctx, globalRootWrapper, dks[0].GetPrivateId()) + require.NoError(err) + require.Len(dkvs, 2) + + keyId1 := dkvs[0].GetPrivateId() + keyId2 := dkvs[1].GetPrivateId() + + // First test: just getting the key should return the latest + wrapper, err := kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase) + require.NoError(err) + require.Equal(keyId2, wrapper.KeyID()) + + // Second: ask for each in turn + wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId1)) + require.NoError(err) + require.Equal(keyId1, wrapper.KeyID()) + wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId2)) + require.NoError(err) + require.Equal(keyId2, wrapper.KeyID()) + + // Last: verify something bogus finds nothing + _, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId("foo")) + require.Error(err) } diff --git a/internal/kms/repository_root_key.go b/internal/kms/repository_root_key.go index 2851354f14..9190365327 100644 --- a/internal/kms/repository_root_key.go +++ b/internal/kms/repository_root_key.go @@ -72,7 +72,7 @@ func createRootKeyTx(ctx context.Context, w db.Writer, keyWrapper wrapping.Wrapp } // no oplog entries for root key versions if err := w.Create(ctx, &kv); err != nil { - return nil, nil, errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("key versions"))) + return nil, nil, errors.Wrap(err, op, errors.WithMsg("key versions")) } return &rk, &kv, nil diff --git a/internal/servers/controller/handlers/workers/worker_service.go b/internal/servers/controller/handlers/workers/worker_service.go index 17da401845..a9fe541d83 100644 --- a/internal/servers/controller/handlers/workers/worker_service.go +++ b/internal/servers/controller/handlers/workers/worker_service.go @@ -219,7 +219,7 @@ func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.Looku resp.ConnectionsLeft -= int32(authzSummary.CurrentConnectionCount) } - wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions) + wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions, kms.WithKeyId(sessionInfo.KeyId)) if err != nil { return nil, status.Errorf(codes.Internal, "Error getting sessions wrapper: %v", err) } diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index b0b95095c5..d294293000 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -76,6 +76,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. } newSession.Certificate = certBytes newSession.PublicId = id + newSession.KeyId = sessionWrapper.KeyID() var returnedSession *Session _, err = r.writer.DoTx( diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 150c381067..bcfc37cfb4 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -344,8 +344,10 @@ func TestRepository_CreateSession(t *testing.T) { assert.NotNil(ses.CreateTime) assert.NotNil(ses.States[0].StartTime) assert.Equal(ses.States[0].Status, StatusPending) + assert.Equal(wrapper.KeyID(), ses.KeyId) foundSession, _, err := repo.LookupSession(context.Background(), ses.PublicId) assert.NoError(err) + assert.Equal(wrapper.KeyID(), foundSession.KeyId) // Account for slight offsets in nanos assert.True(foundSession.ExpirationTime.Timestamp.AsTime().Sub(ses.ExpirationTime.Timestamp.AsTime()) < time.Second) diff --git a/internal/session/session.go b/internal/session/session.go index f91b5e90a4..9d172a39c9 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -160,6 +160,7 @@ func (s *Session) Clone() interface{} { Endpoint: s.Endpoint, ConnectionLimit: s.ConnectionLimit, WorkerFilter: s.WorkerFilter, + KeyId: s.KeyId, } if len(s.States) > 0 { clone.States = make([]*State, 0, len(s.States))