refactor (servers/workerauth): add StoreNodeInformationTx(...) (#2125)

refactor to export StoreNodeInformationTx(...) which can be used
by other repositories within their transactions to store
node information.
pull/2131/head
Jim 4 years ago committed by GitHub
parent 6cbb36f2b1
commit 9da0faabd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -74,10 +74,16 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodee.Messa
// Determine type of message to store
switch t := msg.(type) {
case *types.NodeInformation:
err := r.storeNodeInformation(ctx, t)
// Encrypt the private key
databaseWrapper, err := r.kms.GetWrapper(ctx, scope.Global.String(), kms.KeyPurposeDatabase)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(read db.Reader, w db.Writer) error {
return StoreNodeInformationTx(ctx, w, databaseWrapper, t)
}); err != nil {
return errors.Wrap(ctx, err, op)
}
case *types.RootCertificates:
err := r.storeRootCertificates(ctx, t)
if err != nil {
@ -90,11 +96,26 @@ func (r *WorkerAuthRepositoryStorage) Store(ctx context.Context, msg nodee.Messa
return nil
}
// StoreNodeInformationTx stores NodeInformation. No options are currently
// supported.
//
// This function encapsulates all the work required within a dbw.TxHandler and
// allows this capability to be shared with other repositories or just called
// within a transaction. To be clear, this repository function doesn't include
// its own transaction and is intended to be used within a transaction provided
// by the caller.
//
// Node information is stored in two parts:
// * the workerAuth record is stored with a reference to a worker
// * certificate bundles are stored with a reference to the workerAuth record and issuing root certificate
func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context, node *types.NodeInformation) error {
func StoreNodeInformationTx(ctx context.Context, writer db.Writer, databaseWrapper wrapping.Wrapper, node *types.NodeInformation, _ ...Option) error {
const op = "servers.(WorkerAuthRepositoryStorage).storeNodeInformation"
if isNil(writer) {
return errors.New(ctx, errors.InvalidParameter, op, "missing writer")
}
if isNil(databaseWrapper) {
return errors.New(ctx, errors.InvalidParameter, op, "missing database wrapper")
}
if node == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing NodeInformation")
}
@ -104,11 +125,7 @@ func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context,
nodeAuth.WorkerEncryptionPubKey = node.EncryptionPublicKeyBytes
nodeAuth.WorkerSigningPubKey = node.CertificatePublicKeyPkix
// Encrypt the private key
databaseWrapper, err := r.kms.GetWrapper(ctx, scope.Global.String(), kms.KeyPurposeDatabase)
if err != nil {
return errors.Wrap(ctx, err, op)
}
var err error
nodeAuth.KeyId, err = databaseWrapper.KeyId(ctx)
if err != nil {
return errors.Wrap(ctx, err, op)
@ -134,30 +151,25 @@ func (r *WorkerAuthRepositoryStorage) storeNodeInformation(ctx context.Context,
return errors.Wrap(ctx, err, op)
}
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(read db.Reader, w db.Writer) error {
// Store WorkerAuth
storeNodeAuth := nodeAuth.clone()
if err := w.Create(ctx, &storeNodeAuth); err != nil {
// Store WorkerAuth
if err := writer.Create(ctx, &nodeAuth); err != nil {
return errors.Wrap(ctx, err, op)
}
// Then store cert bundles associated with this WorkerAuth
for _, c := range node.CertificateBundles {
if err := storeWorkerCertBundle(ctx, c, nodeAuth.WorkerKeyIdentifier, writer); err != nil {
return errors.Wrap(ctx, err, op)
}
// Then store cert bundles associated with this WorkerAuth
for _, c := range node.CertificateBundles {
if err := r.storeWorkerCertBundle(ctx, c, nodeAuth.WorkerKeyIdentifier, w); err != nil {
return errors.Wrap(ctx, err, op)
}
}
return nil
},
)
if err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
func (r *WorkerAuthRepositoryStorage) storeWorkerCertBundle(ctx context.Context, bundle *types.CertificateBundle,
workerKeyIdentifier string, writer db.Writer,
func storeWorkerCertBundle(
ctx context.Context,
bundle *types.CertificateBundle,
workerKeyIdentifier string,
writer db.Writer,
) error {
const op = "servers.(WorkerAuthRepositoryStorage).storeWorkerCertBundle"
if bundle == nil {

@ -6,20 +6,21 @@ import (
"testing"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"google.golang.org/protobuf/types/known/structpb"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/types/scope"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/registration"
"github.com/hashicorp/nodeenrollment/rotation"
"github.com/hashicorp/nodeenrollment/storage/file"
"github.com/hashicorp/nodeenrollment/types"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/curve25519"
"google.golang.org/protobuf/types/known/structpb"
)
// Test RootCertificate storage
@ -235,3 +236,174 @@ func TestUnsupportedMessages(t *testing.T) {
err = storage.Remove(ctx, &types.NodeCredentials{Id: "bogus"})
require.Error(err)
}
func TestStoreNodeInformationTx(t *testing.T) {
t.Parallel()
testCtx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
testWrapper := db.TestWrapper(t)
testKeyId, err := testWrapper.KeyId(testCtx)
require.NoError(t, err)
kmsCache := kms.TestKms(t, conn, testWrapper)
// Ensures the global scope contains a valid root key
err = kmsCache.CreateKeys(context.Background(), scope.Global.String(), kms.WithRandomReader(rand.Reader))
require.NoError(t, err)
databaseWrapper, err := kmsCache.GetWrapper(context.Background(), scope.Global.String(), kms.KeyPurposeDatabase)
require.NoError(t, err)
testRootStorage, err := NewRepositoryStorage(testCtx, rw, rw, kmsCache)
require.NoError(t, err)
_, err = rotation.RotateRootCertificates(testCtx, testRootStorage)
require.NoError(t, err)
// Create struct to pass in with workerId that will be passed along to
// storage
testWorker := TestWorker(t, conn, testWrapper)
testState, err := AttachWorkerIdToState(testCtx, testWorker.PublicId)
require.NoError(t, err)
testNodeInfoFn := func() *types.NodeInformation {
// This happens on the worker
fileStorage, err := file.NewFileStorage(testCtx)
require.NoError(t, err)
nodeCreds, err := types.NewNodeCredentials(testCtx, fileStorage)
require.NoError(t, err)
nodePubKey, err := curve25519.X25519(nodeCreds.EncryptionPrivateKeyBytes, curve25519.Basepoint)
require.NoError(t, err)
// Add in node information to storage so we have a key to use
nodeInfo := &types.NodeInformation{
Id: testKeyId,
CertificatePublicKeyPkix: nodeCreds.CertificatePublicKeyPkix,
CertificatePublicKeyType: nodeCreds.CertificatePrivateKeyType,
EncryptionPublicKeyBytes: nodePubKey,
EncryptionPublicKeyType: nodeCreds.EncryptionPrivateKeyType,
ServerEncryptionPrivateKeyBytes: []byte("whatever"),
RegistrationNonce: nodeCreds.RegistrationNonce,
State: testState,
}
return nodeInfo
}
tests := []struct {
name string
writer db.Writer
databaseWrapper wrapping.Wrapper
node *types.NodeInformation
wantErr bool
wantErrIs errors.Code
wantErrContains string
}{
{
name: "missing-writer",
databaseWrapper: databaseWrapper,
node: testNodeInfoFn(),
wantErr: true,
wantErrIs: errors.InvalidParameter,
wantErrContains: "missing writer",
},
{
name: "missing-wrapper",
writer: rw,
node: testNodeInfoFn(),
wantErr: true,
wantErrIs: errors.InvalidParameter,
wantErrContains: "missing database wrapper",
},
{
name: "missing-node",
writer: rw,
databaseWrapper: databaseWrapper,
wantErr: true,
wantErrIs: errors.InvalidParameter,
wantErrContains: "missing NodeInformation",
},
{
name: "key-id-error",
writer: rw,
databaseWrapper: &mockTestWrapper{err: errors.New(testCtx, errors.Internal, "testing", "key-id-error")},
node: testNodeInfoFn(),
wantErr: true,
wantErrIs: errors.Internal,
wantErrContains: "key-id-error",
},
{
name: "encrypt-error",
writer: rw,
databaseWrapper: &mockTestWrapper{
encryptError: true,
err: errors.New(testCtx, errors.Encrypt, "testing", "encrypt-error"),
},
node: testNodeInfoFn(),
wantErr: true,
wantErrIs: errors.Encrypt,
wantErrContains: "encrypt-error",
},
{
name: "create-error",
writer: &db.Db{},
databaseWrapper: databaseWrapper,
node: testNodeInfoFn(),
wantErr: true,
wantErrContains: "db.Create",
},
{
name: "success",
writer: rw,
databaseWrapper: databaseWrapper,
node: testNodeInfoFn(),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
err := StoreNodeInformationTx(testCtx, tc.writer, tc.databaseWrapper, tc.node)
if tc.wantErr {
require.Error(err)
if tc.wantErrIs != errors.Unknown {
assert.True(errors.Match(errors.T(tc.wantErrIs), err))
}
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)
})
}
}
type mockTestWrapper struct {
wrapping.Wrapper
decryptError bool
encryptError bool
err error
keyId string
}
func (m *mockTestWrapper) KeyId(context.Context) (string, error) {
if m.err != nil {
return "", m.err
}
return m.keyId, nil
}
func (m *mockTestWrapper) Encrypt(ctx context.Context, plaintext []byte, options ...wrapping.Option) (*wrapping.BlobInfo, error) {
if m.err != nil && m.encryptError {
return nil, m.err
}
panic("todo")
}
func (m *mockTestWrapper) Decrypt(ctx context.Context, ciphertext *wrapping.BlobInfo, options ...wrapping.Option) ([]byte, error) {
if m.err != nil && m.decryptError {
return nil, m.err
}
return []byte("decrypted"), nil
}

Loading…
Cancel
Save