Fix WithKeyId option (#970)

Currently this would validate that a KeyId was in the returned wrapper
but then return the base wrapper; and if it required a database lookup
it wouldn't validate it anyways.

All uses of this were for decryption so it didn't actually affect
anything to this point (you'd error later instead of sooner) but for any
case where we need to encrypt against a specific version or derive a key
against a specific version this would be broken, in a way that could
lead you to use the wrong key version.

This also properly sends the Key ID around sessions. This worked
currently for the same reason as above but would have broken for any
session in flight once a key was rotated.
pull/975/head
Jeff Mitchell 5 years ago committed by GitHub
parent 519efa6dd1
commit c3684d20db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -139,10 +139,6 @@ func (k *Kms) GetExternalWrappers() *ExternalWrappers {
return ret
}
func generateKeyId(scopeId string, purpose KeyPurpose, version uint32) string {
return fmt.Sprintf("%s_%s_%d", scopeId, purpose, version)
}
// GetWrapper returns a wrapper for the given scope and purpose. When a keyId is
// passed, it will ensure that the returning wrapper has that key ID in the
// multiwrapper. This is not necesary for encryption but should be supplied for
@ -168,9 +164,12 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose
val, ok := k.scopePurposeCache.Load(scopeId + purpose.String())
if ok {
wrapper := val.(*multiwrapper.MultiWrapper)
if opts.withKeyId == "" || wrapper.WrapperForKeyID(opts.withKeyId) != nil {
if opts.withKeyId == "" {
return wrapper, nil
}
if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil {
return keyIdWrapper, nil
}
// Fall through to refresh our multiwrapper for this scope/purpose from the DB
}
@ -192,6 +191,13 @@ func (k *Kms) GetWrapper(ctx context.Context, scopeId string, purpose KeyPurpose
}
k.scopePurposeCache.Store(scopeId+purpose.String(), wrapper)
if opts.withKeyId != "" {
if keyIdWrapper := wrapper.WrapperForKeyID(opts.withKeyId); keyIdWrapper != nil {
return keyIdWrapper, nil
}
return nil, errors.New(errors.KeyNotFound, op, "unable to find specified key ID")
}
return wrapper, nil
}
@ -299,6 +305,8 @@ func (k *Kms) loadDek(ctx context.Context, scopeId string, purpose KeyPurpose, r
keys, err = repo.ListTokenKeys(ctx)
case KeyPurposeSessions:
keys, err = repo.ListSessionKeys(ctx)
default:
return nil, errors.New(errors.InvalidParameter, op, "unknown or invalid DEK purpose specified")
}
if err != nil {
return nil, errors.Wrap(err, op, errors.WithMsg("error listing root keys"))

@ -0,0 +1,127 @@
package kms_test
import (
"context"
"encoding/base64"
"fmt"
"strings"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// NOTE: This is a sequential test that relies on the actions that have come
// before. Please see the comments for details.
func TestKms(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
extWrapper := db.TestWrapper(t)
badExtWrapper := db.TestWrapper(t)
repo, err := kms.NewRepository(rw, rw)
require.NoError(t, err)
kmsCache := kms.TestKms(t, conn, extWrapper)
org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, extWrapper))
// Verify that the cache is empty, so we can show that by the end of the
// test sequence we did actually look up keys and store them in the cache
t.Run("verify cache empty", func(t *testing.T) {
var count int
kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
})
// Verify that the root keys are all in the database and can be decrypted
// with the correct wrapper from the KMS object
t.Run("verify root keys", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rootKeys, err := repo.ListRootKeys(ctx)
require.NoError(err)
wrappers := kmsCache.GetExternalWrappers()
for _, key := range rootKeys {
kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId())
require.NoError(err)
assert.Len(kvs, 1)
assert.Len(kvs[0].GetKey(), 32)
}
})
// Verify that the wrong wrapper causes decryption to fail
t.Run("bad external keys", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rootKeys, err := repo.ListRootKeys(ctx)
require.NoError(err)
for _, key := range rootKeys {
_, err := repo.ListRootKeyVersions(ctx, badExtWrapper, key.GetPrivateId())
require.Error(err)
assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error())
}
})
// This next sequence is run twice to ensure that calling for the keys twice
// returns the same value each time and doesn't simply populate more keys
// into the KMS object
keyBytes := map[string]bool{}
keyIds := map[string]bool{}
scopePurposeMap := map[string]interface{}{}
for i := 1; i < 3; i++ {
// This iterates through wrappers for all three scopes and four purposes,
// ensuring that the key bytes and IDs are different for each of them,
// simulating calling the KMS object from different scopes for different
// purposes and ensuring the keys are different when that happens.
t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} {
for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} {
wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose)
if purpose == kms.KeyPurposeUnknown {
require.Error(err)
continue
}
require.NoError(err)
multi, ok := wrapper.(*multiwrapper.MultiWrapper)
require.True(ok)
aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper)
require.True(ok)
foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())]
foundKeyId := keyIds[aeadWrapper.KeyID()]
if i == 1 {
assert.False(foundKeyBytes)
assert.False(foundKeyId)
keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true
keyIds[aeadWrapper.KeyID()] = true
} else {
assert.True(foundKeyBytes)
assert.True(foundKeyId)
}
}
}
})
// Verify that the cache has been populated with unique values. The
// second time we validate that the items we find when going through the
// cache a second time are the same as the first. If they were recreated
// the pointers would be different.
t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) {
var count int
kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool {
count++
if i == 1 {
scopePurposeMap[key.(string)] = value
} else {
assert.Same(t, scopePurposeMap[key.(string)], value)
}
return true
})
// four purposes x 3 scopes
assert.Equal(t, 12, count)
})
}
}

