mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
207 lines
9.7 KiB
207 lines
9.7 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/hashicorp/boundary/internal/server/store"
|
|
"github.com/hashicorp/boundary/internal/types/scope"
|
|
"github.com/hashicorp/go-kms-wrapping/extras/kms/v2/migrations"
|
|
"github.com/hashicorp/nodeenrollment/rotation"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func TestRewrap_workerAuthCertRewrapFn(t *testing.T) {
|
|
ctx := context.Background()
|
|
t.Run("errors-on-query-error", func(t *testing.T) {
|
|
conn, mock := db.TestSetupWithMock(t)
|
|
wrapper := db.TestWrapper(t)
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_schema_version" WHERE 1=1 ORDER BY "kms_schema_version"\."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_oplog_schema_version" WHERE 1=1 ORDER BY "kms_oplog_schema_version"."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
rw := db.New(conn)
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "worker_auth_ca_certificate" WHERE key_id=\$1`,
|
|
).WillReturnError(errors.New("Query error"))
|
|
err := workerAuthCertRewrapFn(ctx, "some_id", "some_scope", rw, rw, kmsCache)
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("success", func(t *testing.T) {
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) // despite not looking like it, this is necessary for some reason
|
|
|
|
// create a repo and rotate to generate the first auth cert
|
|
workerAuthRepo, err := NewRepositoryStorage(ctx, rw, rw, kmsCache)
|
|
assert.NoError(t, err)
|
|
roots, err := rotation.RotateRootCertificates(ctx, workerAuthRepo)
|
|
assert.NoError(t, err)
|
|
|
|
// pull this directly from the db rather than convert so we don't lose key info
|
|
currentRoot := allocRootCertificate()
|
|
assert.NoError(t, rw.SearchWhere(ctx, ¤tRoot, "state = ?", []any{"current"}, db.WithLimit(-1)))
|
|
|
|
// now things are stored in the db, we can rotate and rewrap
|
|
assert.NoError(t, kmsCache.RotateKeys(ctx, scope.Global.String()))
|
|
assert.NoError(t, workerAuthCertRewrapFn(ctx, currentRoot.KeyId, scope.Global.String(), rw, rw, kmsCache))
|
|
|
|
// now we pull the certs back from the db, decrypt it with the new key, and ensure things match
|
|
got := allocRootCertificate()
|
|
assert.NoError(t, rw.SearchWhere(ctx, &got, "state = ?", []any{"current"}, db.WithLimit(-1)))
|
|
|
|
kmsWrapper2, err := kmsCache.GetWrapper(context.Background(), scope.Global.String(), kms.KeyPurposeDatabase, kms.WithKeyId(got.GetKeyId()))
|
|
assert.NoError(t, err)
|
|
newKeyVersionId, err := kmsWrapper2.KeyId(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// decrypt with the new key version and check to make sure things match
|
|
assert.NoError(t, got.decrypt(ctx, kmsWrapper2))
|
|
assert.NotEmpty(t, got.GetKeyId())
|
|
assert.NotEqual(t, currentRoot.GetKeyId(), got.GetKeyId())
|
|
assert.Equal(t, newKeyVersionId, got.GetKeyId())
|
|
assert.NotEqual(t, currentRoot.GetCtPrivateKey(), got.GetCtPrivateKey())
|
|
assert.Equal(t, roots.Current.PrivateKeyPkcs8, got.GetPrivateKey())
|
|
})
|
|
}
|
|
|
|
func TestRewrap_workerAuthRewrapFn(t *testing.T) {
|
|
ctx := context.Background()
|
|
t.Run("errors-on-query-error", func(t *testing.T) {
|
|
conn, mock := db.TestSetupWithMock(t)
|
|
wrapper := db.TestWrapper(t)
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_schema_version" WHERE 1=1 ORDER BY "kms_schema_version"\."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_oplog_schema_version" WHERE 1=1 ORDER BY "kms_oplog_schema_version"."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
rw := db.New(conn)
|
|
mock.ExpectQuery(
|
|
`select distinct auth\.worker_key_identifier, auth\.controller_encryption_priv_key, auth\.key_id from server_worker worker inner join worker_auth_authorized auth on auth\.worker_id = worker\.public_id where worker\.scope_id = \$1 and auth\.key_id = \$2`,
|
|
).WillReturnError(errors.New("Query error"))
|
|
err := workerAuthRewrapFn(ctx, "some_id", "some_scope", rw, rw, kmsCache)
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("success", func(t *testing.T) {
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) // despite not looking like it, this is necessary for some reason
|
|
|
|
worker := TestPkiWorker(t, conn, wrapper)
|
|
kmsWrapper, err := kmsCache.GetWrapper(ctx, worker.GetScopeId(), kms.KeyPurposeDatabase)
|
|
assert.NoError(t, err)
|
|
workerAuth := TestWorkerAuth(t, conn, worker, kmsWrapper)
|
|
|
|
// TestWorkerAuth DOES encrypt the entry, so just store in the db
|
|
rows, err := rw.Update(ctx, workerAuth, []string{"CtControllerEncryptionPrivKey"}, nil)
|
|
assert.Equal(t, 1, rows)
|
|
assert.NoError(t, err)
|
|
|
|
// now things are stored in the db, we can rotate and rewrap
|
|
assert.NoError(t, kmsCache.RotateKeys(ctx, scope.Global.String()))
|
|
assert.NoError(t, workerAuthRewrapFn(ctx, workerAuth.GetKeyId(), scope.Global.String(), rw, rw, kmsCache))
|
|
|
|
// now we pull the auth back from the db, decrypt it with the new key, and ensure things match
|
|
got := allocWorkerAuth()
|
|
got.WorkerKeyIdentifier = workerAuth.WorkerKeyIdentifier
|
|
assert.NoError(t, rw.LookupById(ctx, got))
|
|
|
|
kmsWrapper2, err := kmsCache.GetWrapper(context.Background(), worker.GetScopeId(), kms.KeyPurposeDatabase, kms.WithKeyId(got.GetKeyId()))
|
|
assert.NoError(t, err)
|
|
newKeyVersionId, err := kmsWrapper2.KeyId(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// decrypt with the new key version and check to make sure things match
|
|
assert.NoError(t, got.decrypt(ctx, kmsWrapper2))
|
|
assert.NotEmpty(t, got.GetKeyId())
|
|
assert.NotEqual(t, workerAuth.GetKeyId(), got.GetKeyId())
|
|
assert.Equal(t, newKeyVersionId, got.GetKeyId())
|
|
assert.NotEqual(t, workerAuth.GetCtControllerEncryptionPrivKey(), got.GetCtControllerEncryptionPrivKey())
|
|
assert.Equal(t, workerAuth.ControllerEncryptionPrivKey, got.GetControllerEncryptionPrivKey())
|
|
})
|
|
}
|
|
|
|
func TestRewrap_workerAuthServerLedActivationTokenRewrapFn(t *testing.T) {
|
|
ctx := context.Background()
|
|
t.Run("errors-on-query-error", func(t *testing.T) {
|
|
conn, mock := db.TestSetupWithMock(t)
|
|
wrapper := db.TestWrapper(t)
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_schema_version" WHERE 1=1 ORDER BY "kms_schema_version"\."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
mock.ExpectQuery(
|
|
`SELECT \* FROM "kms_oplog_schema_version" WHERE 1=1 ORDER BY "kms_oplog_schema_version"."version" LIMIT \$1`,
|
|
).WillReturnRows(sqlmock.NewRows([]string{"version", "create_time"}).AddRow(migrations.Version, time.Now()))
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
rw := db.New(conn)
|
|
mock.ExpectQuery(
|
|
`select distinct auth_token\.worker_id, auth_token\.creation_time_encrypted, auth_token\.key_id from server_worker worker inner join worker_auth_server_led_activation_token auth_token on auth_token\.worker_id = worker\.public_id where worker\.scope_id = \$1 and auth_token\.key_id = \$2`,
|
|
).WillReturnError(errors.New("Query error"))
|
|
err := workerAuthServerLedActivationTokenRewrapFn(ctx, "some_id", "some_scope", rw, rw, kmsCache)
|
|
require.Error(t, err)
|
|
})
|
|
t.Run("success", func(t *testing.T) {
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) // despite not looking like it, this is necessary for some reason
|
|
|
|
worker := TestPkiWorker(t, conn, wrapper)
|
|
kmsWrapper, err := kmsCache.GetWrapper(ctx, worker.GetScopeId(), kms.KeyPurposeDatabase)
|
|
assert.NoError(t, err)
|
|
token := &WorkerAuthServerLedActivationToken{
|
|
WorkerAuthServerLedActivationToken: &store.WorkerAuthServerLedActivationToken{
|
|
WorkerId: worker.GetPublicId(),
|
|
TokenId: "not a real token id lmao",
|
|
CreationTime: []byte("marshaled timestamppb.Timestamp indeed"),
|
|
},
|
|
}
|
|
|
|
// encrypt the entry and store in db
|
|
assert.NoError(t, token.encrypt(ctx, kmsWrapper))
|
|
assert.NoError(t, rw.Create(ctx, token))
|
|
|
|
// now things are stored in the db, we can rotate and rewrap
|
|
assert.NoError(t, kmsCache.RotateKeys(ctx, scope.Global.String()))
|
|
assert.NoError(t, workerAuthServerLedActivationTokenRewrapFn(ctx, token.GetKeyId(), worker.GetScopeId(), rw, rw, kmsCache))
|
|
|
|
// now we pull the auth back from the db, decrypt it with the new key, and ensure things match
|
|
got := allocWorkerAuthServerLedActivationToken()
|
|
got.WorkerId = worker.GetPublicId()
|
|
assert.NoError(t, rw.LookupById(ctx, got))
|
|
|
|
kmsWrapper2, err := kmsCache.GetWrapper(context.Background(), worker.GetScopeId(), kms.KeyPurposeDatabase, kms.WithKeyId(got.GetKeyId()))
|
|
assert.NoError(t, err)
|
|
newKeyVersionId, err := kmsWrapper2.KeyId(ctx)
|
|
assert.NoError(t, err)
|
|
|
|
// decrypt with the new key version and check to make sure things match
|
|
assert.NoError(t, got.decrypt(ctx, kmsWrapper2))
|
|
assert.NotEmpty(t, got.GetKeyId())
|
|
assert.NotEqual(t, token.GetKeyId(), got.GetKeyId())
|
|
assert.Equal(t, newKeyVersionId, got.GetKeyId())
|
|
assert.NotEqual(t, token.GetCreationTimeEncrypted(), got.GetCreationTimeEncrypted())
|
|
assert.Equal(t, token.GetCreationTime(), got.GetCreationTime())
|
|
})
|
|
}
|