mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
972 lines
37 KiB
972 lines
37 KiB
// Copyright IBM Corp. 2020, 2025
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package kms
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
wrappingKms "github.com/hashicorp/go-kms-wrapping/extras/kms/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_New(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
r *db.Db
|
|
w *db.Db
|
|
want *Kms
|
|
wantErr bool
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "nil-reader",
|
|
w: rw,
|
|
wantErr: true,
|
|
wantErrContains: "missing reader",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "nil-writer",
|
|
r: rw,
|
|
wantErr: true,
|
|
wantErrContains: "missing writer",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "success",
|
|
r: rw,
|
|
w: rw,
|
|
want: &Kms{
|
|
reader: rw,
|
|
writer: rw,
|
|
underlying: func() *wrappingKms.Kms {
|
|
purposes := make([]wrappingKms.KeyPurpose, 0, len(ValidDekPurposes()))
|
|
for _, p := range ValidDekPurposes() {
|
|
switch p {
|
|
case KeyPurposeOplog:
|
|
continue
|
|
default:
|
|
purposes = append(purposes, wrappingKms.KeyPurpose(p.String()))
|
|
}
|
|
}
|
|
purposes = append(purposes,
|
|
wrappingKms.KeyPurpose(KeyPurposeWorkerAuth.String()),
|
|
wrappingKms.KeyPurpose(KeyPurposeWorkerAuthStorage.String()),
|
|
wrappingKms.KeyPurpose(KeyPurposeRecovery.String()),
|
|
wrappingKms.KeyPurpose(KeyPurposeBsr.String()),
|
|
)
|
|
|
|
wrapped, err := wrappingKms.New(db.NewChangeSafeDbwReader(rw), db.NewChangeSafeDbwWriter(rw), purposes)
|
|
require.NoError(t, err)
|
|
return wrapped
|
|
}(),
|
|
underlyingForOplog: func() *wrappingKms.Kms {
|
|
wrapped, err := wrappingKms.New(
|
|
db.NewChangeSafeDbwReader(rw),
|
|
db.NewChangeSafeDbwWriter(rw),
|
|
[]wrappingKms.KeyPurpose{wrappingKms.KeyPurpose(KeyPurposeOplog.String())},
|
|
wrappingKms.WithTableNamePrefix("kms_oplog"),
|
|
)
|
|
require.NoError(t, err)
|
|
return wrapped
|
|
}(),
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
got, err := New(testCtx, tc.r, tc.w)
|
|
if tc.wantErr {
|
|
require.Error(err)
|
|
if tc.wantErrMatch != nil {
|
|
assert.Truef(errors.Match(tc.wantErrMatch, err), "expected %q and got err: %+v", tc.wantErrMatch.Code, err)
|
|
}
|
|
if tc.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tc.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tc.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_NewUsingReaderWriter(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
|
|
tests := []struct {
|
|
name string
|
|
r db.Reader
|
|
w db.Writer
|
|
want *Kms
|
|
wantErr bool
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "nil-reader",
|
|
w: rw,
|
|
wantErr: true,
|
|
wantErrContains: "missing reader",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "nil-writer",
|
|
r: rw,
|
|
wantErr: true,
|
|
wantErrContains: "missing writer",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "invalid-reader",
|
|
r: &invalidReader{},
|
|
w: rw,
|
|
wantErr: true,
|
|
wantErrContains: "unable to convert reader to db.Db",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "invalid-writer",
|
|
r: rw,
|
|
w: &invalidWriter{},
|
|
wantErr: true,
|
|
wantErrContains: "unable to convert writer to db.Db",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "success",
|
|
r: rw,
|
|
w: rw,
|
|
want: &Kms{
|
|
reader: rw,
|
|
underlying: func() *wrappingKms.Kms {
|
|
purposes := stdNewKmsPurposes()
|
|
wrapped, err := wrappingKms.New(db.NewChangeSafeDbwReader(rw), db.NewChangeSafeDbwWriter(rw), purposes)
|
|
require.NoError(t, err)
|
|
return wrapped
|
|
}(),
|
|
underlyingForOplog: func() *wrappingKms.Kms {
|
|
wrapped, err := wrappingKms.New(
|
|
db.NewChangeSafeDbwReader(rw),
|
|
db.NewChangeSafeDbwWriter(rw),
|
|
[]wrappingKms.KeyPurpose{wrappingKms.KeyPurpose(KeyPurposeOplog.String())},
|
|
wrappingKms.WithTableNamePrefix("kms_oplog"),
|
|
)
|
|
require.NoError(t, err)
|
|
return wrapped
|
|
}(),
|
|
},
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
got, err := NewUsingReaderWriter(testCtx, tc.r, tc.w)
|
|
if tc.wantErr {
|
|
require.Error(err)
|
|
if tc.wantErrMatch != nil {
|
|
assert.Truef(errors.Match(tc.wantErrMatch, err), "expected %q and got err: %+v", tc.wantErrMatch.Code, err)
|
|
}
|
|
if tc.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tc.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tc.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_ListKeys(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
t.Run("success", func(t *testing.T) {
|
|
t.Parallel()
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
require.Len(t, keys, len(ValidDekPurposes())+1)
|
|
})
|
|
t.Run("unknown-scope", func(t *testing.T) {
|
|
t.Parallel()
|
|
_, err := kmsCache.ListKeys(testCtx, "myscope")
|
|
assert.True(t, errors.IsNotFoundError(err))
|
|
})
|
|
t.Run("empty-scope", func(t *testing.T) {
|
|
t.Parallel()
|
|
_, err := kmsCache.ListKeys(testCtx, "")
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func Test_RotateKeys(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
rw := db.New(conn)
|
|
|
|
t.Run("success", func(t *testing.T) {
|
|
// arrange
|
|
// we're not trying to test the ListKeys function, although we need to use it to validate the rotation
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
require.Len(t, keys, len(ValidDekPurposes())+1)
|
|
|
|
// act
|
|
err = kmsCache.RotateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
|
|
// assert
|
|
keys, err = kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
|
|
keyCount := 0
|
|
for _, key := range keys {
|
|
keyCount += len(key.Versions)
|
|
}
|
|
|
|
require.Equal(t, 14, keyCount)
|
|
})
|
|
|
|
t.Run("reader provided, missing writer", func(t *testing.T) {
|
|
// arrange
|
|
WithReader := func(reader db.Reader) Option {
|
|
return func(o *options) {
|
|
o.withReader = reader
|
|
}
|
|
}
|
|
|
|
// act
|
|
err := kmsCache.RotateKeys(testCtx, "global", WithReader(rw))
|
|
|
|
// assert
|
|
assert.ErrorContains(t, err, "missing writer")
|
|
})
|
|
|
|
t.Run("writer provided, missing reader", func(t *testing.T) {
|
|
// arrange
|
|
WithWriter := func(writer db.Writer) Option {
|
|
return func(o *options) {
|
|
o.withWriter = writer
|
|
}
|
|
}
|
|
|
|
// act
|
|
err := kmsCache.RotateKeys(testCtx, "global", WithWriter(rw))
|
|
|
|
// assert
|
|
assert.ErrorContains(t, err, "missing reader")
|
|
})
|
|
|
|
t.Run("invalid reader", func(t *testing.T) {
|
|
// act
|
|
err := kmsCache.RotateKeys(testCtx, "global", WithReaderWriter(&invalidReader{}, rw))
|
|
|
|
// assert
|
|
assert.ErrorContains(t, err, "unable to convert reader")
|
|
})
|
|
|
|
t.Run("invalid writer", func(t *testing.T) {
|
|
// act
|
|
err := kmsCache.RotateKeys(testCtx, "global", WithReaderWriter(rw, &invalidWriter{}))
|
|
|
|
// assert
|
|
assert.ErrorContains(t, err, "unable to convert writer")
|
|
})
|
|
|
|
t.Run("both reader and writer succeed", func(t *testing.T) {
|
|
err := kmsCache.RotateKeys(testCtx, "global", WithReaderWriter(rw, rw))
|
|
require.NoError(t, err)
|
|
})
|
|
|
|
// other options are passed directly and shouldn't need to be tested
|
|
}
|
|
|
|
func Test_ListDataKeyVersionReferencers(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
t.Run("returns-something", func(t *testing.T) {
|
|
referencers, err := kmsCache.ListDataKeyVersionReferencers(testCtx)
|
|
require.NoError(t, err)
|
|
// Don't want to check the length or explicit tables
|
|
// as this will make the test fail whenever we add or
|
|
// remove a reference. Lets just sanity check that it's
|
|
// got at least one element in it.
|
|
assert.NotEmpty(t, referencers)
|
|
})
|
|
t.Run("errors when just a writer is specified", func(t *testing.T) {
|
|
withWriter := func(writer db.Writer) Option {
|
|
return func(o *options) {
|
|
o.withWriter = writer
|
|
}
|
|
}
|
|
_, err := kmsCache.ListDataKeyVersionReferencers(testCtx, withWriter(rw))
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("errors when just a reader is specified", func(t *testing.T) {
|
|
withReader := func(reader db.Reader) Option {
|
|
return func(o *options) {
|
|
o.withReader = reader
|
|
}
|
|
}
|
|
_, err := kmsCache.ListDataKeyVersionReferencers(testCtx, withReader(rw))
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("errors when an invalid writer is specified", func(t *testing.T) {
|
|
_, err := kmsCache.ListDataKeyVersionReferencers(testCtx, WithReaderWriter(rw, &invalidWriter{}))
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("errors when an invalid reader is specified", func(t *testing.T) {
|
|
_, err := kmsCache.ListDataKeyVersionReferencers(testCtx, WithReaderWriter(&invalidReader{}, rw))
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("succeeds-with-reader-writer", func(t *testing.T) {
|
|
referencers, err := kmsCache.ListDataKeyVersionReferencers(testCtx, WithReaderWriter(rw, rw))
|
|
require.NoError(t, err)
|
|
// Don't want to check the length or explicit tables
|
|
// as this will make the test fail whenever we add or
|
|
// remove a reference. Lets just sanity check that it's
|
|
// got at least one element in it.
|
|
assert.NotEmpty(t, referencers)
|
|
})
|
|
}
|
|
|
|
func Test_ListDataKeyVersionDestructionJobs(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
err = kmsCache.RotateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
|
|
t.Run("errors-on-empty-scope-ids", func(t *testing.T) {
|
|
_, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "")
|
|
assert.Error(t, err)
|
|
})
|
|
t.Run("lists-no-jobs-when-the-scope-id-is-invalid", func(t *testing.T) {
|
|
jobs, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "notascope")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, jobs)
|
|
})
|
|
t.Run("lists-no-jobs-when-there-are-none", func(t *testing.T) {
|
|
jobs, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "global")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, jobs)
|
|
})
|
|
t.Run("lists-jobs-when-there-are-some", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
sqldb, err := conn.SqlDB(testCtx)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_oidc_method', 200)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "update kms_data_key_version_destruction_job_run set is_running=true where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
jobs, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "global")
|
|
require.NoError(t, err)
|
|
require.Len(t, jobs, 1)
|
|
job := jobs[0]
|
|
assert.Equal(t, 0, int(job.CompletedCount))
|
|
assert.Equal(t, 300, int(job.TotalCount))
|
|
assert.Equal(t, "running", job.Status)
|
|
assert.Equal(t, kvToDestroy.Id, job.KeyId)
|
|
assert.Equal(t, "global", job.ScopeId)
|
|
})
|
|
t.Run("lists-no-jobs-when-given-unknown-scope", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
sqldb, err := conn.SqlDB(testCtx)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_oidc_method', 200)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "update kms_data_key_version_destruction_job_run set is_running=true where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
jobs, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "myscope")
|
|
require.NoError(t, err)
|
|
assert.Empty(t, jobs)
|
|
})
|
|
}
|
|
|
|
func TestMonitorTableRewrappingRuns(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
err = kmsCache.RotateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
sqldb, err := conn.SqlDB(testCtx)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("does-nothing-when-no-run-available", func(t *testing.T) {
|
|
callbackCalled := false
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
callbackCalled = true
|
|
return nil
|
|
}
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
require.NoError(t, err)
|
|
assert.False(t, callbackCalled, "auth_token callback should not have been called")
|
|
})
|
|
t.Run("returns-an-error-when-given-empty-table-name", func(t *testing.T) {
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "")
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("returns-an-error-when-no-rewrapping-function-registered", func(t *testing.T) {
|
|
tableNameToRewrapFn = make(map[string]RewrapFn)
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("does-nothing-when-another-run-is-running", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_oidc_method', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "update kms_data_key_version_destruction_job_run set is_running=true where key_id=$1 and table_name='auth_oidc_method'", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
callbackCalled := false
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
callbackCalled = true
|
|
return nil
|
|
}
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
require.NoError(t, err)
|
|
assert.False(t, callbackCalled, "auth_token callback should not have been called")
|
|
})
|
|
t.Run("chooses-one-when-one-run-available", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
callbackCalled := make(chan struct{})
|
|
returnFromCallback := make(chan struct{})
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
close(callbackCalled)
|
|
assert.Equal(t, "global", scopeId)
|
|
// Block here until we want it to return
|
|
<-returnFromCallback
|
|
return nil
|
|
}
|
|
monitorErrCh := make(chan error)
|
|
// Run in goroutine so we can check status while the job is running
|
|
go func() {
|
|
monitorErrCh <- kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
}()
|
|
// Wait for callback to have been called
|
|
select {
|
|
case <-callbackCalled:
|
|
case err := <-monitorErrCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Callback had not been called after 5 seconds")
|
|
return
|
|
}
|
|
// Now we know that the job is waiting for the callback to return, lets
|
|
// do some db inspection
|
|
row := sqldb.QueryRowContext(testCtx, "select is_running from kms_data_key_version_destruction_job_run where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
isRunning := false
|
|
err = row.Scan(&isRunning)
|
|
require.NoError(t, err)
|
|
assert.True(t, isRunning, "is_running should be set to true")
|
|
// Trigger callback to return
|
|
close(returnFromCallback)
|
|
// Wait for function to return
|
|
select {
|
|
case err := <-monitorErrCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Function has not returned 5 seconds after callback finished")
|
|
return
|
|
}
|
|
// Lets look at the db state again after the function has returned
|
|
row = sqldb.QueryRowContext(testCtx, "select completed_count, is_running from kms_data_key_version_destruction_job_run where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
completedCount := 0
|
|
isRunning = false
|
|
err = row.Scan(&completedCount, &isRunning)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, completedCount, 100, "completed_count should have been updated")
|
|
assert.False(t, isRunning, "is_running should be set to false")
|
|
})
|
|
t.Run("chooses-oldest-when-two-runs-available", func(t *testing.T) {
|
|
var kvsToDestroy []wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
switch key.Purpose {
|
|
case wrappingKms.KeyPurpose(KeyPurposeDatabase.String()), wrappingKms.KeyPurpose(KeyPurposeTokens.String()):
|
|
kvsToDestroy = append(kvsToDestroy, key.Versions[0])
|
|
}
|
|
}
|
|
require.Len(t, kvsToDestroy, 2)
|
|
for _, kvToDestroy := range kvsToDestroy {
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
}
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
rewrappedKeyVersion := ""
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
rewrappedKeyVersion = dataKeyVersionId
|
|
return nil
|
|
}
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, rewrappedKeyVersion, kvsToDestroy[0].Id, "auth_token callback should have been called with the oldest job")
|
|
})
|
|
t.Run("resumes-running-one-when-two-runs-available-even-if-not-oldest", func(t *testing.T) {
|
|
var kvsToDestroy []wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
switch key.Purpose {
|
|
case wrappingKms.KeyPurpose(KeyPurposeDatabase.String()), wrappingKms.KeyPurpose(KeyPurposeTokens.String()):
|
|
kvsToDestroy = append(kvsToDestroy, key.Versions[0])
|
|
}
|
|
}
|
|
require.Len(t, kvsToDestroy, 2)
|
|
for i, kvToDestroy := range kvsToDestroy {
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
if i == 1 {
|
|
_, err = sqldb.ExecContext(testCtx, "update kms_data_key_version_destruction_job_run set is_running=true where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
rewrappedKeyVersion := ""
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
rewrappedKeyVersion = dataKeyVersionId
|
|
return nil
|
|
}
|
|
err = kmsCache.MonitorTableRewrappingRuns(testCtx, "auth_token")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, rewrappedKeyVersion, kvsToDestroy[1].Id, "auth_token callback should have been called with the already running job")
|
|
})
|
|
t.Run("updates-the-running-state-even-when-context-canceled", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
callbackCalled := make(chan struct{})
|
|
tableNameToRewrapFn["auth_token"] = func(ctx context.Context, dataKeyVersionId string, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
close(callbackCalled)
|
|
// Block here until we want it to return
|
|
<-ctx.Done()
|
|
return ctx.Err()
|
|
}
|
|
monitorErrCh := make(chan error)
|
|
newCtx, cancel := context.WithCancel(testCtx)
|
|
defer cancel()
|
|
// Run in goroutine so we can check status while the job is running
|
|
go func() {
|
|
monitorErrCh <- kmsCache.MonitorTableRewrappingRuns(newCtx, "auth_token")
|
|
}()
|
|
// Wait for callback to have been called
|
|
select {
|
|
case <-callbackCalled:
|
|
case err := <-monitorErrCh:
|
|
require.NoError(t, err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Callback had not been called after 5 seconds")
|
|
return
|
|
}
|
|
// Now we know that the job is waiting for the callback to return, lets
|
|
// do some db inspection
|
|
row := sqldb.QueryRowContext(testCtx, "select is_running from kms_data_key_version_destruction_job_run where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
isRunning := false
|
|
err = row.Scan(&isRunning)
|
|
require.NoError(t, err)
|
|
assert.True(t, isRunning, "is_running should be set to true")
|
|
// Trigger callback to return
|
|
cancel()
|
|
// Wait for function to return
|
|
select {
|
|
case err := <-monitorErrCh:
|
|
require.True(t, errors.Is(err, context.Canceled))
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("Function has not returned 5 seconds after callback finished")
|
|
return
|
|
}
|
|
// Lets look at the db state again after the function has returned
|
|
row = sqldb.QueryRowContext(testCtx, "select completed_count, is_running from kms_data_key_version_destruction_job_run where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
completedCount := 0
|
|
isRunning = false
|
|
err = row.Scan(&completedCount, &isRunning)
|
|
require.NoError(t, err)
|
|
assert.EqualValues(t, completedCount, 100, "completed_count should have been updated")
|
|
assert.False(t, isRunning, "is_running should be set to false")
|
|
})
|
|
}
|
|
|
|
func TestMonitorDataKeyVersionDestruction(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
err = kmsCache.RotateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
sqldb, err := conn.SqlDB(testCtx)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("does-nothing-when-no-jobs-available", func(t *testing.T) {
|
|
err = kmsCache.MonitorDataKeyVersionDestruction(testCtx)
|
|
require.NoError(t, err)
|
|
})
|
|
t.Run("does-nothing-when-the-job-isnt-completed", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
err = kmsCache.MonitorDataKeyVersionDestruction(testCtx)
|
|
require.NoError(t, err)
|
|
row := sqldb.QueryRowContext(testCtx, "select exists(select 1 from kms_data_key_version_destruction_job where key_id=$1)", kvToDestroy.Id)
|
|
exists := false
|
|
err = row.Scan(&exists)
|
|
require.NoError(t, err)
|
|
assert.True(t, exists, "the job should still exist")
|
|
})
|
|
t.Run("deletes-the-key-when-the-job-is-completed", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job(key_id) values ($1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "insert into kms_data_key_version_destruction_job_run(key_id, table_name, total_count) values ($1, 'auth_token', 100)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
_, err = sqldb.ExecContext(testCtx, "update kms_data_key_version_destruction_job_run set completed_count=100 where key_id=$1 and table_name='auth_token'", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "truncate kms_data_key_version_destruction_job, kms_data_key_version_destruction_job_run CASCADE")
|
|
require.NoError(t, err)
|
|
})
|
|
err = kmsCache.MonitorDataKeyVersionDestruction(testCtx)
|
|
require.NoError(t, err)
|
|
row := sqldb.QueryRowContext(testCtx, "select exists(select 1 from kms_data_key_version_destruction_job where key_id=$1)", kvToDestroy.Id)
|
|
exists := false
|
|
err = row.Scan(&exists)
|
|
require.NoError(t, err)
|
|
// The job is deleted by virtue of cascading foreign key references
|
|
assert.False(t, exists, "the job should be deleted")
|
|
})
|
|
}
|
|
|
|
func TestDestroyKeyVersion(t *testing.T) {
|
|
t.Parallel()
|
|
testCtx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
extWrapper := db.TestWrapper(t)
|
|
kmsCache := TestKms(t, conn, extWrapper)
|
|
err := kmsCache.CreateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
err = kmsCache.RotateKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
sqldb, err := conn.SqlDB(testCtx)
|
|
require.NoError(t, err)
|
|
|
|
t.Run("returns-an-error-when-given-an-empty-scope-id", func(t *testing.T) {
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "", "krkv_DoesntExist")
|
|
assert.Error(t, err)
|
|
})
|
|
t.Run("returns-an-error-when-given-an-empty-key-version-id", func(t *testing.T) {
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "global", "")
|
|
assert.Error(t, err)
|
|
})
|
|
t.Run("returns-an-error-when-attempting-to-destroy-an-unknown-key-version", func(t *testing.T) {
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "global", "krkv_DoesntExist")
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Match(errors.T(errors.KeyNotFound), err))
|
|
})
|
|
t.Run("returns-an-error-when-attempting-to-destroy-a-key-version-in-an-unknown-scope", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[len(key.Versions)-1]
|
|
}
|
|
}
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "p_DoesntExist", kvToDestroy.Id)
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Match(errors.T(errors.RecordNotFound), err))
|
|
})
|
|
t.Run("returns-an-error-when-attempting-to-destroy-latest-key-version", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[len(key.Versions)-1]
|
|
}
|
|
}
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Match(errors.T(errors.InvalidParameter), err))
|
|
})
|
|
t.Run("returns-an-error-when-attempting-to-destroy-oplog-key-version", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeOplog.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
require.NotEmpty(t, kvToDestroy.Id)
|
|
_, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Match(errors.T(errors.KeyNotFound), err))
|
|
})
|
|
t.Run("succeeds-synchronously-when-destroying-old-root-key-version", func(t *testing.T) {
|
|
var chosenKey wrappingKms.Key
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeRootKey.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
chosenKey = key
|
|
}
|
|
}
|
|
destroyed, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
assert.True(t, destroyed)
|
|
// Check that the key version is destroyed
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
for _, key := range keys {
|
|
if key.Id == chosenKey.Id {
|
|
require.Len(t, key.Versions, 1, "the root key should only have one version left")
|
|
assert.NotEqual(t, key.Versions[0].Id, kvToDestroy.Id)
|
|
}
|
|
}
|
|
})
|
|
t.Run("succeeds-synchronously-when-destroying-old-data-key-version-that-encrypts-no-data", func(t *testing.T) {
|
|
require.NoError(t, err)
|
|
var chosenKey wrappingKms.Key
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeAudit.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
chosenKey = key
|
|
}
|
|
}
|
|
destroyed, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
assert.True(t, destroyed)
|
|
// Check that the key version is destroyed
|
|
keys, err := kmsCache.ListKeys(testCtx, "global")
|
|
require.NoError(t, err)
|
|
for _, key := range keys {
|
|
if key.Id == chosenKey.Id {
|
|
require.Len(t, key.Versions, 1, "the data key should only have one version left")
|
|
assert.NotEqual(t, key.Versions[0].Id, kvToDestroy.Id)
|
|
}
|
|
}
|
|
})
|
|
t.Run("creates-asynchronous-job-and-runs-for-data-key-version-that-encrypt-data", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
// Create a row that references the key we want to destroy.
|
|
// Note: This could break if the underlying table definitions change.
|
|
// Future contributors; forgive me, but we do really want to test this
|
|
// functionality here, and we can't import any of the normal test helpers
|
|
// since it would create circular dependencies.
|
|
_, err = sqldb.ExecContext(testCtx, "insert into worker_auth_ca_certificate(certificate, not_valid_after, public_key, private_key, state, issuing_ca, key_id) values ('certificate', CURRENT_TIMESTAMP+'1h'::interval, 'public_key', 'private_key', 'current', 'roots', $1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "delete from worker_auth_ca_certificate")
|
|
require.NoError(t, err)
|
|
})
|
|
destroyed, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "delete from kms_data_key_version_destruction_job; delete from kms_data_key_version_destruction_job_run")
|
|
require.NoError(t, err)
|
|
})
|
|
assert.False(t, destroyed)
|
|
// Check that the data key version destruction job was created
|
|
jobs, err := kmsCache.ListDataKeyVersionDestructionJobs(testCtx, "global")
|
|
require.NoError(t, err)
|
|
require.Len(t, jobs, 1)
|
|
assert.Equal(t, jobs[0].KeyId, kvToDestroy.Id)
|
|
assert.Equal(t, jobs[0].CompletedCount, int64(0))
|
|
assert.Equal(t, jobs[0].TotalCount, int64(1))
|
|
assert.Equal(t, jobs[0].Status, "pending")
|
|
assert.Equal(t, jobs[0].ScopeId, "global")
|
|
assert.True(t, jobs[0].CreateTime.AsTime().Before(time.Now()))
|
|
})
|
|
t.Run("errors-when-destroying-key-that-is-already-destroying", func(t *testing.T) {
|
|
var kvToDestroy wrappingKms.KeyVersion
|
|
for _, key := range keys {
|
|
if key.Purpose == wrappingKms.KeyPurpose(KeyPurposeDatabase.String()) {
|
|
kvToDestroy = key.Versions[0]
|
|
}
|
|
}
|
|
_, err = sqldb.ExecContext(testCtx, "insert into worker_auth_ca_certificate(certificate, not_valid_after, public_key, private_key, state, issuing_ca, key_id) values ('certificate', CURRENT_TIMESTAMP+'1h'::interval, 'public_key', 'private_key', 'current', 'roots', $1)", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() {
|
|
_, err = sqldb.ExecContext(testCtx, "delete from worker_auth_ca_certificate")
|
|
require.NoError(t, err)
|
|
})
|
|
destroyed, err := kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.NoError(t, err)
|
|
assert.False(t, destroyed)
|
|
_, err = kmsCache.DestroyKeyVersion(testCtx, "global", kvToDestroy.Id)
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Match(errors.T(errors.InvalidParameter), err), "error did not match InvalidParameter as expected: %v", err)
|
|
})
|
|
}
|
|
|
|
func Test_RegisterTableRewrapFn(t *testing.T) {
|
|
rewrapFn := func(ctx context.Context, dataKeyVersionId, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
return nil
|
|
}
|
|
RegisterTableRewrapFn("a_table", rewrapFn)
|
|
t.Cleanup(func() {
|
|
tableNameToRewrapFn = map[string]RewrapFn{}
|
|
})
|
|
assert.Panics(t, func() {
|
|
RegisterTableRewrapFn("a_table", rewrapFn)
|
|
})
|
|
}
|
|
|
|
func Test_ListTablesSupportingRewrap(t *testing.T) {
|
|
assert.Empty(t, ListTablesSupportingRewrap())
|
|
rewrapFn := func(ctx context.Context, dataKeyVersionId, scopeId string, reader db.Reader, writer db.Writer, kms GetWrapperer) error {
|
|
return nil
|
|
}
|
|
RegisterTableRewrapFn("a_table", rewrapFn)
|
|
t.Cleanup(func() {
|
|
tableNameToRewrapFn = map[string]RewrapFn{}
|
|
})
|
|
registeredTables := ListTablesSupportingRewrap()
|
|
assert.Len(t, registeredTables, 1)
|
|
assert.Equal(t, registeredTables[0], "a_table")
|
|
}
|
|
|
|
type invalidReader struct {
|
|
db.Reader
|
|
}
|
|
|
|
type invalidWriter struct {
|
|
db.Writer
|
|
}
|