You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/kms/kms_ext_test.go

370 lines
11 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package kms_test
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"strings"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
aead "github.com/hashicorp/go-kms-wrapping/v2/aead"
"github.com/hashicorp/go-kms-wrapping/v2/extras/multi"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// NOTE: This is a sequential test that relies on the actions that have come
// before. Please see the comments for details.
func TestKms(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
extWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, extWrapper)
org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, extWrapper))
// This next sequence is run twice to ensure that calling for the keys twice
// returns the same value each time and doesn't simply populate more keys
// into the KMS object
keyBytes := map[string]bool{}
keyIds := map[string]bool{}
for i := 1; i < 3; i++ {
// This iterates through wrappers for all three scopes and four purposes,
// ensuring that the key bytes and IDs are different for each of them,
// simulating calling the KMS object from different scopes for different
// purposes and ensuring the keys are different when that happens.
t.Run(fmt.Sprintf("verify wrappers different x %d", i), func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
for _, scopeId := range []string{scope.Global.String(), org.GetPublicId(), proj.GetPublicId()} {
for _, purpose := range []kms.KeyPurpose{kms.KeyPurposeUnknown, kms.KeyPurposeOplog, kms.KeyPurposeDatabase, kms.KeyPurposeSessions, kms.KeyPurposeTokens} {
wrapper, err := kmsCache.GetWrapper(ctx, scopeId, purpose)
if purpose == kms.KeyPurposeUnknown {
require.Error(err)
continue
}
require.NoError(err)
multi, ok := wrapper.(*multi.PooledWrapper)
require.True(ok)
mKeyId, err := multi.KeyId(ctx)
require.NoError(err)
aeadWrapper, ok := multi.WrapperForKeyId(mKeyId).(*aead.Wrapper)
require.True(ok)
aeadKeyId, err := aeadWrapper.KeyId(ctx)
require.NoError(err)
wrapperBytes, err := aeadWrapper.KeyBytes(ctx)
require.NoError(err)
foundKeyBytes := keyBytes[base64.StdEncoding.EncodeToString(wrapperBytes)]
foundKeyId := keyIds[aeadKeyId]
if i == 1 {
assert.False(foundKeyBytes)
assert.False(foundKeyId)
wrapperBytes, err := aeadWrapper.KeyBytes(ctx)
require.NoError(err)
keyBytes[base64.StdEncoding.EncodeToString(wrapperBytes)] = true
keyIds[aeadKeyId] = true
} else {
assert.True(foundKeyBytes)
assert.True(foundKeyId)
}
}
}
})
}
}
func TestKms_ReconcileKeys(t *testing.T) {
t.Parallel()
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
org2, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
tests := []struct {
name string
kms *kms.Kms
scopeIds []string
reader io.Reader
setup func(*kms.Kms)
wantPurpose []kms.KeyPurpose
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-reader",
kms: kms.TestKms(t, conn, wrapper),
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing rand reader",
},
{
name: "reader-interface-is-nil",
kms: kms.TestKms(t, conn, wrapper),
reader: func() io.Reader { var sr *strings.Reader; var r io.Reader = sr; return r }(),
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing rand reader",
},
{
name: "nothing-to-reconcile",
kms: kms.TestKms(t, conn, wrapper),
reader: rand.Reader,
wantErr: false,
},
{
name: "reconcile-audit-key",
kms: kms.TestKms(t, conn, wrapper),
reader: rand.Reader,
setup: func(k *kms.Kms) {
kms.TestKmsDeleteKeyPurpose(t, conn, kms.KeyPurposeAudit)
_, err := k.GetWrapper(testCtx, scope.Global.String(), kms.KeyPurposeAudit)
require.Error(t, err)
},
wantPurpose: []kms.KeyPurpose{kms.KeyPurposeAudit},
wantErr: false,
},
{
name: "reconcile-oidc-key-multiple-scopes",
kms: kms.TestKms(t, conn, wrapper),
scopeIds: []string{org.PublicId, org2.PublicId},
reader: rand.Reader,
setup: func(k *kms.Kms) {
// create initial keys for the test scope ids...
for _, id := range []string{org.PublicId, org2.PublicId} {
err := k.CreateKeys(testCtx, id)
require.NoError(t, err)
}
kms.TestKmsDeleteKeyPurpose(t, conn, kms.KeyPurposeOidc)
// make sure the kms is in the proper state for the unit test
// before proceeding.
for _, id := range []string{org.PublicId, org2.PublicId} {
_, err := k.GetWrapper(testCtx, id, kms.KeyPurposeOidc)
require.Error(t, err)
}
},
wantPurpose: []kms.KeyPurpose{kms.KeyPurposeOidc},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// start with no keys...
kms.TestKmsDeleteAllKeys(t, conn)
// create initial keys for the global scope...
err := tt.kms.CreateKeys(context.Background(), scope.Global.String())
require.NoError(err)
if tt.setup != nil {
tt.setup(tt.kms)
}
err = tt.kms.ReconcileKeys(testCtx, tt.reader, kms.WithScopeIds(tt.scopeIds...))
if tt.wantErr {
assert.Error(err)
if tt.wantErrMatch != nil {
assert.Truef(errors.Match(tt.wantErrMatch, err), "expected %q and got err: %+v", tt.wantErrMatch.Code, err)
}
if tt.wantErrContains != "" {
assert.True(strings.Contains(err.Error(), tt.wantErrContains))
}
return
}
assert.NoError(err)
if len(tt.scopeIds) > 0 {
for _, id := range tt.scopeIds {
for _, p := range tt.wantPurpose {
_, err := tt.kms.GetWrapper(testCtx, id, p)
require.NoError(err)
}
}
}
_, err = tt.kms.GetWrapper(testCtx, scope.Global.String(), kms.KeyPurposeAudit)
require.NoError(err)
})
}
}
func TestKms_GetDerivedPurposeCache(t *testing.T) {
t.Parallel()
assert := assert.New(t)
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
derivedWrapper := db.TestWrapper(t)
kmsCache.GetDerivedPurposeCache().Store(1, derivedWrapper)
v, ok := kmsCache.GetDerivedPurposeCache().Load(1)
assert.True(ok)
assert.Equal(derivedWrapper, v)
}
func TestKms_VerifyGlobalRoot(t *testing.T) {
t.Parallel()
testCtx := context.Background()
assert, require := assert.New(t), require.New(t)
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
assert.Error(kmsCache.VerifyGlobalRoot(testCtx))
require.NoError(kmsCache.CreateKeys(testCtx, "global"))
assert.NoError(kmsCache.VerifyGlobalRoot(testCtx))
}
func TestKms_GetWrapper(t *testing.T) {
t.Parallel()
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
require.NoError(t, kmsCache.CreateKeys(testCtx, "global"))
tests := []struct {
name string
kms *kms.Kms
purpose kms.KeyPurpose
scopeId string
opt []kms.Option
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-purpose",
kms: kmsCache,
scopeId: "global",
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing purpose",
},
{
name: "missing-scope-id",
kms: kmsCache,
purpose: kms.KeyPurposeDatabase,
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing scope id",
},
{
name: "success",
kms: kmsCache,
purpose: kms.KeyPurposeDatabase,
scopeId: "global",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := tc.kms.GetWrapper(testCtx, tc.scopeId, tc.purpose, tc.opt...)
if tc.wantErr {
require.Error(err)
if tc.wantErrMatch != nil {
assert.Truef(errors.Match(tc.wantErrMatch, err), "expected %q and got err: %+v", tc.wantErrMatch.Code, err)
}
if tc.wantErrContains != "" {
assert.True(strings.Contains(err.Error(), tc.wantErrContains))
}
return
}
require.NoError(err)
assert.NotNil(got)
})
}
}
func TestKms_CreateKeys(t *testing.T) {
t.Parallel()
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
require.NoError(t, kmsCache.CreateKeys(testCtx, "global"))
rw := db.New(conn)
tests := []struct {
name string
kms *kms.Kms
scopeId string
opt []kms.Option
wantErr bool
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-scope-id",
kms: kmsCache,
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing scope id",
},
{
name: "invalid-scope",
kms: kmsCache,
scopeId: "o_1234567890",
wantErr: true,
wantErrContains: "violates foreign key constraint",
},
{
name: "missing-writer-opt",
kms: kmsCache,
scopeId: "global",
opt: []kms.Option{kms.WithReaderWriter(rw, nil)},
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing writer",
},
{
name: "missing-reader-opt",
kms: kmsCache,
scopeId: "global",
opt: []kms.Option{kms.WithReaderWriter(nil, rw)},
wantErr: true,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing reader",
},
{
name: "success",
kms: kmsCache,
scopeId: "global",
},
{
name: "success-with-reader-writer",
kms: kmsCache,
opt: []kms.Option{kms.WithReaderWriter(rw, rw)},
scopeId: "global",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
kms.TestKmsDeleteAllKeys(t, conn)
err := tc.kms.CreateKeys(testCtx, tc.scopeId, tc.opt...)
if tc.wantErr {
require.Error(err)
if tc.wantErrMatch != nil {
assert.Truef(errors.Match(tc.wantErrMatch, err), "expected %q and got err: %+v", tc.wantErrMatch.Code, err)
}
if tc.wantErrContains != "" {
assert.True(strings.Contains(err.Error(), tc.wantErrContains))
}
return
}
require.NoError(err)
})
}
}