From c2cd6febcaf382d3e7a22eb620380b5457e19fe0 Mon Sep 17 00:00:00 2001 From: Chris Marchesi Date: Wed, 10 Nov 2021 17:08:42 -0800 Subject: [PATCH] internal/host/plugin: add UpdateSet (#1686) * internal/host/plugin: add UpdateSet This adds the UpdateSet repository function. UpdateSet runs the OnUpdateSet plugin call within the transaction as per the changes in #1683. The plugin call returns the updated set, the hosts in the set, the rows updated, and any error. * Ensure that attributes can be cleared by supplying a full null * move version increment to the end and track perferred endpoint update separately * move plugin call to end of transaction * only clone returned set once, and move host fetch call out of transaction * Request a sync when attributes are updated * Add test to assert that old preferred endpoints are being removed correctly --- internal/host/plugin/repository_host_set.go | 355 ++++++++- .../host/plugin/repository_host_set_test.go | 754 ++++++++++++++++++ 2 files changed, 1108 insertions(+), 1 deletion(-) diff --git a/internal/host/plugin/repository_host_set.go b/internal/host/plugin/repository_host_set.go index 5997ab363d..a4f2bfd94b 100644 --- a/internal/host/plugin/repository_host_set.go +++ b/internal/host/plugin/repository_host_set.go @@ -3,12 +3,14 @@ package plugin import ( "context" "fmt" + "strings" "time" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/host" + "github.com/hashicorp/boundary/internal/host/store" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/libs/endpoint" "github.com/hashicorp/boundary/internal/libs/patchstruct" @@ -20,6 +22,7 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/wrapperspb" ) // CreateSet inserts s into the repository and returns a new HostSet @@ -168,6 +171,344 @@ func (r *Repository) CreateSet(ctx context.Context, scopeId string, s *HostSet, return returnedHostSet, plg, nil } +// UpdateSet updates the repository for host set entry s with the +// values populated, for the fields listed in fieldMask. It returns a +// new HostSet containing the updated values, the hosts in the set, +// and a count of the number of records updated. s is not changed. +// +// s must contain a valid PublicId and CatalogId. Name, Description, +// Attributes, and PreferredEndpoints can be updated. Name must be +// unique among all sets associated with a single catalog. +// +// An attribute of s will be set to NULL in the database if the +// attribute in s is the zero value and it is included in fieldMask. +// Note that this does not apply to s.Attributes - a null +// s.Attributes is a no-op for modifications. Rather, if fields need +// to be reset, its field in c.Attributes should individually set to +// null. +// +// Updates are sent to OnUpdateSet with a full copy of both the +// current set, the state of the new set should it be updated, along +// with its parent host catalog and persisted state (can include +// secrets). This is a stateless call and does not affect the final +// record written, but some plugins may perform some actions on this +// call. Update of the record in the database is aborted if this call +// fails. +func (r *Repository) UpdateSet(ctx context.Context, scopeId string, s *HostSet, version uint32, fieldMask []string, opt ...Option) (*HostSet, []*Host, *hostplugin.Plugin, int, error) { + const op = "plugin.(Repository).UpdateSet" + if s == nil { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "nil HostSet") + } + if s.HostSet == nil { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "nil embedded HostSet") + } + if s.PublicId == "" { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "no public id") + } + if scopeId == "" { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, "no scope id") + } + if len(fieldMask) == 0 { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.EmptyFieldMask, op, "empty field mask") + } + + // Get the old set first. We patch the record first before + // sending it to the DB for updating so that we can run on + // OnUpdateSet. Note that the field masks are still used for + // updating. + sets, plg, err := r.getSets(ctx, s.PublicId, "") + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + + if len(sets) == 0 { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.RecordNotFound, op, fmt.Sprintf("host set id %q not found", s.PublicId)) + } + + if len(sets) != 1 { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.Internal, op, fmt.Sprintf("unexpected amount of sets found, want=1, got=%d", len(sets))) + } + + currentSet := sets[0] + + // Assert the version of the current set to make sure we're + // updating the correct one. + if currentSet.GetVersion() != version { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.VersionMismatch, op, fmt.Sprintf("set version mismatch, want=%d, got=%d", currentSet.GetVersion(), version)) + } + + // Clone the set so that we can set fields. + newSet := currentSet.clone() + var updateAttributes bool + var dbMask, nullFields []string + const ( + endpointOpNoop = "endpointOpNoop" + endpointOpDelete = "endpointOpDelete" + endpointOpUpdate = "endpointOpUpdate" + ) + var endpointOp string = endpointOpNoop + for _, f := range fieldMask { + switch { + case strings.EqualFold("name", f) && s.Name == "": + nullFields = append(nullFields, "name") + newSet.Name = s.Name + case strings.EqualFold("name", f) && s.Name != "": + dbMask = append(dbMask, "name") + newSet.Name = s.Name + case strings.EqualFold("description", f) && s.Description == "": + nullFields = append(nullFields, "description") + newSet.Description = s.Description + case strings.EqualFold("description", f) && s.Description != "": + dbMask = append(dbMask, "description") + newSet.Description = s.Description + case strings.EqualFold("PreferredEndpoints", f) && len(s.PreferredEndpoints) == 0: + endpointOp = endpointOpDelete + newSet.PreferredEndpoints = s.PreferredEndpoints + case strings.EqualFold("PreferredEndpoints", f) && len(s.PreferredEndpoints) != 0: + endpointOp = endpointOpUpdate + newSet.PreferredEndpoints = s.PreferredEndpoints + case strings.EqualFold("attributes", strings.Split(f, ".")[0]): + // Flag attributes for updating. While multiple masks may be + // sent, we only need to do this once. + updateAttributes = true + + default: + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidFieldMask, op, fmt.Sprintf("invalid field mask: %s", f)) + } + } + + if updateAttributes { + dbMask = append(dbMask, "attributes") + if s.Attributes != nil { + newSet.Attributes, err = patchstruct.PatchBytes(newSet.Attributes, s.Attributes) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("error in set attribute JSON")) + } + } else { + newSet.Attributes = make([]byte, 0) + } + + // Flag the record as needing a sync since we've updated + // attributes. + dbMask = append(dbMask, "NeedSync") + newSet.NeedSync = true + } + + // Get the host catalog for the set and its persisted data. + catalog, persisted, err := r.getCatalog(ctx, newSet.CatalogId) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("error looking up catalog with id %q", newSet.CatalogId))) + } + + // Assert that the catalog scope ID and supplied scope ID match. + if catalog.ScopeId != scopeId { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("catalog id %q not in scope id %q", newSet.CatalogId, scopeId)) + } + + // Get the plugin client. + plgClient, ok := r.plugins[catalog.GetPluginId()] + if !ok || plgClient == nil { + return nil, nil, nil, db.NoRowsAffected, errors.New(ctx, errors.Internal, op, fmt.Sprintf("plugin %q not available", catalog.GetPluginId())) + } + + // Convert the catalog values to API protobuf values, which is what + // we use for the plugin hook calls. + plgHc, err := toPluginCatalog(ctx, catalog) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + currentPlgSet, err := toPluginSet(ctx, currentSet) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + newPlgSet, err := toPluginSet(ctx, newSet) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + + // Get the preferred endpoints to write out. + var preferredEndpoints []interface{} + if endpointOp == endpointOpUpdate { + preferredEndpoints = make([]interface{}, 0, len(newSet.PreferredEndpoints)) + for i, e := range newSet.PreferredEndpoints { + obj, err := host.NewPreferredEndpoint(ctx, newSet.PublicId, uint32(i+1), e) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + preferredEndpoints = append(preferredEndpoints, obj) + } + } + + oplogWrapper, err := r.kms.GetWrapper(ctx, scopeId, kms.KeyPurposeOplog) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get oplog wrapper")) + } + + // If the call to the plugin succeeded, we do not want to call it again if + // the transaction failed and is being retried. + var pluginCalledSuccessfully bool + + var setUpdated, preferredEndpointsUpdated bool + var returnedSet *HostSet + var hosts []*Host + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(r db.Reader, w db.Writer) error { + returnedSet = newSet.clone() + msgs := make([]*oplog.Message, 0, len(preferredEndpoints)+2) + ticket, err := w.GetTicket(s) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get ticket")) + } + + if len(dbMask) != 0 || len(nullFields) != 0 { + var hsOplogMsg oplog.Message + numUpdated, err := w.Update( + ctx, + returnedSet, + dbMask, + nullFields, + db.NewOplogMsg(&hsOplogMsg), + db.WithVersion(&version), + ) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + if numUpdated != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("expected 1 set to be updated, got %d", numUpdated)) + } + + setUpdated = true + msgs = append(msgs, &hsOplogMsg) + } + + switch endpointOp { + case endpointOpDelete: + // Delete all old endpoint entries. + var peDeleteOplogMsg oplog.Message + pep := &host.PreferredEndpoint{ + PreferredEndpoint: &store.PreferredEndpoint{ + HostSetId: newSet.PublicId, + Priority: 1, + }, + } + _, err := w.Delete(ctx, pep, db.WithWhere("host_set_id = ?", newSet.PublicId), db.NewOplogMsg(&peDeleteOplogMsg)) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + // Only append the oplog message if an operation was actually + // performed. + if peDeleteOplogMsg.Message != nil { + preferredEndpointsUpdated = true + msgs = append(msgs, &peDeleteOplogMsg) + } + + case endpointOpUpdate: + // Delete all old endpoint entries. + var peDeleteOplogMsg oplog.Message + pep := &host.PreferredEndpoint{ + PreferredEndpoint: &store.PreferredEndpoint{ + HostSetId: newSet.PublicId, + Priority: 1, + }, + } + _, err := w.Delete(ctx, pep, db.WithWhere("host_set_id = ?", newSet.PublicId), db.NewOplogMsg(&peDeleteOplogMsg)) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + // Only append the oplog message if an operation was actually + // performed. + if peDeleteOplogMsg.Message != nil { + msgs = append(msgs, &peDeleteOplogMsg) + } + + // Create the new entries. + peCreateOplogMsgs := make([]*oplog.Message, 0, len(preferredEndpoints)) + if err := w.CreateItems(ctx, preferredEndpoints, db.NewOplogMsgs(&peCreateOplogMsgs)); err != nil { + return err + } + + preferredEndpointsUpdated = true + msgs = append(msgs, peCreateOplogMsgs...) + } + + if !setUpdated && preferredEndpointsUpdated { + returnedSet.Version = uint32(version) + 1 + var hsOplogMsg oplog.Message + numUpdated, err := w.Update( + ctx, + returnedSet, + []string{"version"}, + []string{}, + db.NewOplogMsg(&hsOplogMsg), + db.WithVersion(&version), + ) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + if numUpdated != 1 { + return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("expected 1 set to be updated, got %d", numUpdated)) + } + + msgs = append(msgs, &hsOplogMsg) + } + + if len(msgs) != 0 { + metadata := s.oplog(oplog.OpType_OP_TYPE_UPDATE) + if err := w.WriteOplogEntryWith(ctx, oplogWrapper, ticket, metadata, msgs); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog")) + } + } + + if !pluginCalledSuccessfully { + _, err = plgClient.OnUpdateSet(ctx, &plgpb.OnUpdateSetRequest{ + CurrentSet: currentPlgSet, + NewSet: newPlgSet, + Catalog: plgHc, + Persisted: persisted, + }) + if err != nil { + if status.Code(err) != codes.Unimplemented { + return errors.Wrap(ctx, err, op) + } + } + } + + return nil + }, + ) + + if err != nil { + if errors.IsUniqueError(err) { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s: name %s already exists", newSet.PublicId, newSet.Name))) + } + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", newSet.PublicId))) + } + + hosts, err = listHostBySetIds(ctx, r.reader, []string{returnedSet.PublicId}, opt...) + if err != nil { + return nil, nil, nil, db.NoRowsAffected, errors.Wrap(ctx, err, op) + } + + var numUpdated int + if setUpdated || preferredEndpointsUpdated { + numUpdated = 1 + } + + if updateAttributes { + // Request a host sync since we have updated attributes. + _ = r.scheduler.UpdateJobNextRunInAtLeast(ctx, setSyncJobName, 0) + } + + return returnedSet, hosts, plg, numUpdated, nil +} + // LookupSet will look up a host set in the repository and return the host set, // as well as host IDs that match. If the host set is not found, it will return // nil, nil, nil. No options are currently supported. @@ -357,8 +698,20 @@ func toPluginSet(ctx context.Context, in *HostSet) (*pb.HostSet, error) { if in == nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "nil storage plugin") } + + var name, description *wrapperspb.StringValue + if inName := in.GetName(); inName != "" { + name = wrapperspb.String(inName) + } + if inDescription := in.GetDescription(); inDescription != "" { + description = wrapperspb.String(inDescription) + } + hs := &pb.HostSet{ - Id: in.GetPublicId(), + Id: in.GetPublicId(), + Name: name, + Description: description, + PreferredEndpoints: in.GetPreferredEndpoints(), } if in.GetAttributes() != nil { attrs := &structpb.Struct{} diff --git a/internal/host/plugin/repository_host_set_test.go b/internal/host/plugin/repository_host_set_test.go index 017dad03ea..c3d71e9143 100644 --- a/internal/host/plugin/repository_host_set_test.go +++ b/internal/host/plugin/repository_host_set_test.go @@ -11,6 +11,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/host" "github.com/hashicorp/boundary/internal/host/plugin/store" @@ -18,6 +19,7 @@ import ( "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" hostplg "github.com/hashicorp/boundary/internal/plugin/host" + hostplugin "github.com/hashicorp/boundary/internal/plugin/host" "github.com/hashicorp/boundary/internal/scheduler" plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin" "github.com/stretchr/testify/assert" @@ -25,6 +27,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/wrapperspb" ) func TestRepository_CreateSet(t *testing.T) { @@ -352,6 +355,757 @@ func TestRepository_CreateSet(t *testing.T) { }) } +func TestRepository_UpdateSet(t *testing.T) { + ctx := context.Background() + dbConn, _ := db.TestSetup(t, "postgres") + dbRW := db.New(dbConn) + dbWrapper := db.TestWrapper(t) + sched := scheduler.TestScheduler(t, dbConn, dbWrapper) + dbKmsCache := kms.TestKms(t, dbConn, dbWrapper) + _, projectScope := iam.TestScopes(t, iam.TestRepo(t, dbConn, dbWrapper)) + + // Define a plugin "manager", basically just a map with a mock + // plugin in it. This also includes functionality to capture the + // state, and set an error and the returned secrets to nil. Note + // that the way that this is set up means that the tests cannot run + // in parallel, but there could be other factors affecting that as + // well. + var gotOnUpdateCallCount int + var gotOnUpdateSetRequest *plgpb.OnUpdateSetRequest + var pluginError error + testPlugin := hostplg.TestPlugin(t, dbConn, "test") + testPluginMap := map[string]plgpb.HostPluginServiceClient{ + testPlugin.GetPublicId(): &WrappingPluginClient{ + Server: &TestPluginServer{ + OnUpdateSetFn: func(_ context.Context, req *plgpb.OnUpdateSetRequest) (*plgpb.OnUpdateSetResponse, error) { + gotOnUpdateCallCount++ + gotOnUpdateSetRequest = req + return &plgpb.OnUpdateSetResponse{}, pluginError + }, + }, + }, + } + + // Set up a test catalog and the secrets for it + testCatalog := TestCatalog(t, dbConn, projectScope.PublicId, testPlugin.GetPublicId()) + testCatalogSecret, err := newHostCatalogSecret(ctx, testCatalog.GetPublicId(), mustStruct(map[string]interface{}{ + "one": "two", + })) + require.NoError(t, err) + scopeWrapper, err := dbKmsCache.GetWrapper(ctx, testCatalog.GetScopeId(), kms.KeyPurposeDatabase) + require.NoError(t, err) + require.NoError(t, testCatalogSecret.encrypt(ctx, scopeWrapper)) + testCatalogSecretQ, testCatalogSecretV := testCatalogSecret.upsertQuery() + secretsUpdated, err := dbRW.Exec(ctx, testCatalogSecretQ, testCatalogSecretV) + require.NoError(t, err) + require.Equal(t, 1, secretsUpdated) + + // Create a test duplicate set. We don't use this set, it just + // exists to ensure that we can test for conflicts when setting + // the name. + const testDuplicateSetName = "duplicate-set-name" + testDuplicateSet := TestSet(t, dbConn, dbKmsCache, sched, testCatalog, testPluginMap) + testDuplicateSet.Name = testDuplicateSetName + setsUpdated, err := dbRW.Update(ctx, testDuplicateSet, []string{"name"}, []string{}) + require.NoError(t, err) + require.Equal(t, 1, setsUpdated) + + // Define some helpers here to make the test table more readable. + type changeHostSetFunc func(s *HostSet) *HostSet + + changeSetToNil := func() changeHostSetFunc { + return func(_ *HostSet) *HostSet { + return nil + } + } + + changeEmbeddedSetToNil := func() changeHostSetFunc { + return func(s *HostSet) *HostSet { + s.HostSet = nil + return s + } + } + + changePublicId := func(v string) changeHostSetFunc { + return func(s *HostSet) *HostSet { + s.PublicId = v + return s + } + } + + changeName := func(v string) changeHostSetFunc { + return func(s *HostSet) *HostSet { + s.Name = v + return s + } + } + + changeDescription := func(s string) changeHostSetFunc { + return func(c *HostSet) *HostSet { + c.Description = s + return c + } + } + + changePreferredEndpoints := func(s []string) changeHostSetFunc { + return func(c *HostSet) *HostSet { + c.PreferredEndpoints = s + return c + } + } + + changeAttributes := func(m map[string]interface{}) changeHostSetFunc { + return func(c *HostSet) *HostSet { + c.Attributes = mustMarshal(m) + return c + } + } + + changeAttributesNil := func() changeHostSetFunc { + return func(c *HostSet) *HostSet { + c.Attributes = nil + return c + } + } + + // Define some checks that will be used in the below tests. Some of + // these are re-used, so we define them here. Most of these are + // assertions and no particular one is non-fatal in that they will + // stop execution. Note that these are executed after wantIsErr, so + // if that is set in an individual table test, these will not be + // executed. + // + // Note that we define some state here, similar to how we + // previously defined gotOnUpdateCatalogRequest above next the + // plugin map. + type checkFunc func(t *testing.T, ctx context.Context) + var ( + gotSet *HostSet + gotNumUpdated int + ) + + checkVersion := func(want uint32) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotSet.Version) + } + } + + checkName := func(want string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotSet.Name) + } + } + + checkDescription := func(want string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotSet.Description) + } + } + + checkPreferredEndpoints := func(want []string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotSet.PreferredEndpoints) + } + } + + checkAttributes := func(want map[string]interface{}) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + require := require.New(t) + st := &structpb.Struct{} + require.NoError(proto.Unmarshal(gotSet.Attributes, st)) + assert.Empty(cmp.Diff(mustStruct(want), st, protocmp.Transform())) + } + } + + checkNeedSync := func(want bool) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotSet.NeedSync) + } + } + + checkUpdateSetRequestCurrentNameNil := func() checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Nil(gotOnUpdateSetRequest.CurrentSet.Name) + } + } + + checkUpdateSetRequestNewName := func(want string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(wrapperspb.String(want), gotOnUpdateSetRequest.NewSet.Name) + } + } + + checkUpdateSetRequestNewNameNil := func() checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Nil(gotOnUpdateSetRequest.NewSet.Name) + } + } + + checkUpdateSetRequestCurrentDescriptionNil := func() checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Nil(gotOnUpdateSetRequest.CurrentSet.Description) + } + } + + checkUpdateSetRequestNewDescription := func(want string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(wrapperspb.String(want), gotOnUpdateSetRequest.NewSet.Description) + } + } + + checkUpdateSetRequestNewDescriptionNil := func() checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Nil(gotOnUpdateSetRequest.NewSet.Description) + } + } + + checkUpdateSetRequestCurrentPreferredEndpoints := func(want []string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotOnUpdateSetRequest.CurrentSet.PreferredEndpoints) + } + } + + checkUpdateSetRequestNewPreferredEndpoints := func(want []string) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotOnUpdateSetRequest.NewSet.PreferredEndpoints) + } + } + + checkUpdateSetRequestNewPreferredEndpointsNil := func() checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Nil(gotOnUpdateSetRequest.NewSet.PreferredEndpoints) + } + } + + checkUpdateSetRequestCurrentAttributes := func(want map[string]interface{}) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.CurrentSet.Attributes, protocmp.Transform())) + } + } + + checkUpdateSetRequestNewAttributes := func(want map[string]interface{}) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.NewSet.Attributes, protocmp.Transform())) + } + } + + checkUpdateSetRequestPersistedSecrets := func(want map[string]interface{}) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Empty(cmp.Diff(mustStruct(want), gotOnUpdateSetRequest.Persisted.Secrets, protocmp.Transform())) + } + } + + checkNumUpdated := func(want int) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.Equal(want, gotNumUpdated) + } + } + + checkVerifySetOplog := func(op oplog.OpType) checkFunc { + return func(t *testing.T, ctx context.Context) { + t.Helper() + assert := assert.New(t) + assert.NoError( + db.TestVerifyOplog( + t, + dbRW, + gotSet.PublicId, + db.WithOperation(op), + db.WithCreateNotBefore(10*time.Second), + ), + ) + } + } + + tests := []struct { + name string + withScopeId *string + withEmptyPluginMap bool + withPluginError error + changeFuncs []changeHostSetFunc + version uint32 + fieldMask []string + wantCheckFuncs []checkFunc + wantIsErr errors.Code + }{ + { + name: "nil set", + changeFuncs: []changeHostSetFunc{changeSetToNil()}, + wantIsErr: errors.InvalidParameter, + }, + { + name: "nil embedded set", + changeFuncs: []changeHostSetFunc{changeEmbeddedSetToNil()}, + wantIsErr: errors.InvalidParameter, + }, + { + name: "missing public id", + changeFuncs: []changeHostSetFunc{changePublicId("")}, + wantIsErr: errors.InvalidParameter, + }, + { + name: "missing scope id", + withScopeId: func() *string { a := ""; return &a }(), + wantIsErr: errors.InvalidParameter, + }, + { + name: "empty field mask", + fieldMask: nil, // Should be testing on len + wantIsErr: errors.EmptyFieldMask, + }, + { + name: "bad set id", + changeFuncs: []changeHostSetFunc{changePublicId("badid")}, + fieldMask: []string{"name"}, + wantIsErr: errors.RecordNotFound, + }, + { + name: "version mismatch", + changeFuncs: []changeHostSetFunc{changeName("foo")}, + version: 1, + fieldMask: []string{"name"}, + wantIsErr: errors.VersionMismatch, + }, + { + name: "mismatched scope id to catalog scope", + withScopeId: func() *string { a := "badid"; return &a }(), + version: 2, + fieldMask: []string{"name"}, + wantIsErr: errors.InvalidParameter, + }, + { + name: "plugin lookup error", + withEmptyPluginMap: true, + version: 2, + fieldMask: []string{"name"}, + wantIsErr: errors.Internal, + }, + { + name: "plugin invocation error", + withPluginError: errors.New(context.Background(), errors.Internal, "TestRepository_UpdateSet/plugin_invocation_error", "test plugin error"), + version: 2, + fieldMask: []string{"name"}, + wantIsErr: errors.Internal, + }, + { + name: "update name (duplicate)", + changeFuncs: []changeHostSetFunc{changeName(testDuplicateSetName)}, + version: 2, + fieldMask: []string{"name"}, + wantIsErr: errors.NotUnique, + }, + { + name: "update name", + changeFuncs: []changeHostSetFunc{changeName("foo")}, + version: 2, + fieldMask: []string{"name"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentNameNil(), + checkUpdateSetRequestNewName("foo"), + checkName("foo"), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update name to same", + changeFuncs: []changeHostSetFunc{changeName("")}, + version: 2, + fieldMask: []string{"name"}, + wantCheckFuncs: []checkFunc{ + checkVersion(2), // Version remains same even though row is updated + checkUpdateSetRequestCurrentNameNil(), + checkUpdateSetRequestNewNameNil(), + checkName(""), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update description", + changeFuncs: []changeHostSetFunc{changeDescription("foo")}, + version: 2, + fieldMask: []string{"description"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentDescriptionNil(), + checkUpdateSetRequestNewDescription("foo"), + checkDescription("foo"), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update description to same", + changeFuncs: []changeHostSetFunc{changeDescription("")}, + version: 2, + fieldMask: []string{"description"}, + wantCheckFuncs: []checkFunc{ + checkVersion(2), // Version remains same even though row is updated + checkUpdateSetRequestCurrentDescriptionNil(), + checkUpdateSetRequestNewDescriptionNil(), + checkDescription(""), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update preferred endpoints", + changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})}, + version: 2, + fieldMask: []string{"PreferredEndpoints"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), + checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "clear preferred endpoints", + changeFuncs: []changeHostSetFunc{changePreferredEndpoints(nil)}, + version: 2, + fieldMask: []string{"PreferredEndpoints"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), + checkUpdateSetRequestNewPreferredEndpointsNil(), + checkPreferredEndpoints(nil), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + }, + }, + { + name: "update attributes (add)", + changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ + "baz": "qux", + })}, + version: 2, + fieldMask: []string{"attributes"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ + "foo": "bar", + }), + checkUpdateSetRequestNewAttributes(map[string]interface{}{ + "foo": "bar", + "baz": "qux", + }), + checkAttributes(map[string]interface{}{ + "foo": "bar", + "baz": "qux", + }), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(true), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update attributes (overwrite)", + changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ + "foo": "baz", + })}, + version: 2, + fieldMask: []string{"attributes"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ + "foo": "bar", + }), + checkUpdateSetRequestNewAttributes(map[string]interface{}{ + "foo": "baz", + }), + checkAttributes(map[string]interface{}{ + "foo": "baz", + }), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(true), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update attributes (null)", + changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ + "foo": nil, + })}, + version: 2, + fieldMask: []string{"attributes"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ + "foo": "bar", + }), + checkUpdateSetRequestNewAttributes(map[string]interface{}{}), + checkAttributes(map[string]interface{}{}), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(true), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update attributes (full null)", + changeFuncs: []changeHostSetFunc{changeAttributesNil()}, + version: 2, + fieldMask: []string{"attributes"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ + "foo": "bar", + }), + checkUpdateSetRequestNewAttributes(map[string]interface{}{}), + checkAttributes(map[string]interface{}{}), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(true), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update attributes (combined)", + changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ + "a": "b", + "foo": "baz", + })}, + version: 2, + fieldMask: []string{"attributes.a", "attributes.foo"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ + "foo": "bar", + }), + checkUpdateSetRequestNewAttributes(map[string]interface{}{ + "a": "b", + "foo": "baz", + }), + checkAttributes(map[string]interface{}{ + "a": "b", + "foo": "baz", + }), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(true), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + { + name: "update name and preferred endpoints", + changeFuncs: []changeHostSetFunc{ + changePreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + changeName("foo"), + }, + version: 2, + fieldMask: []string{"name", "PreferredEndpoints"}, + wantCheckFuncs: []checkFunc{ + checkVersion(3), + checkUpdateSetRequestCurrentNameNil(), + checkUpdateSetRequestNewName("foo"), + checkName("foo"), + checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), + checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + checkNeedSync(false), + checkNumUpdated(1), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, + } + + // Create a host that will be verified as belonging to the created set below. + testHostExternIdPrefix := "test-host-external-id" + testHosts := make([]*Host, 3) + for i := 0; i <= 2; i++ { + testHosts[i] = TestHost(t, dbConn, testCatalog.PublicId, fmt.Sprintf("%s-%d", testHostExternIdPrefix, i)) + } + + // Finally define a function for bringing the test subject host + // set. + setupHostSet := func(t *testing.T, ctx context.Context) *HostSet { + t.Helper() + require := require.New(t) + + set := TestSet( + t, + dbConn, + dbKmsCache, + sched, + testCatalog, + testPluginMap, + WithPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), + ) + // Set some (default) attributes on our test set + set.Attributes = mustMarshal(map[string]interface{}{ + "foo": "bar", + }) + + // Set some fake sync detail to the set. + set.LastSyncTime = timestamp.New(time.Now()) + set.NeedSync = false + + numSetsUpdated, err := dbRW.Update(ctx, set, []string{"attributes", "LastSyncTime", "NeedSync"}, []string{}) + require.NoError(err) + require.Equal(1, numSetsUpdated) + + // Add some hosts to the host set. + TestSetMembers(t, dbConn, set.PublicId, testHosts) + + t.Cleanup(func() { + t.Helper() + assert := assert.New(t) + n, err := dbRW.Delete(ctx, set) + assert.NoError(err) + assert.Equal(1, n) + }) + + return set + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + assert := assert.New(t) + origSet := setupHostSet(t, ctx) + + pluginMap := testPluginMap + if tt.withEmptyPluginMap { + pluginMap = make(map[string]plgpb.HostPluginServiceClient) + } + pluginError = tt.withPluginError + t.Cleanup(func() { pluginError = nil }) + repo, err := NewRepository(dbRW, dbRW, dbKmsCache, sched, pluginMap) + require.NoError(err) + require.NotNil(repo) + + workingSet := origSet.clone() + for _, cf := range tt.changeFuncs { + workingSet = cf(workingSet) + } + + scopeId := testCatalog.ScopeId + if tt.withScopeId != nil { + scopeId = *tt.withScopeId + } + + var gotHosts []*Host + var gotPlugin *hostplugin.Plugin + gotSet, gotHosts, gotPlugin, gotNumUpdated, err = repo.UpdateSet(ctx, scopeId, workingSet, tt.version, tt.fieldMask) + t.Cleanup(func() { gotOnUpdateCallCount = 0 }) + if tt.wantIsErr != 0 { + require.Equal(db.NoRowsAffected, gotNumUpdated) + require.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err) + return + } + require.NoError(err) + require.Equal(1, gotOnUpdateCallCount) + t.Cleanup(func() { gotOnUpdateSetRequest = nil }) + + // Quick assertion that the set is not nil and that the plugin + // ID in the catalog referenced by the set matches the plugin + // ID in the returned plugin. + require.NotNil(gotSet) + require.NotNil(gotPlugin) + assert.Equal(testCatalog.PublicId, gotSet.CatalogId) + assert.Equal(testCatalog.PluginId, gotPlugin.PublicId) + + // Also assert that the hosts returned by the request are the ones that belong to the set + wantHostMap := make(map[string]string, len(testHosts)) + for _, h := range testHosts { + wantHostMap[h.PublicId] = h.ExternalId + } + gotHostMap := make(map[string]string, len(gotHosts)) + for _, h := range gotHosts { + gotHostMap[h.PublicId] = h.ExternalId + } + assert.Equal(wantHostMap, gotHostMap) + + // Perform checks + for _, check := range tt.wantCheckFuncs { + check(t, ctx) + } + }) + } +} + func TestRepository_LookupSet(t *testing.T) { ctx := context.Background() conn, _ := db.TestSetup(t, "postgres")