@ -1,127 +1,68 @@
package kms_test
package kms
import (
"context"
"encoding/base64"
"fmt"
"strings"
"crypto/rand"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/hashicorp/go-kms-wrapping/wrappers/multiwrapper"
"github.com/stretchr/testify/assert"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/require"
)
// NOTE: This is a sequential test that relies on the actions that have come
// before. Please see the comments for details.
func TestKms(t *testing.T) {
func TestKms_KeyId(t *testing.T) {
t.Parallel()
require := require.New(t)
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
badWrapper := db.TestWrapper(t)
repo, err := kms.NewRepository(rw, rw)
require.NoError(t, err)
kmsCache := kms.TestKms(t, conn, wrapper)
org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
extWrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw)
require.NoError(err)
// Verify that the cache is empty, so we can show that by the end of the
// test sequence we did actually look up keys and store them in the cache
t.Run("verify cache empty", func(t *testing.T) {
var count int
kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool {
count++
return true
})
assert.Equal(t, 0, count)
})
// Verify that the root keys are all in the database and can be decrypted
// with the correct wrapper from the KMS object
t.Run("verify root keys", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rootKeys, err := repo.ListRootKeys(ctx)
require.NoError(err)
wrappers := kmsCache.GetExternalWrappers()
for _, key := range rootKeys {
kvs, err := repo.ListRootKeyVersions(ctx, wrappers.Root(), key.GetPrivateId())
require.NoError(err)
assert.Len(kvs, 1)
assert.Len(kvs[0].GetKey(), 32)
}
})
// Verify that the wrong wrapper causes decryption to fail
t.Run("bad external keys", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
rootKeys, err := repo.ListRootKeys(ctx)
require.NoError(err)
for _, key := range rootKeys {
_, err := repo.ListRootKeyVersions(ctx, badWrapper, key.GetPrivateId())
require.Error(err)
assert.True(strings.Contains(err.Error(), "message authentication failed"), err.Error())
}
})
// This next sequence is run twice to ensure that calling for the keys twice
// returns the same value each time and doesn't simply populate more keys
// into the KMS object
keyBytes := map[string]bool{}
keyIds := map[string]bool{}
scopePurposeMap := map[string]interface{}{}
for i := 1; i < 3; i++ {
// This iterates through wrappers for all three scopes and four purposes,
// ensuring that the key bytes and IDs are different for each of them,
// simulating calling the KMS object from different scopes for different
// purposes and ensuring the keys are different when that happens.
t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} {
for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} {
wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose)
if purpose == kms.KeyPurposeUnknown {
require.Error(err)
continue
}
require.NoError(err)
multi, ok := wrapper.(*multiwrapper.MultiWrapper)
require.True(ok)
aeadWrapper, ok := multi.WrapperForKeyID(multi.KeyID()).(*aead.Wrapper)
require.True(ok)
foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())]
foundKeyId := keyIds[aeadWrapper.KeyID()]
if i == 1 {
assert.False(foundKeyBytes)
assert.False(foundKeyId)
keyBytes[base64.StdEncoding.EncodeToString(aeadWrapper.GetKeyBytes())] = true
keyIds[aeadWrapper.KeyID()] = true
} else {
assert.True(foundKeyBytes)
assert.True(foundKeyId)
}
}
}
})
// Verify that the cache has been populated with unique values. The
// second time we validate that the items we find when going through the
// cache a second time are the same as the first. If they were recreated
// the pointers would be different.
t.Run(fmt.Sprintf("verify cache populated x %d", i), func(t *testing.T) {
var count int
kmsCache.GetScopePurposeCache().Range(func(key interface{}, value interface{}) bool {
count++
if i == 1 {
scopePurposeMap[key.(string)] = value
} else {
assert.Same(t, scopePurposeMap[key.(string)], value)
}
return true
})
// four purposes x 3 scopes
assert.Equal(t, 12, count)
})
}
// Make the global scope base keys
_, err = CreateKeysTx(ctx, rw, rw, extWrapper, rand.Reader, scope.Global.String())
require.NoError(err)
// Get the global scope's root wrapper
kmsCache, err := NewKms(repo)
require.NoError(err)
require.NoError(kmsCache.AddExternalWrappers(WithRootWrapper(extWrapper)))
globalRootWrapper, _, err := kmsCache.loadRoot(ctx, scope.Global.String())
require.NoError(err)
dks, err := repo.ListDatabaseKeys(ctx)
require.NoError(err)
require.Len(dks, 1)
// Create another key version
newKeyBytes, err := uuid.GenerateRandomBytes(32)
require.NoError(err)
_, err = repo.CreateDatabaseKeyVersion(ctx, globalRootWrapper, dks[0].GetPrivateId(), newKeyBytes)
require.NoError(err)
dkvs, err := repo.ListDatabaseKeyVersions(ctx, globalRootWrapper, dks[0].GetPrivateId())
require.NoError(err)
require.Len(dkvs, 2)
keyId1 := dkvs[0].GetPrivateId()
keyId2 := dkvs[1].GetPrivateId()
// First test: just getting the key should return the latest
wrapper, err := kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase)
require.NoError(err)
require.Equal(keyId2, wrapper.KeyID())
// Second: ask for each in turn
wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId1))
require.NoError(err)
require.Equal(keyId1, wrapper.KeyID())
wrapper, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId(keyId2))
require.NoError(err)
require.Equal(keyId2, wrapper.KeyID())
// Last: verify something bogus finds nothing
_, err = kmsCache.GetWrapper(ctx, scope.Global.String(), KeyPurposeDatabase, WithKeyId("foo"))
require.Error(err)
}

