add SetTargetHostSets()

jimlambrt-targets-store
Jim Lambert 6 years ago
parent cb6f4e90b1
commit 738972b76c

@ -239,7 +239,8 @@ func (r *Repository) AddTargeHostSets(ctx context.Context, targetId string, targ
tcpT.PublicId = t.PublicId
tcpT.Version = targetVersion + 1
target = &tcpT
metadata = tcpT.oplog(oplog.OpType_OP_TYPE_CREATE)
metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE)
metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_CREATE.String())
default:
return nil, nil, fmt.Errorf("delete target host sets: %s is an unsupported target type %s", t.PublicId, t.Type)
}
@ -329,7 +330,8 @@ func (r *Repository) DeleteTargeHostSets(ctx context.Context, targetId string, t
tcpT.PublicId = t.PublicId
tcpT.Version = targetVersion + 1
target = &tcpT
metadata = tcpT.oplog(oplog.OpType_OP_TYPE_DELETE)
metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE)
metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_DELETE.String())
default:
return db.NoRowsAffected, fmt.Errorf("delete target host sets: %s is an unsupported target type %s", t.PublicId, t.Type)
}
@ -382,3 +384,146 @@ func (r *Repository) DeleteTargeHostSets(ctx context.Context, targetId string, t
}
return totalRowsDeleted, nil
}
// SetTargetHostSets will set the target's host sets. Set add and/or delete
// target host sets as need to reconcile the existing sets with the sets
// requested. If hostSetIds is empty, the target host sets will be cleared. Zero
// is not a valid value for the WithVersion option and will return an error.
func (r *Repository) SetTargetHostSets(ctx context.Context, targetId string, targetVersion uint32, hostSetIds []string, opt ...Option) ([]string, int, error) {
if targetId == "" {
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: missing role id: %w", db.ErrInvalidParameter)
}
if targetVersion == 0 {
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: version cannot be zero: %w", db.ErrInvalidParameter)
}
t := allocTargetView()
t.PublicId = targetId
if err := r.reader.LookupByPublicId(ctx, &t); err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: failed %w for %s", err, targetId)
}
// NOTE: calculating that to set can safely happen outside of the write
// transaction since we're using targetVersion to ensure that the only
// operate on the same set of data from these queries that calculate the
// set.
foundThs, err := fetchHostSets(ctx, r.reader, targetId)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: unable to search for existing target host sets: %w", err)
}
found := map[string]string{}
for _, id := range foundThs {
found[id] = id
}
addHostSets := make([]interface{}, 0, len(hostSetIds))
for _, id := range hostSetIds {
if _, ok := found[id]; ok {
// found a match, so do nothing (we want to keep it), but remove it
// from found
delete(found, id)
continue
}
hs, err := NewTargetHostSet(targetId, id)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set target host set: unable to create in memory target host set: %w", err)
}
addHostSets = append(addHostSets, hs)
}
deleteHostSets := make([]interface{}, 0, len(hostSetIds))
if len(found) > 0 {
for _, id := range found {
hs, err := NewTargetHostSet(targetId, id)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set target host set: unable to create in memory target host set: %w", err)
}
deleteHostSets = append(deleteHostSets, hs)
}
}
currentHostSets := make([]string, 0, len(hostSetIds)+len(found))
if len(addHostSets) == 0 && len(deleteHostSets) == 0 {
for _, id := range foundThs {
currentHostSets = append(currentHostSets, id)
}
return currentHostSets, db.NoRowsAffected, nil
}
var metadata oplog.Metadata
var target interface{}
switch t.Type {
case TcpTargetType.String():
tcpT := allocTcpTarget()
tcpT.PublicId = t.PublicId
tcpT.Version = targetVersion + 1
target = &tcpT
metadata = tcpT.oplog(oplog.OpType_OP_TYPE_UPDATE)
default:
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: %s is an unsupported target type %s", t.PublicId, t.Type)
}
oplogWrapper, err := r.kms.GetWrapper(ctx, t.GetScopeId(), kms.KeyPurposeOplog)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("set target host sets: unable to get oplog wrapper: %w", err)
}
var totalRowsAffected int
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 2)
targetTicket, err := w.GetTicket(target)
if err != nil {
return fmt.Errorf("set target host sets: unable to get ticket: %w", err)
}
updatedTarget := target.(Cloneable).Clone()
var targetOplogMsg oplog.Message
rowsUpdated, err := w.Update(ctx, updatedTarget, []string{"Version"}, nil, db.NewOplogMsg(&targetOplogMsg), db.WithVersion(&targetVersion))
if err != nil {
return fmt.Errorf("set target host sets: unable to update target version: %w", err)
}
if rowsUpdated != 1 {
return fmt.Errorf("set target host sets: updated target and %d rows updated", rowsUpdated)
}
msgs = append(msgs, &targetOplogMsg)
// Write the new ones in
if len(addHostSets) > 0 {
hostSetOplogMsgs := make([]*oplog.Message, 0, len(addHostSets))
if err := w.CreateItems(ctx, addHostSets, db.NewOplogMsgs(&hostSetOplogMsgs)); err != nil {
return fmt.Errorf("unable to add target host sets during set: %w", err)
}
totalRowsAffected += len(addHostSets)
msgs = append(msgs, hostSetOplogMsgs...)
metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_CREATE.String())
}
// Anything we didn't take out of found needs to be removed
if len(deleteHostSets) > 0 {
hostSetOplogMsgs := make([]*oplog.Message, 0, len(deleteHostSets))
rowsDeleted, err := w.DeleteItems(ctx, deleteHostSets, db.NewOplogMsgs(&hostSetOplogMsgs))
if err != nil {
return fmt.Errorf("set target host sets: unable to delete target host set: %w", err)
}
if rowsDeleted != len(deleteHostSets) {
return fmt.Errorf("set target host sets: target host sets deleted %d did not match request for %d", rowsDeleted, len(deleteHostSets))
}
totalRowsAffected += rowsDeleted
msgs = append(msgs, hostSetOplogMsgs...)
metadata["op-type"] = append(metadata["op-type"], oplog.OpType_OP_TYPE_DELETE.String())
}
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, targetTicket, metadata, msgs); err != nil {
return fmt.Errorf("set target host sets: unable to write oplog: %w", err)
}
currentHostSets, err = fetchHostSets(ctx, reader, targetId)
if err != nil {
return fmt.Errorf("set target host sets: unable to retrieve current target host sets after set: %w", err)
}
return nil
},
)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("target host sets: error setting target host sets: %w", err)
}
return currentHostSets, totalRowsAffected, nil
}

