From 5cfcc41f00387dc30e5814d79090fc8e4fd8d00e Mon Sep 17 00:00:00 2001 From: Todd Date: Fri, 12 Nov 2021 09:04:48 -0700 Subject: [PATCH] Fix bug for adding preferred endpoint when non existed prior. (#1692) --- internal/host/plugin/repository_host_set.go | 34 +-- .../host/plugin/repository_host_set_test.go | 212 +++++++++++------- 2 files changed, 152 insertions(+), 94 deletions(-) diff --git a/internal/host/plugin/repository_host_set.go b/internal/host/plugin/repository_host_set.go index 69f3aef047..109f8755e1 100644 --- a/internal/host/plugin/repository_host_set.go +++ b/internal/host/plugin/repository_host_set.go @@ -410,23 +410,25 @@ func (r *Repository) UpdateSet(ctx context.Context, scopeId string, s *HostSet, } case endpointOpUpdate: - // Delete all old endpoint entries. - var peps []interface{} - for i := 1; i <= len(currentSet.PreferredEndpoints); i++ { - p := host.AllocPreferredEndpoint() - p.HostSetId, p.Priority = currentSet.GetPublicId(), uint32(i) - peps = append(peps, p) - } - deleteOplogMsgs := make([]*oplog.Message, 0, len(peps)) - _, err := w.DeleteItems(ctx, peps, db.WithDebug(true), db.NewOplogMsgs(&deleteOplogMsgs)) - if err != nil { - return errors.Wrap(ctx, err, op) - } + if len(currentSet.PreferredEndpoints) > 0 { + // Delete all old endpoint entries. + var peps []interface{} + for i := 1; i <= len(currentSet.PreferredEndpoints); i++ { + p := host.AllocPreferredEndpoint() + p.HostSetId, p.Priority = currentSet.GetPublicId(), uint32(i) + peps = append(peps, p) + } + deleteOplogMsgs := make([]*oplog.Message, 0, len(peps)) + _, err := w.DeleteItems(ctx, peps, db.WithDebug(true), db.NewOplogMsgs(&deleteOplogMsgs)) + if err != nil { + return errors.Wrap(ctx, err, op) + } - // Only append the oplog message if an operation was actually - // performed. - if len(deleteOplogMsgs) > 0 { - msgs = append(msgs, deleteOplogMsgs...) + // Only append the oplog message if an operation was actually + // performed. + if len(deleteOplogMsgs) > 0 { + msgs = append(msgs, deleteOplogMsgs...) + } } // Create the new entries. diff --git a/internal/host/plugin/repository_host_set_test.go b/internal/host/plugin/repository_host_set_test.go index 62206535a9..5bda4438b2 100644 --- a/internal/host/plugin/repository_host_set_test.go +++ b/internal/host/plugin/repository_host_set_test.go @@ -654,8 +654,82 @@ func TestRepository_UpdateSet(t *testing.T) { } } + // Finally define a function for bringing the test subject host + // set. + setupBareHostSet := func(t *testing.T, ctx context.Context) (*HostSet, []*Host) { + t.Helper() + set := TestSet( + t, + dbConn, + dbKmsCache, + sched, + testCatalog, + dummyPluginMap, + ) + + t.Cleanup(func() { + t.Helper() + assert := assert.New(t) + n, err := dbRW.Delete(ctx, set) + assert.NoError(err) + assert.Equal(1, n) + }) + + return set, nil + } + + // Create a host that will be verified as belonging to the created set below. + testHostExternIdPrefix := "test-host-external-id" + testHosts := make([]*Host, 3) + for i := 0; i <= 2; i++ { + testHosts[i] = TestHost(t, dbConn, testCatalog.PublicId, fmt.Sprintf("%s-%d", testHostExternIdPrefix, i)) + } + + // Finally define a function for bringing the test subject host + // set. + setupHostSet := func(t *testing.T, ctx context.Context) (*HostSet, []*Host) { + t.Helper() + require := require.New(t) + + set := TestSet( + t, + dbConn, + dbKmsCache, + sched, + testCatalog, + dummyPluginMap, + WithPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), + ) + // Set some (default) attributes on our test set + set.Attributes = mustMarshal(map[string]interface{}{ + "foo": "bar", + }) + + // Set some fake sync detail to the set. + set.LastSyncTime = timestamp.New(time.Now()) + set.NeedSync = false + + numSetsUpdated, err := dbRW.Update(ctx, set, []string{"attributes", "LastSyncTime", "NeedSync"}, []string{}) + require.NoError(err) + require.Equal(1, numSetsUpdated) + + // Add some hosts to the host set. + TestSetMembers(t, dbConn, set.PublicId, testHosts) + + t.Cleanup(func() { + t.Helper() + assert := assert.New(t) + n, err := dbRW.Delete(ctx, set) + assert.NoError(err) + assert.Equal(1, n) + }) + + return set, testHosts + } + tests := []struct { name string + startingSet func(*testing.T, context.Context) (*HostSet, []*Host) withScopeId *string withEmptyPluginMap bool withPluginError error @@ -668,74 +742,81 @@ func TestRepository_UpdateSet(t *testing.T) { }{ { name: "nil set", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changeSetToNil()}, wantIsErr: errors.InvalidParameter, }, { name: "nil embedded set", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changeEmbeddedSetToNil()}, wantIsErr: errors.InvalidParameter, }, { name: "missing public id", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changePublicId("")}, wantIsErr: errors.InvalidParameter, }, { name: "missing scope id", + startingSet: setupBareHostSet, withScopeId: func() *string { a := ""; return &a }(), wantIsErr: errors.InvalidParameter, }, { - name: "empty field mask", - fieldMask: nil, // Should be testing on len - wantIsErr: errors.EmptyFieldMask, + name: "empty field mask", + startingSet: setupBareHostSet, + fieldMask: nil, // Should be testing on len + wantIsErr: errors.EmptyFieldMask, }, { name: "bad set id", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changePublicId("badid")}, fieldMask: []string{"name"}, wantIsErr: errors.RecordNotFound, }, { name: "version mismatch", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changeName("foo")}, - version: 1, + version: 4, fieldMask: []string{"name"}, wantIsErr: errors.VersionMismatch, }, { name: "mismatched scope id to catalog scope", + startingSet: setupBareHostSet, withScopeId: func() *string { a := "badid"; return &a }(), - version: 2, fieldMask: []string{"name"}, wantIsErr: errors.InvalidParameter, }, { name: "plugin lookup error", + startingSet: setupBareHostSet, withEmptyPluginMap: true, - version: 2, fieldMask: []string{"name"}, wantIsErr: errors.Internal, }, { name: "plugin invocation error", + startingSet: setupBareHostSet, withPluginError: errors.New(context.Background(), errors.Internal, "TestRepository_UpdateSet/plugin_invocation_error", "test plugin error"), - version: 2, fieldMask: []string{"name"}, wantIsErr: errors.Internal, }, { name: "update name (duplicate)", + startingSet: setupBareHostSet, changeFuncs: []changeHostSetFunc{changeName(testDuplicateSetName)}, - version: 2, fieldMask: []string{"name"}, wantIsErr: errors.NotUnique, }, { name: "update name", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeName("foo")}, - version: 2, fieldMask: []string{"name"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentNameNil(), @@ -753,8 +834,8 @@ func TestRepository_UpdateSet(t *testing.T) { }, { name: "update name to same", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeName("")}, - version: 2, fieldMask: []string{"name"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentNameNil(), @@ -772,8 +853,8 @@ func TestRepository_UpdateSet(t *testing.T) { }, { name: "update description", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeDescription("foo")}, - version: 2, fieldMask: []string{"description"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentDescriptionNil(), @@ -791,8 +872,8 @@ func TestRepository_UpdateSet(t *testing.T) { }, { name: "update description to same", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeDescription("")}, - version: 2, fieldMask: []string{"description"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentDescriptionNil(), @@ -808,10 +889,29 @@ func TestRepository_UpdateSet(t *testing.T) { checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), }, }, + { + name: "add preferred endpoints", + startingSet: setupBareHostSet, + changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})}, + fieldMask: []string{"PreferredEndpoints"}, + wantCheckPluginReqFuncs: []checkPluginReqFunc{ + checkUpdateSetRequestCurrentPreferredEndpoints(nil), + checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkUpdateSetRequestPersistedSecrets(map[string]interface{}{ + "one": "two", + }), + }, + wantCheckSetFuncs: []checkHostSetFunc{ + checkVersion(2), + checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}), + checkNeedSync(true), + checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE), + }, + }, { name: "update preferred endpoints", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})}, - version: 2, fieldMask: []string{"PreferredEndpoints"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), @@ -829,8 +929,8 @@ func TestRepository_UpdateSet(t *testing.T) { }, { name: "clear preferred endpoints", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changePreferredEndpoints(nil)}, - version: 2, fieldMask: []string{"PreferredEndpoints"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), @@ -846,11 +946,11 @@ func TestRepository_UpdateSet(t *testing.T) { }, }, { - name: "update attributes (add)", + name: "update attributes (add)", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ "baz": "qux", })}, - version: 2, fieldMask: []string{"attributes"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ @@ -875,11 +975,11 @@ func TestRepository_UpdateSet(t *testing.T) { }, }, { - name: "update attributes (overwrite)", + name: "update attributes (overwrite)", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ "foo": "baz", })}, - version: 2, fieldMask: []string{"attributes"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ @@ -902,11 +1002,11 @@ func TestRepository_UpdateSet(t *testing.T) { }, }, { - name: "update attributes (null)", + name: "update attributes (null)", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ "foo": nil, })}, - version: 2, fieldMask: []string{"attributes"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ @@ -926,8 +1026,8 @@ func TestRepository_UpdateSet(t *testing.T) { }, { name: "update attributes (full null)", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeAttributesNil()}, - version: 2, fieldMask: []string{"attributes"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ @@ -946,12 +1046,12 @@ func TestRepository_UpdateSet(t *testing.T) { }, }, { - name: "update attributes (combined)", + name: "update attributes (combined)", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{changeAttributes(map[string]interface{}{ "a": "b", "foo": "baz", })}, - version: 2, fieldMask: []string{"attributes.a", "attributes.foo"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentAttributes(map[string]interface{}{ @@ -976,12 +1076,12 @@ func TestRepository_UpdateSet(t *testing.T) { }, }, { - name: "update name and preferred endpoints", + name: "update name and preferred endpoints", + startingSet: setupHostSet, changeFuncs: []changeHostSetFunc{ changePreferredEndpoints([]string{"cidr:10.0.0.0/24"}), changeName("foo"), }, - version: 2, fieldMask: []string{"name", "PreferredEndpoints"}, wantCheckPluginReqFuncs: []checkPluginReqFunc{ checkUpdateSetRequestCurrentNameNil(), @@ -1002,55 +1102,6 @@ func TestRepository_UpdateSet(t *testing.T) { }, } - // Create a host that will be verified as belonging to the created set below. - testHostExternIdPrefix := "test-host-external-id" - testHosts := make([]*Host, 3) - for i := 0; i <= 2; i++ { - testHosts[i] = TestHost(t, dbConn, testCatalog.PublicId, fmt.Sprintf("%s-%d", testHostExternIdPrefix, i)) - } - - // Finally define a function for bringing the test subject host - // set. - setupHostSet := func(t *testing.T, ctx context.Context) *HostSet { - t.Helper() - require := require.New(t) - - set := TestSet( - t, - dbConn, - dbKmsCache, - sched, - testCatalog, - dummyPluginMap, - WithPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}), - ) - // Set some (default) attributes on our test set - set.Attributes = mustMarshal(map[string]interface{}{ - "foo": "bar", - }) - - // Set some fake sync detail to the set. - set.LastSyncTime = timestamp.New(time.Now()) - set.NeedSync = false - - numSetsUpdated, err := dbRW.Update(ctx, set, []string{"attributes", "LastSyncTime", "NeedSync"}, []string{}) - require.NoError(err) - require.Equal(1, numSetsUpdated) - - // Add some hosts to the host set. - TestSetMembers(t, dbConn, set.PublicId, testHosts) - - t.Cleanup(func() { - t.Helper() - assert := assert.New(t) - n, err := dbRW.Delete(ctx, set) - assert.NoError(err) - assert.Equal(1, n) - }) - - return set - } - for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { @@ -1078,7 +1129,7 @@ func TestRepository_UpdateSet(t *testing.T) { }, } - origSet := setupHostSet(t, ctx) + origSet, wantedHosts := tt.startingSet(t, ctx) pluginMap := testPluginMap if tt.withEmptyPluginMap { @@ -1098,7 +1149,12 @@ func TestRepository_UpdateSet(t *testing.T) { scopeId = *tt.withScopeId } - gotUpdatedSet, gotHosts, gotPlugin, gotNumUpdated, err := repo.UpdateSet(ctx, scopeId, workingSet, tt.version, tt.fieldMask) + version := origSet.Version + if tt.version != 0 { + version = tt.version + } + + gotUpdatedSet, gotHosts, gotPlugin, gotNumUpdated, err := repo.UpdateSet(ctx, scopeId, workingSet, version, tt.fieldMask) t.Cleanup(func() { gotOnUpdateCallCount = 0 }) if tt.wantIsErr != 0 { require.Equal(db.NoRowsAffected, gotNumUpdated) @@ -1118,8 +1174,8 @@ func TestRepository_UpdateSet(t *testing.T) { assert.Equal(testCatalog.PluginId, gotPlugin.PublicId) // Also assert that the hosts returned by the request are the ones that belong to the set - wantHostMap := make(map[string]string, len(testHosts)) - for _, h := range testHosts { + wantHostMap := make(map[string]string, len(wantedHosts)) + for _, h := range wantedHosts { wantHostMap[h.PublicId] = h.ExternalId } gotHostMap := make(map[string]string, len(gotHosts))