@ -72,7 +72,7 @@ func createRootKeyTx(ctx context.Context, w db.Writer, keyWrapper wrapping.Wrapp
}
// no oplog entries for root key versions
if err := w.Create(ctx, &kv); err != nil {
return nil, nil, errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("key versions")))
return nil, nil, errors.Wrap(err, op, errors.WithMsg("key versions"))
}
return &rk, &kv, nil

@ -219,7 +219,7 @@ func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.Looku
resp.ConnectionsLeft -= int32(authzSummary.CurrentConnectionCount)
}
wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions)
wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions, kms.WithKeyId(sessionInfo.KeyId))
if err != nil {
return nil, status.Errorf(codes.Internal, "Error getting sessions wrapper: %v", err)
}

@ -76,6 +76,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
}
newSession.Certificate = certBytes
newSession.PublicId = id
newSession.KeyId = sessionWrapper.KeyID()
var returnedSession *Session
_, err = r.writer.DoTx(

@ -344,8 +344,10 @@ func TestRepository_CreateSession(t *testing.T) {
assert.NotNil(ses.CreateTime)
assert.NotNil(ses.States[0].StartTime)
assert.Equal(ses.States[0].Status, StatusPending)
assert.Equal(wrapper.KeyID(), ses.KeyId)
foundSession, _, err := repo.LookupSession(context.Background(), ses.PublicId)
assert.NoError(err)
assert.Equal(wrapper.KeyID(), foundSession.KeyId)
// Account for slight offsets in nanos
assert.True(foundSession.ExpirationTime.Timestamp.AsTime().Sub(ses.ExpirationTime.Timestamp.AsTime()) < time.Second)

@ -160,6 +160,7 @@ func (s *Session) Clone() interface{} {
Endpoint: s.Endpoint,
ConnectionLimit: s.ConnectionLimit,
WorkerFilter: s.WorkerFilter,
KeyId: s.KeyId,
}
if len(s.States) > 0 {
clone.States = make([]*State, 0, len(s.States))

Loading…
Cancel
Save