@ -3,6 +3,7 @@ package target
import (
"context"
"errors"
"sort"
"strconv"
"testing"
"time"
@ -611,3 +612,159 @@ func TestRepository_DeleteTargetHosts(t *testing.T) {
})
}
}
func TestRepository_SetTargetHostSets(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
testKms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, testKms)
require.NoError(t, err)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, proj := iam.TestScopes(t, iamRepo)
testCats := static.TestCatalogs(t, conn, org.PublicId, 1)
hsets := static.TestSets(t, conn, testCats[0].GetPublicId(), 5)
testHostSetIds := make([]string, 0, len(hsets))
for _, hs := range hsets {
testHostSetIds = append(testHostSetIds, hs.PublicId)
}
createHostSetsFn := func() []string {
results := []string{}
for _, publicId := range []string{org.PublicId, proj.PublicId} {
for i := 0; i < 5; i++ {
cats := static.TestCatalogs(t, conn, publicId, 1)
hsets := static.TestSets(t, conn, cats[0].GetPublicId(), 1)
results = append(results, hsets[0].PublicId)
}
}
return results
}
setupFn := func(target Target) []string {
hs := createHostSetsFn()
_, created, err := repo.AddTargeHostSets(context.Background(), target.GetPublicId(), 1, hs)
require.NoError(t, err)
require.Equal(t, 10, len(created))
return created
}
type args struct {
target Target
targetVersion uint32
hostSetIds []string
addToOrigHostSets bool
opt []Option
}
tests := []struct {
name string
setup func(Target) []string
args args
wantAffectedRows int
wantErr bool
}{
{
name: "clear",
setup: setupFn,
args: args{
target: TestTcpTarget(t, conn, proj.PublicId, "clear"),
targetVersion: 2, // yep, since setupFn will increment it to 2
hostSetIds: []string{},
},
wantErr: false,
wantAffectedRows: 10,
},
{
name: "no-change",
setup: setupFn,
args: args{
target: TestTcpTarget(t, conn, proj.PublicId, "no-change"),
targetVersion: 2, // yep, since setupFn will increment it to 2
hostSetIds: []string{},
addToOrigHostSets: true,
},
wantErr: false,
wantAffectedRows: 0,
},
{
name: "add-sets",
setup: setupFn,
args: args{
target: TestTcpTarget(t, conn, proj.PublicId, "add-sets"),
targetVersion: 2, // yep, since setupFn will increment it to 2
hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]},
addToOrigHostSets: true,
},
wantErr: false,
wantAffectedRows: 2,
},
{
name: "add host sets with zero version",
setup: setupFn,
args: args{
target: TestTcpTarget(t, conn, proj.PublicId, "add host sets with zero version"),
targetVersion: 0,
hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]},
addToOrigHostSets: true,
},
wantErr: true,
},
{
name: "remove existing and add users and grps",
setup: setupFn,
args: args{
target: TestTcpTarget(t, conn, proj.PublicId, "remove existing and add users and grps"),
targetVersion: 2, // yep, since setupFn will increment it to 2
hostSetIds: []string{testHostSetIds[0], testHostSetIds[1]},
addToOrigHostSets: false,
},
wantErr: false,
wantAffectedRows: 12,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var origHostSets []string
if tt.setup != nil {
origHostSets = tt.setup(tt.args.target)
}
if tt.args.addToOrigHostSets {
tt.args.hostSetIds = append(tt.args.hostSetIds, origHostSets...)
}
origTarget, lookedUpHs, err := repo.LookupTarget(context.Background(), tt.args.target.GetPublicId())
require.NoError(err)
assert.Equal(len(origHostSets), len(lookedUpHs))
got, affectedRows, err := repo.SetTargetHostSets(context.Background(), tt.args.target.GetPublicId(), tt.args.targetVersion, tt.args.hostSetIds, tt.args.opt...)
if tt.wantErr {
require.Error(err)
t.Log(err)
return
}
t.Log(err)
require.NoError(err)
assert.Equal(tt.wantAffectedRows, affectedRows)
assert.Equal(len(tt.args.hostSetIds), len(got))
var gotIds []string
for _, id := range got {
gotIds = append(gotIds, id)
}
var wantIds []string
wantIds = append(wantIds, tt.args.hostSetIds...)
sort.Strings(wantIds)
sort.Strings(gotIds)
assert.Equal(wantIds, gotIds)
foundTarget, _, err := repo.LookupTarget(context.Background(), tt.args.target.GetPublicId())
require.NoError(err)
if tt.name != "no-change" {
assert.Equalf(tt.args.targetVersion+1, foundTarget.GetVersion(), "%s unexpected version: %d/%d", tt.name, tt.args.targetVersion+1, foundTarget.GetVersion())
assert.Equalf(origTarget.GetVersion(), foundTarget.GetVersion()-1, "%s unexpected version: %d/%d", tt.name, origTarget.GetVersion(), foundTarget.GetVersion()-1)
}
t.Logf("target: %v and origVersion/newVersion: %d/%d", foundTarget.GetPublicId(), origTarget.GetVersion(), foundTarget.GetVersion())
})
}
}

Loading…
Cancel
Save