diff --git a/internal/target/repository.go b/internal/target/repository.go index c2ff99cd0c..a3cf72a2fa 100644 --- a/internal/target/repository.go +++ b/internal/target/repository.go @@ -239,7 +239,8 @@ func (r *Repository) AddTargeHostSets(ctx context.Context, targetId string, targ tcpT.PublicId = t.PublicId tcpT.Version = targetVersion + 1 target = &tcpT - metadata = tcpT.oplog(oplog.OpType_OP_TYPE_CREATE) + metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_CREATE.String()) default: return nil, nil, fmt.Errorf("delete target host sets: %s is an unsupported target type %s", t.PublicId, t.Type) } @@ -329,7 +330,8 @@ func (r *Repository) DeleteTargeHostSets(ctx context.Context, targetId string, t tcpT.PublicId = t.PublicId tcpT.Version = targetVersion + 1 target = &tcpT - metadata = tcpT.oplog(oplog.OpType_OP_TYPE_DELETE) + metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_DELETE.String()) default: return db.NoRowsAffected, fmt.Errorf("delete target host sets: %s is an unsupported target type %s", t.PublicId, t.Type) } @@ -382,3 +384,146 @@ func (r *Repository) DeleteTargeHostSets(ctx context.Context, targetId string, t } return totalRowsDeleted, nil } + +// SetTargetHostSets will set the target's host sets. Set add and/or delete +// target host sets as need to reconcile the existing sets with the sets +// requested. If hostSetIds is empty, the target host sets will be cleared. Zero +// is not a valid value for the WithVersion option and will return an error. +func (r *Repository) SetTargetHostSets(ctx context.Context, targetId string, targetVersion uint32, hostSetIds []string, opt ...Option) ([]string, int, error) { + if targetId == "" { + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: missing role id: %w", db.ErrInvalidParameter) + } + if targetVersion == 0 { + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: version cannot be zero: %w", db.ErrInvalidParameter) + } + t := allocTargetView() + t.PublicId = targetId + if err := r.reader.LookupByPublicId(ctx, &t); err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: failed %w for %s", err, targetId) + } + + // NOTE: calculating that to set can safely happen outside of the write + // transaction since we're using targetVersion to ensure that the only + // operate on the same set of data from these queries that calculate the + // set. + + foundThs, err := fetchHostSets(ctx, r.reader, targetId) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: unable to search for existing target host sets: %w", err) + } + found := map[string]string{} + for _, id := range foundThs { + found[id] = id + } + addHostSets := make([]interface{}, 0, len(hostSetIds)) + for _, id := range hostSetIds { + if _, ok := found[id]; ok { + // found a match, so do nothing (we want to keep it), but remove it + // from found + delete(found, id) + continue + } + hs, err := NewTargetHostSet(targetId, id) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set target host set: unable to create in memory target host set: %w", err) + } + addHostSets = append(addHostSets, hs) + } + deleteHostSets := make([]interface{}, 0, len(hostSetIds)) + if len(found) > 0 { + for _, id := range found { + hs, err := NewTargetHostSet(targetId, id) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set target host set: unable to create in memory target host set: %w", err) + } + deleteHostSets = append(deleteHostSets, hs) + } + } + currentHostSets := make([]string, 0, len(hostSetIds)+len(found)) + if len(addHostSets) == 0 && len(deleteHostSets) == 0 { + for _, id := range foundThs { + currentHostSets = append(currentHostSets, id) + } + return currentHostSets, db.NoRowsAffected, nil + } + + var metadata oplog.Metadata + var target interface{} + switch t.Type { + case TcpTargetType.String(): + tcpT := allocTcpTarget() + tcpT.PublicId = t.PublicId + tcpT.Version = targetVersion + 1 + target = &tcpT + metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE) + default: + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: %s is an unsupported target type %s", t.PublicId, t.Type) + } + oplogWrapper, err := r.kms.GetWrapper(ctx, t.GetScopeId(), kms.KeyPurposeOplog) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: unable to get oplog wrapper: %w", err) + } + + var totalRowsAffected int + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 2) + targetTicket, err := w.GetTicket(target) + if err != nil { + return fmt.Errorf("set target host sets: unable to get ticket: %w", err) + } + updatedTarget := target.(Cloneable).Clone() + var targetOplogMsg oplog.Message + rowsUpdated, err := w.Update(ctx, updatedTarget, []string{"Version"}, nil, db.NewOplogMsg(&targetOplogMsg), db.WithVersion(&targetVersion)) + if err != nil { + return fmt.Errorf("set target host sets: unable to update target version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("set target host sets: updated target and %d rows updated", rowsUpdated) + } + msgs = append(msgs, &targetOplogMsg) + + // Write the new ones in + if len(addHostSets) > 0 { + hostSetOplogMsgs := make([]*oplog.Message, 0, len(addHostSets)) + if err := w.CreateItems(ctx, addHostSets, db.NewOplogMsgs(&hostSetOplogMsgs)); err != nil { + return fmt.Errorf("unable to add target host sets during set: %w", err) + } + totalRowsAffected += len(addHostSets) + msgs = append(msgs, hostSetOplogMsgs...) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_CREATE.String()) + } + + // Anything we didn't take out of found needs to be removed + if len(deleteHostSets) > 0 { + hostSetOplogMsgs := make([]*oplog.Message, 0, len(deleteHostSets)) + rowsDeleted, err := w.DeleteItems(ctx, deleteHostSets, db.NewOplogMsgs(&hostSetOplogMsgs)) + if err != nil { + return fmt.Errorf("set target host sets: unable to delete target host set: %w", err) + } + if rowsDeleted != len(deleteHostSets) { + return fmt.Errorf("set target host sets: target host sets deleted %d did not match request for %d", rowsDeleted, len(deleteHostSets)) + } + totalRowsAffected += rowsDeleted + msgs = append(msgs, hostSetOplogMsgs...) + metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_DELETE.String()) + } + if err := w.WriteOplogEntryWith(ctx, oplogWrapper, targetTicket, metadata, msgs); err != nil { + return fmt.Errorf("set target host sets: unable to write oplog: %w", err) + } + + currentHostSets, err = fetchHostSets(ctx, reader, targetId) + if err != nil { + return fmt.Errorf("set target host sets: unable to retrieve current target host sets after set: %w", err) + } + return nil + }, + ) + if err != nil { + return nil, db.NoRowsAffected, fmt.Errorf("target host sets: error setting target host sets: %w", err) + } + return currentHostSets, totalRowsAffected, nil +} diff --git a/internal/target/repository_test.go b/internal/target/repository_test.go index 15c3f8b219..890094ce3a 100644 --- a/internal/target/repository_test.go +++ b/internal/target/repository_test.go @@ -3,6 +3,7 @@ package target import ( "context" "errors" + "sort" "strconv" "testing" "time" @@ -611,3 +612,159 @@ func TestRepository_DeleteTargetHosts(t *testing.T) { }) } } + +func TestRepository_SetTargetHostSets(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + testKms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, testKms) + require.NoError(t, err) + + iamRepo := iam.TestRepo(t, conn, wrapper) + org, proj := iam.TestScopes(t, iamRepo) + + testCats := static.TestCatalogs(t, conn, org.PublicId, 1) + hsets := static.TestSets(t, conn, testCats[0].GetPublicId(), 5) + testHostSetIds := make([]string, 0, len(hsets)) + for _, hs := range hsets { + testHostSetIds = append(testHostSetIds, hs.PublicId) + } + + createHostSetsFn := func() []string { + results := []string{} + for _, publicId := range []string{org.PublicId, proj.PublicId} { + for i := 0; i < 5; i++ { + cats := static.TestCatalogs(t, conn, publicId, 1) + hsets := static.TestSets(t, conn, cats[0].GetPublicId(), 1) + results = append(results, hsets[0].PublicId) + } + } + return results + } + + setupFn := func(target Target) []string { + hs := createHostSetsFn() + _, created, err := repo.AddTargeHostSets(context.Background(), target.GetPublicId(), 1, hs) + require.NoError(t, err) + require.Equal(t, 10, len(created)) + return created + } + type args struct { + target Target + targetVersion uint32 + hostSetIds []string + addToOrigHostSets bool + opt []Option + } + tests := []struct { + name string + setup func(Target) []string + args args + wantAffectedRows int + wantErr bool + }{ + { + name: "clear", + setup: setupFn, + args: args{ + target: TestTcpTarget(t, conn, proj.PublicId, "clear"), + targetVersion: 2, // yep, since setupFn will increment it to 2 + hostSetIds: []string{}, + }, + wantErr: false, + wantAffectedRows: 10, + }, + { + name: "no-change", + setup: setupFn, + args: args{ + target: TestTcpTarget(t, conn, proj.PublicId, "no-change"), + targetVersion: 2, // yep, since setupFn will increment it to 2 + hostSetIds: []string{}, + addToOrigHostSets: true, + }, + wantErr: false, + wantAffectedRows: 0, + }, + { + name: "add-sets", + setup: setupFn, + args: args{ + target: TestTcpTarget(t, conn, proj.PublicId, "add-sets"), + targetVersion: 2, // yep, since setupFn will increment it to 2 + hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]}, + addToOrigHostSets: true, + }, + wantErr: false, + wantAffectedRows: 2, + }, + { + name: "add host sets with zero version", + setup: setupFn, + args: args{ + target: TestTcpTarget(t, conn, proj.PublicId, "add host sets with zero version"), + targetVersion: 0, + hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]}, + addToOrigHostSets: true, + }, + wantErr: true, + }, + { + name: "remove existing and add users and grps", + setup: setupFn, + args: args{ + target: TestTcpTarget(t, conn, proj.PublicId, "remove existing and add users and grps"), + targetVersion: 2, // yep, since setupFn will increment it to 2 + hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]}, + addToOrigHostSets: false, + }, + wantErr: false, + wantAffectedRows: 12, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var origHostSets []string + if tt.setup != nil { + origHostSets = tt.setup(tt.args.target) + } + if tt.args.addToOrigHostSets { + tt.args.hostSetIds = append(tt.args.hostSetIds, origHostSets...) + } + origTarget, lookedUpHs, err := repo.LookupTarget(context.Background(), tt.args.target.GetPublicId()) + require.NoError(err) + assert.Equal(len(origHostSets), len(lookedUpHs)) + + got, affectedRows, err := repo.SetTargetHostSets(context.Background(), tt.args.target.GetPublicId(), tt.args.targetVersion, tt.args.hostSetIds, tt.args.opt...) + if tt.wantErr { + require.Error(err) + t.Log(err) + return + } + t.Log(err) + require.NoError(err) + assert.Equal(tt.wantAffectedRows, affectedRows) + assert.Equal(len(tt.args.hostSetIds), len(got)) + var gotIds []string + for _, id := range got { + gotIds = append(gotIds, id) + } + var wantIds []string + wantIds = append(wantIds, tt.args.hostSetIds...) + sort.Strings(wantIds) + sort.Strings(gotIds) + assert.Equal(wantIds, gotIds) + + foundTarget, _, err := repo.LookupTarget(context.Background(), tt.args.target.GetPublicId()) + require.NoError(err) + if tt.name != "no-change" { + assert.Equalf(tt.args.targetVersion+1, foundTarget.GetVersion(), "%s unexpected version: %d/%d", tt.name, tt.args.targetVersion+1, foundTarget.GetVersion()) + assert.Equalf(origTarget.GetVersion(), foundTarget.GetVersion()-1, "%s unexpected version: %d/%d", tt.name, origTarget.GetVersion(), foundTarget.GetVersion()-1) + } + t.Logf("target: %v and origVersion/newVersion: %d/%d", foundTarget.GetPublicId(), origTarget.GetVersion(), foundTarget.GetVersion()) + }) + } +}