From 9da0faabd7888b31b65a991204accae2c6885ce3 Mon Sep 17 00:00:00 2001 From: Jim Date: Fri, 3 Jun 2022 12:28:54 -0400 Subject: [PATCH] refactor (servers/workerauth): add StoreNodeInformationTx(...) (#2125) refactor to export StoreNodeInformationTx(...) which can be used by other repositories within their transactions to store node information. --- internal/servers/repository_workerauth.go | 60 +++--- .../servers/repository_workerauth_test.go | 178 +++++++++++++++++- 2 files changed, 211 insertions(+), 27 deletions(-) diff --git a/internal/servers/repository_workerauth.go b/internal/servers/repository_workerauth.go index 95d22465dd..44bd705204 100644 --- a/internal/servers/repository_workerauth.go +++ b/internal/servers/repository_workerauth.go @@ -74,10 +74,16 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodee.Messa // Determine type of message to store switch t := msg.(type) { case *types.NodeInformation: - err := r.storeNodeInformation(ctx, t) + // Encrypt the private key + databaseWrapper, err := r.kms.GetWrapper(ctx, scope.Global.String(), kms.KeyPurposeDatabase) if err != nil { return errors.Wrap(ctx, err, op) } + if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(read db.Reader, w db.Writer) error { + return StoreNodeInformationTx(ctx, w, databaseWrapper, t) + }); err != nil { + return errors.Wrap(ctx, err, op) + } case *types.RootCertificates: err := r.storeRootCertificates(ctx, t) if err != nil { @@ -90,11 +96,26 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodee.Messa return nil } +// StoreNodeInformationTx stores NodeInformation. No options are currently +// supported. +// +// This function encapsulates all the work required within a dbw.TxHandler and +// allows this capability to be shared with other repositories or just called +// within a transaction. To be clear, this repository function doesn't include +// its own transaction and is intended to be used within a transaction provided +// by the caller. +// // Node information is stored in two parts: // * the workerAuth record is stored with a reference to a worker // * certificate bundles are stored with a reference to the workerAuth record and issuing root certificate -func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context, node *types.NodeInformation) error { +func StoreNodeInformationTx(ctx context.Context, writer db.Writer, databaseWrapper wrapping.Wrapper, node *types.NodeInformation, _ ...Option) error { const op = "servers.(WorkerAuthRepositoryStorage).storeNodeInformation" + if isNil(writer) { + return errors.New(ctx, errors.InvalidParameter, op, "missing writer") + } + if isNil(databaseWrapper) { + return errors.New(ctx, errors.InvalidParameter, op, "missing database wrapper") + } if node == nil { return errors.New(ctx, errors.InvalidParameter, op, "missing NodeInformation") } @@ -104,11 +125,7 @@ func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context, nodeAuth.WorkerEncryptionPubKey = node.EncryptionPublicKeyBytes nodeAuth.WorkerSigningPubKey = node.CertificatePublicKeyPkix - // Encrypt the private key - databaseWrapper, err := r.kms.GetWrapper(ctx, scope.Global.String(), kms.KeyPurposeDatabase) - if err != nil { - return errors.Wrap(ctx, err, op) - } + var err error nodeAuth.KeyId, err = databaseWrapper.KeyId(ctx) if err != nil { return errors.Wrap(ctx, err, op) @@ -134,30 +151,25 @@ func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context, return errors.Wrap(ctx, err, op) } - _, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(read db.Reader, w db.Writer) error { - // Store WorkerAuth - storeNodeAuth := nodeAuth.clone() - if err := w.Create(ctx, &storeNodeAuth); err != nil { + // Store WorkerAuth + if err := writer.Create(ctx, &nodeAuth); err != nil { + return errors.Wrap(ctx, err, op) + } + // Then store cert bundles associated with this WorkerAuth + for _, c := range node.CertificateBundles { + if err := storeWorkerCertBundle(ctx, c, nodeAuth.WorkerKeyIdentifier, writer); err != nil { return errors.Wrap(ctx, err, op) } - // Then store cert bundles associated with this WorkerAuth - for _, c := range node.CertificateBundles { - if err := r.storeWorkerCertBundle(ctx, c, nodeAuth.WorkerKeyIdentifier, w); err != nil { - return errors.Wrap(ctx, err, op) - } - } - return nil - }, - ) - if err != nil { - return errors.Wrap(ctx, err, op) } return nil } -func (r *WorkerAuthRepositoryStorage) storeWorkerCertBundle(ctx context.Context, bundle *types.CertificateBundle, - workerKeyIdentifier string, writer db.Writer, +func storeWorkerCertBundle( + ctx context.Context, + bundle *types.CertificateBundle, + workerKeyIdentifier string, + writer db.Writer, ) error { const op = "servers.(WorkerAuthRepositoryStorage).storeWorkerCertBundle" if bundle == nil { diff --git a/internal/servers/repository_workerauth_test.go b/internal/servers/repository_workerauth_test.go index 7d90077ef8..11f7bea3b8 100644 --- a/internal/servers/repository_workerauth_test.go +++ b/internal/servers/repository_workerauth_test.go @@ -6,20 +6,21 @@ import ( "testing" "github.com/fatih/structs" - "github.com/mitchellh/mapstructure" - "google.golang.org/protobuf/types/known/structpb" - "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/types/scope" + wrapping "github.com/hashicorp/go-kms-wrapping/v2" "github.com/hashicorp/nodeenrollment" "github.com/hashicorp/nodeenrollment/registration" "github.com/hashicorp/nodeenrollment/rotation" "github.com/hashicorp/nodeenrollment/storage/file" "github.com/hashicorp/nodeenrollment/types" + "github.com/mitchellh/mapstructure" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/curve25519" + "google.golang.org/protobuf/types/known/structpb" ) // Test RootCertificate storage @@ -235,3 +236,174 @@ func TestUnsupportedMessages(t *testing.T) { err = storage.Remove(ctx, &types.NodeCredentials{Id: "bogus"}) require.Error(err) } + +func TestStoreNodeInformationTx(t *testing.T) { + t.Parallel() + testCtx := context.Background() + + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + + testWrapper := db.TestWrapper(t) + testKeyId, err := testWrapper.KeyId(testCtx) + require.NoError(t, err) + + kmsCache := kms.TestKms(t, conn, testWrapper) + // Ensures the global scope contains a valid root key + err = kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader)) + require.NoError(t, err) + databaseWrapper, err := kmsCache.GetWrapper(context.Background(), scope.Global.String(), kms.KeyPurposeDatabase) + require.NoError(t, err) + + testRootStorage, err := NewRepositoryStorage(testCtx, rw, rw, kmsCache) + require.NoError(t, err) + + _, err = rotation.RotateRootCertificates(testCtx, testRootStorage) + require.NoError(t, err) + + // Create struct to pass in with workerId that will be passed along to + // storage + testWorker := TestWorker(t, conn, testWrapper) + + testState, err := AttachWorkerIdToState(testCtx, testWorker.PublicId) + require.NoError(t, err) + + testNodeInfoFn := func() *types.NodeInformation { + // This happens on the worker + fileStorage, err := file.NewFileStorage(testCtx) + require.NoError(t, err) + nodeCreds, err := types.NewNodeCredentials(testCtx, fileStorage) + require.NoError(t, err) + + nodePubKey, err := curve25519.X25519(nodeCreds.EncryptionPrivateKeyBytes, curve25519.Basepoint) + require.NoError(t, err) + // Add in node information to storage so we have a key to use + nodeInfo := &types.NodeInformation{ + Id: testKeyId, + CertificatePublicKeyPkix: nodeCreds.CertificatePublicKeyPkix, + CertificatePublicKeyType: nodeCreds.CertificatePrivateKeyType, + EncryptionPublicKeyBytes: nodePubKey, + EncryptionPublicKeyType: nodeCreds.EncryptionPrivateKeyType, + ServerEncryptionPrivateKeyBytes: []byte("whatever"), + RegistrationNonce: nodeCreds.RegistrationNonce, + State: testState, + } + return nodeInfo + } + + tests := []struct { + name string + writer db.Writer + databaseWrapper wrapping.Wrapper + node *types.NodeInformation + wantErr bool + wantErrIs errors.Code + wantErrContains string + }{ + { + name: "missing-writer", + databaseWrapper: databaseWrapper, + node: testNodeInfoFn(), + wantErr: true, + wantErrIs: errors.InvalidParameter, + wantErrContains: "missing writer", + }, + { + name: "missing-wrapper", + writer: rw, + node: testNodeInfoFn(), + wantErr: true, + wantErrIs: errors.InvalidParameter, + wantErrContains: "missing database wrapper", + }, + { + name: "missing-node", + writer: rw, + databaseWrapper: databaseWrapper, + wantErr: true, + wantErrIs: errors.InvalidParameter, + wantErrContains: "missing NodeInformation", + }, + { + name: "key-id-error", + writer: rw, + databaseWrapper: &mockTestWrapper{err: errors.New(testCtx, errors.Internal, "testing", "key-id-error")}, + node: testNodeInfoFn(), + wantErr: true, + wantErrIs: errors.Internal, + wantErrContains: "key-id-error", + }, + { + name: "encrypt-error", + writer: rw, + databaseWrapper: &mockTestWrapper{ + encryptError: true, + err: errors.New(testCtx, errors.Encrypt, "testing", "encrypt-error"), + }, + node: testNodeInfoFn(), + wantErr: true, + wantErrIs: errors.Encrypt, + wantErrContains: "encrypt-error", + }, + { + name: "create-error", + writer: &db.Db{}, + databaseWrapper: databaseWrapper, + node: testNodeInfoFn(), + wantErr: true, + wantErrContains: "db.Create", + }, + { + name: "success", + writer: rw, + databaseWrapper: databaseWrapper, + node: testNodeInfoFn(), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + err := StoreNodeInformationTx(testCtx, tc.writer, tc.databaseWrapper, tc.node) + if tc.wantErr { + require.Error(err) + if tc.wantErrIs != errors.Unknown { + assert.True(errors.Match(errors.T(tc.wantErrIs), err)) + } + if tc.wantErrContains != "" { + assert.Contains(err.Error(), tc.wantErrContains) + } + return + } + require.NoError(err) + }) + } +} + +type mockTestWrapper struct { + wrapping.Wrapper + decryptError bool + encryptError bool + err error + keyId string +} + +func (m *mockTestWrapper) KeyId(context.Context) (string, error) { + if m.err != nil { + return "", m.err + } + return m.keyId, nil +} + +func (m *mockTestWrapper) Encrypt(ctx context.Context, plaintext []byte, options ...wrapping.Option) (*wrapping.BlobInfo, error) { + if m.err != nil && m.encryptError { + return nil, m.err + } + panic("todo") +} + +func (m *mockTestWrapper) Decrypt(ctx context.Context, ciphertext *wrapping.BlobInfo, options ...wrapping.Option) ([]byte, error) { + if m.err != nil && m.decryptError { + return nil, m.err + } + return []byte("decrypted"), nil +}