internal/host/plugin: add UpdateSet (#1686)

* internal/host/plugin: add UpdateSet

This adds the UpdateSet repository function.

UpdateSet runs the OnUpdateSet plugin call within the transaction as per
the changes in #1683.

The plugin call returns the updated set, the hosts in the set, the rows
updated, and any error.

* Ensure that attributes can be cleared by supplying a full null

* move version increment to the end and track perferred endpoint update separately

* move plugin call to end of transaction

* only clone returned set once, and move host fetch call out of transaction

* Request a sync when attributes are updated

* Add test to assert that old preferred endpoints are being removed correctly
pull/1684/head
Chris Marchesi 5 years ago committed by GitHub
parent 772bb88d8d
commit c2cd6febca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,12 +3,14 @@ package plugin
import (
"context"
"fmt"
"strings"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/host/store"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/libs/endpoint"
"github.com/hashicorp/boundary/internal/libs/patchstruct"
@ -20,6 +22,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
// CreateSet inserts s into the repository and returns a new HostSet
@ -168,6 +171,344 @@ func (r *Repository) CreateSet(ctx context.Context, scopeId string, s *HostSet,
return returnedHostSet, plg, nil
}
// UpdateSet updates the repository for host set entry s with the
// values populated, for the fields listed in fieldMask. It returns a
// new HostSet containing the updated values, the hosts in the set,
// and a count of the number of records updated. s is not changed.
//
// s must contain a valid PublicId and CatalogId. Name, Description,
// Attributes, and PreferredEndpoints can be updated. Name must be
// unique among all sets associated with a single catalog.
//
// An attribute of s will be set to NULL in the database if the
// attribute in s is the zero value and it is included in fieldMask.
// Note that this does not apply to s.Attributes - a null
// s.Attributes is a no-op for modifications. Rather, if fields need
// to be reset, its field in c.Attributes should individually set to
// null.
//
// Updates are sent to OnUpdateSet with a full copy of both the
// current set, the state of the new set should it be updated, along
// with its parent host catalog and persisted state (can include
// secrets). This is a stateless call and does not affect the final
// record written, but some plugins may perform some actions on this
// call. Update of the record in the database is aborted if this call
// fails.
func (r *Repository) UpdateSet(ctx context.Context, scopeId string, s *HostSet, version uint32, fieldMask []string, opt ...Option) (*HostSet, []*Host, *hostplugin.Plugin, int, error) {
const op = "plugin.(Repository).UpdateSet"
if s == nil {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "nil HostSet")
}
if s.HostSet == nil {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "nil embedded HostSet")
}
if s.PublicId == "" {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "no public id")
}
if scopeId == "" {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "no scope id")
}
if len(fieldMask) == 0 {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask")
}
// Get the old set first. We patch the record first before
// sending it to the DB for updating so that we can run on
// OnUpdateSet. Note that the field masks are still used for
// updating.
sets, plg, err := r.getSets(ctx, s.PublicId, "")
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
if len(sets) == 0 {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("host set id %q not found", s.PublicId))
}
if len(sets) != 1 {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.Internal, op, fmt.Sprintf("unexpected amount of sets found, want=1, got=%d", len(sets)))
}
currentSet := sets[0]
// Assert the version of the current set to make sure we're
// updating the correct one.
if currentSet.GetVersion() != version {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("set version mismatch, want=%d, got=%d", currentSet.GetVersion(), version))
}
// Clone the set so that we can set fields.
newSet := currentSet.clone()
var updateAttributes bool
var dbMask, nullFields []string
const (
endpointOpNoop = "endpointOpNoop"
endpointOpDelete = "endpointOpDelete"
endpointOpUpdate = "endpointOpUpdate"
)
var endpointOp string = endpointOpNoop
for _, f := range fieldMask {
switch {
case strings.EqualFold("name", f) && s.Name == "":
nullFields = append(nullFields, "name")
newSet.Name = s.Name
case strings.EqualFold("name", f) && s.Name != "":
dbMask = append(dbMask, "name")
newSet.Name = s.Name
case strings.EqualFold("description", f) && s.Description == "":
nullFields = append(nullFields, "description")
newSet.Description = s.Description
case strings.EqualFold("description", f) && s.Description != "":
dbMask = append(dbMask, "description")
newSet.Description = s.Description
case strings.EqualFold("PreferredEndpoints", f) && len(s.PreferredEndpoints) == 0:
endpointOp = endpointOpDelete
newSet.PreferredEndpoints = s.PreferredEndpoints
case strings.EqualFold("PreferredEndpoints", f) && len(s.PreferredEndpoints) != 0:
endpointOp = endpointOpUpdate
newSet.PreferredEndpoints = s.PreferredEndpoints
case strings.EqualFold("attributes", strings.Split(f, ".")[0]):
// Flag attributes for updating. While multiple masks may be
// sent, we only need to do this once.
updateAttributes = true
default:
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidFieldMask, op, fmt.Sprintf("invalid field mask: %s", f))
}
}
if updateAttributes {
dbMask = append(dbMask, "attributes")
if s.Attributes != nil {
newSet.Attributes, err = patchstruct.PatchBytes(newSet.Attributes, s.Attributes)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("error in set attribute JSON"))
}
} else {
newSet.Attributes = make([]byte, 0)
}
// Flag the record as needing a sync since we've updated
// attributes.
dbMask = append(dbMask, "NeedSync")
newSet.NeedSync = true
}
// Get the host catalog for the set and its persisted data.
catalog, persisted, err := r.getCatalog(ctx, newSet.CatalogId)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("error looking up catalog with id %q", newSet.CatalogId)))
}
// Assert that the catalog scope ID and supplied scope ID match.
if catalog.ScopeId != scopeId {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("catalog id %q not in scope id %q", newSet.CatalogId, scopeId))
}
// Get the plugin client.
plgClient, ok := r.plugins[catalog.GetPluginId()]
if !ok || plgClient == nil {
return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.Internal, op, fmt.Sprintf("plugin %q not available", catalog.GetPluginId()))
}
// Convert the catalog values to API protobuf values, which is what
// we use for the plugin hook calls.
plgHc, err := toPluginCatalog(ctx, catalog)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
currentPlgSet, err := toPluginSet(ctx, currentSet)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
newPlgSet, err := toPluginSet(ctx, newSet)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
// Get the preferred endpoints to write out.
var preferredEndpoints []interface{}
if endpointOp == endpointOpUpdate {
preferredEndpoints = make([]interface{}, 0, len(newSet.PreferredEndpoints))
for i, e := range newSet.PreferredEndpoints {
obj, err := host.NewPreferredEndpoint(ctx, newSet.PublicId, uint32(i+1), e)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
preferredEndpoints = append(preferredEndpoints, obj)
}
}
oplogWrapper, err := r.kms.GetWrapper(ctx, scopeId, kms.KeyPurposeOplog)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper"))
}
// If the call to the plugin succeeded, we do not want to call it again if
// the transaction failed and is being retried.
var pluginCalledSuccessfully bool
var setUpdated, preferredEndpointsUpdated bool
var returnedSet *HostSet
var hosts []*Host
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(r db.Reader, w db.Writer) error {
returnedSet = newSet.clone()
msgs := make([]*oplog.Message, 0, len(preferredEndpoints)+2)
ticket, err := w.GetTicket(s)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket"))
}
if len(dbMask) != 0 || len(nullFields) != 0 {
var hsOplogMsg oplog.Message
numUpdated, err := w.Update(
ctx,
returnedSet,
dbMask,
nullFields,
db.NewOplogMsg(&hsOplogMsg),
db.WithVersion(&version),
)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if numUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("expected 1 set to be updated, got %d", numUpdated))
}
setUpdated = true
msgs = append(msgs, &hsOplogMsg)
}
switch endpointOp {
case endpointOpDelete:
// Delete all old endpoint entries.
var peDeleteOplogMsg oplog.Message
pep := &host.PreferredEndpoint{
PreferredEndpoint: &store.PreferredEndpoint{
HostSetId: newSet.PublicId,
Priority: 1,
},
}
_, err := w.Delete(ctx, pep, db.WithWhere("host_set_id = ?", newSet.PublicId), db.NewOplogMsg(&peDeleteOplogMsg))
if err != nil {
return errors.Wrap(ctx, err, op)
}
// Only append the oplog message if an operation was actually
// performed.
if peDeleteOplogMsg.Message != nil {
preferredEndpointsUpdated = true
msgs = append(msgs, &peDeleteOplogMsg)
}
case endpointOpUpdate:
// Delete all old endpoint entries.
var peDeleteOplogMsg oplog.Message
pep := &host.PreferredEndpoint{
PreferredEndpoint: &store.PreferredEndpoint{
HostSetId: newSet.PublicId,
Priority: 1,
},
}
_, err := w.Delete(ctx, pep, db.WithWhere("host_set_id = ?", newSet.PublicId), db.NewOplogMsg(&peDeleteOplogMsg))
if err != nil {
return errors.Wrap(ctx, err, op)
}
// Only append the oplog message if an operation was actually
// performed.
if peDeleteOplogMsg.Message != nil {
msgs = append(msgs, &peDeleteOplogMsg)
}
// Create the new entries.
peCreateOplogMsgs := make([]*oplog.Message, 0, len(preferredEndpoints))
if err := w.CreateItems(ctx, preferredEndpoints, db.NewOplogMsgs(&peCreateOplogMsgs)); err != nil {
return err
}
preferredEndpointsUpdated = true
msgs = append(msgs, peCreateOplogMsgs...)
}
if !setUpdated && preferredEndpointsUpdated {
returnedSet.Version = uint32(version) + 1
var hsOplogMsg oplog.Message
numUpdated, err := w.Update(
ctx,
returnedSet,
[]string{"version"},
[]string{},
db.NewOplogMsg(&hsOplogMsg),
db.WithVersion(&version),
)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if numUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("expected 1 set to be updated, got %d", numUpdated))
}
msgs = append(msgs, &hsOplogMsg)
}
if len(msgs) != 0 {
metadata := s.oplog(oplog.OpType_OP_TYPE_UPDATE)
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog"))
}
}
if !pluginCalledSuccessfully {
_, err = plgClient.OnUpdateSet(ctx, &plgpb.OnUpdateSetRequest{
CurrentSet: currentPlgSet,
NewSet: newPlgSet,
Catalog: plgHc,
Persisted: persisted,
})
if err != nil {
if status.Code(err) != codes.Unimplemented {
return errors.Wrap(ctx, err, op)
}
}
}
return nil
},
)
if err != nil {
if errors.IsUniqueError(err) {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s: name %s already exists", newSet.PublicId, newSet.Name)))
}
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", newSet.PublicId)))
}
hosts, err = listHostBySetIds(ctx, r.reader, []string{returnedSet.PublicId}, opt...)
if err != nil {
return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op)
}
var numUpdated int
if setUpdated || preferredEndpointsUpdated {
numUpdated = 1
}
if updateAttributes {
// Request a host sync since we have updated attributes.
_ = r.scheduler.UpdateJobNextRunInAtLeast(ctx, setSyncJobName, 0)
}
return returnedSet, hosts, plg, numUpdated, nil
}
// LookupSet will look up a host set in the repository and return the host set,
// as well as host IDs that match. If the host set is not found, it will return
// nil, nil, nil. No options are currently supported.
@ -357,8 +698,20 @@ func toPluginSet(ctx context.Context, in *HostSet) (*pb.HostSet, error) {
if in == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil storage plugin")
}
var name, description *wrapperspb.StringValue
if inName := in.GetName(); inName != "" {
name = wrapperspb.String(inName)
}
if inDescription := in.GetDescription(); inDescription != "" {
description = wrapperspb.String(inDescription)
}
hs := &pb.HostSet{
Id: in.GetPublicId(),
Id: in.GetPublicId(),
Name: name,
Description: description,
PreferredEndpoints: in.GetPreferredEndpoints(),
}
if in.GetAttributes() != nil {
attrs := &structpb.Struct{}

@ -11,6 +11,7 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/host/plugin/store"
@ -18,6 +19,7 @@ import (
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
hostplg "github.com/hashicorp/boundary/internal/plugin/host"
hostplugin "github.com/hashicorp/boundary/internal/plugin/host"
"github.com/hashicorp/boundary/internal/scheduler"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
"github.com/stretchr/testify/assert"
@ -25,6 +27,7 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func TestRepository_CreateSet(t *testing.T) {
@ -352,6 +355,757 @@ func TestRepository_CreateSet(t *testing.T) {
})
}
func TestRepository_UpdateSet(t *testing.T) {
ctx := context.Background()
dbConn, _ := db.TestSetup(t, "postgres")
dbRW := db.New(dbConn)
dbWrapper := db.TestWrapper(t)
sched := scheduler.TestScheduler(t, dbConn, dbWrapper)
dbKmsCache := kms.TestKms(t, dbConn, dbWrapper)
_, projectScope := iam.TestScopes(t, iam.TestRepo(t, dbConn, dbWrapper))
// Define a plugin "manager", basically just a map with a mock
// plugin in it. This also includes functionality to capture the
// state, and set an error and the returned secrets to nil. Note
// that the way that this is set up means that the tests cannot run
// in parallel, but there could be other factors affecting that as
// well.
var gotOnUpdateCallCount int
var gotOnUpdateSetRequest *plgpb.OnUpdateSetRequest
var pluginError error
testPlugin := hostplg.TestPlugin(t, dbConn, "test")
testPluginMap := map[string]plgpb.HostPluginServiceClient{
testPlugin.GetPublicId(): &WrappingPluginClient{
Server: &TestPluginServer{
OnUpdateSetFn: func(_ context.Context, req *plgpb.OnUpdateSetRequest) (*plgpb.OnUpdateSetResponse, error) {
gotOnUpdateCallCount++
gotOnUpdateSetRequest = req
return &plgpb.OnUpdateSetResponse{}, pluginError
},
},
},
}
// Set up a test catalog and the secrets for it
testCatalog := TestCatalog(t, dbConn, projectScope.PublicId, testPlugin.GetPublicId())
testCatalogSecret, err := newHostCatalogSecret(ctx, testCatalog.GetPublicId(), mustStruct(map[string]interface{}{
"one": "two",
}))
require.NoError(t, err)
scopeWrapper, err := dbKmsCache.GetWrapper(ctx, testCatalog.GetScopeId(), kms.KeyPurposeDatabase)
require.NoError(t, err)
require.NoError(t, testCatalogSecret.encrypt(ctx, scopeWrapper))
testCatalogSecretQ, testCatalogSecretV := testCatalogSecret.upsertQuery()
secretsUpdated, err := dbRW.Exec(ctx, testCatalogSecretQ, testCatalogSecretV)
require.NoError(t, err)
require.Equal(t, 1, secretsUpdated)
// Create a test duplicate set. We don't use this set, it just
// exists to ensure that we can test for conflicts when setting
// the name.
const testDuplicateSetName = "duplicate-set-name"
testDuplicateSet := TestSet(t, dbConn, dbKmsCache, sched, testCatalog, testPluginMap)
testDuplicateSet.Name = testDuplicateSetName
setsUpdated, err := dbRW.Update(ctx, testDuplicateSet, []string{"name"}, []string{})
require.NoError(t, err)
require.Equal(t, 1, setsUpdated)
// Define some helpers here to make the test table more readable.
type changeHostSetFunc func(s *HostSet) *HostSet
changeSetToNil := func() changeHostSetFunc {
return func(_ *HostSet) *HostSet {
return nil
}
}
changeEmbeddedSetToNil := func() changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.HostSet = nil
return s
}
}
changePublicId := func(v string) changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.PublicId = v
return s
}
}
changeName := func(v string) changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.Name = v
return s
}
}
changeDescription := func(s string) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Description = s
return c
}
}
changePreferredEndpoints := func(s []string) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.PreferredEndpoints = s
return c
}
}
changeAttributes := func(m map[string]interface{}) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Attributes = mustMarshal(m)
return c
}
}
changeAttributesNil := func() changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Attributes = nil
return c
}
}
// Define some checks that will be used in the below tests. Some of
// these are re-used, so we define them here. Most of these are
// assertions and no particular one is non-fatal in that they will
// stop execution. Note that these are executed after wantIsErr, so
// if that is set in an individual table test, these will not be
// executed.
//
// Note that we define some state here, similar to how we
// previously defined gotOnUpdateCatalogRequest above next the
// plugin map.
type checkFunc func(t *testing.T, ctx context.Context)
var (
gotSet *HostSet
gotNumUpdated int
)
checkVersion := func(want uint32) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotSet.Version)
}
}
checkName := func(want string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotSet.Name)
}
}
checkDescription := func(want string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotSet.Description)
}
}
checkPreferredEndpoints := func(want []string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotSet.PreferredEndpoints)
}
}
checkAttributes := func(want map[string]interface{}) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
require := require.New(t)
st := &structpb.Struct{}
require.NoError(proto.Unmarshal(gotSet.Attributes, st))
assert.Empty(cmp.Diff(mustStruct(want), st, protocmp.Transform()))
}
}
checkNeedSync := func(want bool) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotSet.NeedSync)
}
}
checkUpdateSetRequestCurrentNameNil := func() checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Nil(gotOnUpdateSetRequest.CurrentSet.Name)
}
}
checkUpdateSetRequestNewName := func(want string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(wrapperspb.String(want), gotOnUpdateSetRequest.NewSet.Name)
}
}
checkUpdateSetRequestNewNameNil := func() checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Nil(gotOnUpdateSetRequest.NewSet.Name)
}
}
checkUpdateSetRequestCurrentDescriptionNil := func() checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Nil(gotOnUpdateSetRequest.CurrentSet.Description)
}
}
checkUpdateSetRequestNewDescription := func(want string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(wrapperspb.String(want), gotOnUpdateSetRequest.NewSet.Description)
}
}
checkUpdateSetRequestNewDescriptionNil := func() checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Nil(gotOnUpdateSetRequest.NewSet.Description)
}
}
checkUpdateSetRequestCurrentPreferredEndpoints := func(want []string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotOnUpdateSetRequest.CurrentSet.PreferredEndpoints)
}
}
checkUpdateSetRequestNewPreferredEndpoints := func(want []string) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotOnUpdateSetRequest.NewSet.PreferredEndpoints)
}
}
checkUpdateSetRequestNewPreferredEndpointsNil := func() checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Nil(gotOnUpdateSetRequest.NewSet.PreferredEndpoints)
}
}
checkUpdateSetRequestCurrentAttributes := func(want map[string]interface{}) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.CurrentSet.Attributes, protocmp.Transform()))
}
}
checkUpdateSetRequestNewAttributes := func(want map[string]interface{}) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.NewSet.Attributes, protocmp.Transform()))
}
}
checkUpdateSetRequestPersistedSecrets := func(want map[string]interface{}) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.Persisted.Secrets, protocmp.Transform()))
}
}
checkNumUpdated := func(want int) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.Equal(want, gotNumUpdated)
}
}
checkVerifySetOplog := func(op oplog.OpType) checkFunc {
return func(t *testing.T, ctx context.Context) {
t.Helper()
assert := assert.New(t)
assert.NoError(
db.TestVerifyOplog(
t,
dbRW,
gotSet.PublicId,
db.WithOperation(op),
db.WithCreateNotBefore(10*time.Second),
),
)
}
}
tests := []struct {
name string
withScopeId *string
withEmptyPluginMap bool
withPluginError error
changeFuncs []changeHostSetFunc
version uint32
fieldMask []string
wantCheckFuncs []checkFunc
wantIsErr errors.Code
}{
{
name: "nil set",
changeFuncs: []changeHostSetFunc{changeSetToNil()},
wantIsErr: errors.InvalidParameter,
},
{
name: "nil embedded set",
changeFuncs: []changeHostSetFunc{changeEmbeddedSetToNil()},
wantIsErr: errors.InvalidParameter,
},
{
name: "missing public id",
changeFuncs: []changeHostSetFunc{changePublicId("")},
wantIsErr: errors.InvalidParameter,
},
{
name: "missing scope id",
withScopeId: func() *string { a := ""; return &a }(),
wantIsErr: errors.InvalidParameter,
},
{
name: "empty field mask",
fieldMask: nil, // Should be testing on len
wantIsErr: errors.EmptyFieldMask,
},
{
name: "bad set id",
changeFuncs: []changeHostSetFunc{changePublicId("badid")},
fieldMask: []string{"name"},
wantIsErr: errors.RecordNotFound,
},
{
name: "version mismatch",
changeFuncs: []changeHostSetFunc{changeName("foo")},
version: 1,
fieldMask: []string{"name"},
wantIsErr: errors.VersionMismatch,
},
{
name: "mismatched scope id to catalog scope",
withScopeId: func() *string { a := "badid"; return &a }(),
version: 2,
fieldMask: []string{"name"},
wantIsErr: errors.InvalidParameter,
},
{
name: "plugin lookup error",
withEmptyPluginMap: true,
version: 2,
fieldMask: []string{"name"},
wantIsErr: errors.Internal,
},
{
name: "plugin invocation error",
withPluginError: errors.New(context.Background(), errors.Internal, "TestRepository_UpdateSet/plugin_invocation_error", "test plugin error"),
version: 2,
fieldMask: []string{"name"},
wantIsErr: errors.Internal,
},
{
name: "update name (duplicate)",
changeFuncs: []changeHostSetFunc{changeName(testDuplicateSetName)},
version: 2,
fieldMask: []string{"name"},
wantIsErr: errors.NotUnique,
},
{
name: "update name",
changeFuncs: []changeHostSetFunc{changeName("foo")},
version: 2,
fieldMask: []string{"name"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewName("foo"),
checkName("foo"),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update name to same",
changeFuncs: []changeHostSetFunc{changeName("")},
version: 2,
fieldMask: []string{"name"},
wantCheckFuncs: []checkFunc{
checkVersion(2), // Version remains same even though row is updated
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewNameNil(),
checkName(""),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update description",
changeFuncs: []changeHostSetFunc{changeDescription("foo")},
version: 2,
fieldMask: []string{"description"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentDescriptionNil(),
checkUpdateSetRequestNewDescription("foo"),
checkDescription("foo"),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update description to same",
changeFuncs: []changeHostSetFunc{changeDescription("")},
version: 2,
fieldMask: []string{"description"},
wantCheckFuncs: []checkFunc{
checkVersion(2), // Version remains same even though row is updated
checkUpdateSetRequestCurrentDescriptionNil(),
checkUpdateSetRequestNewDescriptionNil(),
checkDescription(""),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update preferred endpoints",
changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})},
version: 2,
fieldMask: []string{"PreferredEndpoints"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "clear preferred endpoints",
changeFuncs: []changeHostSetFunc{changePreferredEndpoints(nil)},
version: 2,
fieldMask: []string{"PreferredEndpoints"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpointsNil(),
checkPreferredEndpoints(nil),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
},
},
{
name: "update attributes (add)",
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{
"baz": "qux",
})},
version: 2,
fieldMask: []string{"attributes"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentAttributes(map[string]interface{}{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]interface{}{
"foo": "bar",
"baz": "qux",
}),
checkAttributes(map[string]interface{}{
"foo": "bar",
"baz": "qux",
}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(true),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (overwrite)",
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{
"foo": "baz",
})},
version: 2,
fieldMask: []string{"attributes"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentAttributes(map[string]interface{}{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]interface{}{
"foo": "baz",
}),
checkAttributes(map[string]interface{}{
"foo": "baz",
}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(true),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (null)",
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{
"foo": nil,
})},
version: 2,
fieldMask: []string{"attributes"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentAttributes(map[string]interface{}{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]interface{}{}),
checkAttributes(map[string]interface{}{}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(true),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (full null)",
changeFuncs: []changeHostSetFunc{changeAttributesNil()},
version: 2,
fieldMask: []string{"attributes"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentAttributes(map[string]interface{}{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]interface{}{}),
checkAttributes(map[string]interface{}{}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(true),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (combined)",
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{
"a": "b",
"foo": "baz",
})},
version: 2,
fieldMask: []string{"attributes.a", "attributes.foo"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentAttributes(map[string]interface{}{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]interface{}{
"a": "b",
"foo": "baz",
}),
checkAttributes(map[string]interface{}{
"a": "b",
"foo": "baz",
}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(true),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update name and preferred endpoints",
changeFuncs: []changeHostSetFunc{
changePreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
changeName("foo"),
},
version: 2,
fieldMask: []string{"name", "PreferredEndpoints"},
wantCheckFuncs: []checkFunc{
checkVersion(3),
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewName("foo"),
checkName("foo"),
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkUpdateSetRequestPersistedSecrets(map[string]interface{}{
"one": "two",
}),
checkNeedSync(false),
checkNumUpdated(1),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
}
// Create a host that will be verified as belonging to the created set below.
testHostExternIdPrefix := "test-host-external-id"
testHosts := make([]*Host, 3)
for i := 0; i <= 2; i++ {
testHosts[i] = TestHost(t, dbConn, testCatalog.PublicId, fmt.Sprintf("%s-%d", testHostExternIdPrefix, i))
}
// Finally define a function for bringing the test subject host
// set.
setupHostSet := func(t *testing.T, ctx context.Context) *HostSet {
t.Helper()
require := require.New(t)
set := TestSet(
t,
dbConn,
dbKmsCache,
sched,
testCatalog,
testPluginMap,
WithPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
)
// Set some (default) attributes on our test set
set.Attributes = mustMarshal(map[string]interface{}{
"foo": "bar",
})
// Set some fake sync detail to the set.
set.LastSyncTime = timestamp.New(time.Now())
set.NeedSync = false
numSetsUpdated, err := dbRW.Update(ctx, set, []string{"attributes", "LastSyncTime", "NeedSync"}, []string{})
require.NoError(err)
require.Equal(1, numSetsUpdated)
// Add some hosts to the host set.
TestSetMembers(t, dbConn, set.PublicId, testHosts)
t.Cleanup(func() {
t.Helper()
assert := assert.New(t)
n, err := dbRW.Delete(ctx, set)
assert.NoError(err)
assert.Equal(1, n)
})
return set
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
origSet := setupHostSet(t, ctx)
pluginMap := testPluginMap
if tt.withEmptyPluginMap {
pluginMap = make(map[string]plgpb.HostPluginServiceClient)
}
pluginError = tt.withPluginError
t.Cleanup(func() { pluginError = nil })
repo, err := NewRepository(dbRW, dbRW, dbKmsCache, sched, pluginMap)
require.NoError(err)
require.NotNil(repo)
workingSet := origSet.clone()
for _, cf := range tt.changeFuncs {
workingSet = cf(workingSet)
}
scopeId := testCatalog.ScopeId
if tt.withScopeId != nil {
scopeId = *tt.withScopeId
}
var gotHosts []*Host
var gotPlugin *hostplugin.Plugin
gotSet, gotHosts, gotPlugin, gotNumUpdated, err = repo.UpdateSet(ctx, scopeId, workingSet, tt.version, tt.fieldMask)
t.Cleanup(func() { gotOnUpdateCallCount = 0 })
if tt.wantIsErr != 0 {
require.Equal(db.NoRowsAffected, gotNumUpdated)
require.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
return
}
require.NoError(err)
require.Equal(1, gotOnUpdateCallCount)
t.Cleanup(func() { gotOnUpdateSetRequest = nil })
// Quick assertion that the set is not nil and that the plugin
// ID in the catalog referenced by the set matches the plugin
// ID in the returned plugin.
require.NotNil(gotSet)
require.NotNil(gotPlugin)
assert.Equal(testCatalog.PublicId, gotSet.CatalogId)
assert.Equal(testCatalog.PluginId, gotPlugin.PublicId)
// Also assert that the hosts returned by the request are the ones that belong to the set
wantHostMap := make(map[string]string, len(testHosts))
for _, h := range testHosts {
wantHostMap[h.PublicId] = h.ExternalId
}
gotHostMap := make(map[string]string, len(gotHosts))
for _, h := range gotHosts {
gotHostMap[h.PublicId] = h.ExternalId
}
assert.Equal(wantHostMap, gotHostMap)
// Perform checks
for _, check := range tt.wantCheckFuncs {
check(t, ctx)
}
})
}
}
func TestRepository_LookupSet(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")

Loading…
Cancel
Save