diff --git a/internal/iam/repository_principal_role_test.go b/internal/iam/repository_principal_role_test.go index f11b4cda28..a013aa2b5b 100644 --- a/internal/iam/repository_principal_role_test.go +++ b/internal/iam/repository_principal_role_test.go @@ -3,6 +3,7 @@ package iam import ( "context" "errors" + "sort" "testing" "time" @@ -419,3 +420,330 @@ func TestRepository_DeletePrincipalRoles(t *testing.T) { }) } } + +func TestRepository_SetPrincipalRoles(t *testing.T) { + t.Parallel() + cleanup, conn, _ := db.TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + }() + rw := db.New(conn) + wrapper := db.TestWrapper(t) + + repo, err := NewRepository(rw, rw, wrapper) + require.NoError(t, err) + + org, proj := TestScopes(t, conn) + testUser := TestUser(t, conn, org.PublicId) + testGrp := TestGroup(t, conn, proj.PublicId) + + createUsersFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + u := TestUser(t, conn, org.PublicId) + results = append(results, u.PublicId) + } + return results + } + createGrpsFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + g := TestGroup(t, conn, proj.PublicId) + results = append(results, g.PublicId) + } + return results + } + setupFn := func(role *Role) ([]string, []string) { + users := createUsersFn() + grps := createGrpsFn() + _, err := repo.AddPrincipalRoles(context.Background(), role.PublicId, 1, users, grps) + require.NoError(t, err) + return users, grps + } + type args struct { + role *Role + roleVersion int + userIds []string + groupIds []string + addToOrigUsers bool + addToOrigGrps bool + opt []Option + } + tests := []struct { + name string + setup func(*Role) ([]string, []string) + args args + wantAffectedRows int + wantErr bool + }{ + { + name: "clear", + setup: setupFn, + args: args{ + role: TestRole(t, conn, proj.PublicId), + roleVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{}, + groupIds: []string{}, + }, + wantErr: false, + wantAffectedRows: 10, + }, + { + name: "no change", + setup: setupFn, + args: args{ + role: TestRole(t, conn, proj.PublicId), + roleVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{}, + groupIds: []string{}, + addToOrigUsers: true, + addToOrigGrps: true, + }, + wantErr: false, + wantAffectedRows: 0, + }, + { + name: "add users and grps", + setup: setupFn, + args: args{ + role: TestRole(t, conn, proj.PublicId), + roleVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{testUser.PublicId}, + groupIds: []string{testGrp.PublicId}, + addToOrigUsers: true, + addToOrigGrps: true, + }, + wantErr: false, + wantAffectedRows: 2, + }, + { + name: "remove existing and add users and grps", + setup: setupFn, + args: args{ + role: TestRole(t, conn, proj.PublicId), + roleVersion: 2, // yep, since setupFn will increment it to 2 + userIds: []string{testUser.PublicId}, + groupIds: []string{testGrp.PublicId}, + addToOrigUsers: false, + addToOrigGrps: false, + }, + wantErr: false, + wantAffectedRows: 12, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var origUsers, origGrps []string + if tt.setup != nil { + origUsers, origGrps = tt.setup(tt.args.role) + } + setUsers := tt.args.userIds + setGrps := tt.args.groupIds + if tt.args.addToOrigUsers { + setUsers = append(setUsers, origUsers...) + } + if tt.args.addToOrigGrps { + setGrps = append(setGrps, origGrps...) + } + + got, affectedRows, err := repo.SetPrincipalRoles(context.Background(), tt.args.role.PublicId, tt.args.roleVersion, setUsers, setGrps, tt.args.opt...) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assert.Equal(tt.wantAffectedRows, affectedRows) + var gotIds []string + for _, r := range got { + gotIds = append(gotIds, r.GetPrincipalId()) + } + var wantIds []string + wantIds = append(wantIds, tt.args.userIds...) + wantIds = append(wantIds, tt.args.groupIds...) + sort.Strings(wantIds) + sort.Strings(gotIds) + assert.Equal(wantIds, wantIds) + }) + } +} + +func TestRepository_principalsToSet(t *testing.T) { + t.Parallel() + cleanup, conn, _ := db.TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(t, err) + err = conn.Close() + assert.NoError(t, err) + }() + rw := db.New(conn) + wrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw, wrapper) + require.NoError(t, err) + org, proj := TestScopes(t, conn) + createUsersFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + u := TestUser(t, conn, org.PublicId) + results = append(results, u.PublicId) + } + return results + } + createGrpsFn := func() []string { + results := []string{} + for i := 0; i < 5; i++ { + g := TestGroup(t, conn, proj.PublicId) + results = append(results, g.PublicId) + } + return results + } + setupFn := func() (*Role, []string, []string) { + users := createUsersFn() + grps := createGrpsFn() + role := TestRole(t, conn, proj.PublicId) + _, err := repo.AddPrincipalRoles(context.Background(), role.PublicId, 1, users, grps) + require.NoError(t, err) + return role, users, grps + } + + type args struct { + userIds []string + groupIds []string + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "all new", + args: args{ + userIds: createUsersFn(), + groupIds: createGrpsFn(), + }, + wantErr: false, + }, + { + name: "clear all", + args: args{ + userIds: nil, + groupIds: nil, + }, + wantErr: false, + }, + { + name: "just new users", + args: args{ + userIds: createUsersFn(), + groupIds: nil, + }, + wantErr: false, + }, + { + name: "just new groups", + args: args{ + userIds: nil, + groupIds: createGrpsFn(), + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require := require.New(t) + r, origUsers, origGrps := setupFn() + got, err := repo.principalsToSet(context.Background(), r, tt.args.userIds, tt.args.groupIds) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assertSetResults(t, got, tt.args.userIds, tt.args.groupIds, origUsers, origGrps) + }) + } + t.Run("nil role", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + _, users, grps := setupFn() + got, err := repo.principalsToSet(context.Background(), nil, users, grps) + require.Error(err) + assert.Nil(got) + assert.Truef(errors.Is(err, db.ErrNilParameter), "unexpected error %s", err.Error()) + }) + t.Run("no change", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + r, users, grps := setupFn() + got, err := repo.principalsToSet(context.Background(), r, users, grps) + require.NoError(err) + assert.Empty(got.addUserRoles) + assert.Empty(got.addGroupRoles) + assert.Empty(got.deleteUserRoles) + assert.Empty(got.deleteGroupRoles) + }) + t.Run("mixed", func(t *testing.T) { + require := require.New(t) + r, users, grps := setupFn() + var wantSetUsers, wantSetGrps, wantDeleteUsers, wantDeleteGrps []string + for i, id := range users { + if i < 2 { + wantSetUsers = append(wantSetUsers, id) + } else { + wantDeleteUsers = append(wantDeleteUsers, id) + } + } + for i, id := range grps { + if i < 2 { + wantSetGrps = append(wantSetGrps, id) + } else { + wantDeleteGrps = append(wantDeleteGrps, id) + } + } + newUser := TestUser(t, conn, org.PublicId) + newGrp := TestGroup(t, conn, proj.PublicId) + wantSetUsers = append(wantSetUsers, newUser.PublicId) + wantSetGrps = append(wantSetGrps, newGrp.PublicId) + + got, err := repo.principalsToSet(context.Background(), r, wantSetUsers, wantSetGrps) + require.NoError(err) + assertSetResults(t, got, []string{newUser.PublicId}, []string{newGrp.PublicId}, wantDeleteUsers, wantDeleteGrps) + }) +} + +func assertSetResults(t *testing.T, got *principalSet, wantAddUsers, wantAddGroups, wantDeleteUsers, wantDeleteGroups []string) { + t.Helper() + assert := assert.New(t) + var gotAddUsers []string + for _, r := range got.addUserRoles { + gotAddUsers = append(gotAddUsers, r.(*UserRole).PrincipalId) + } + // sort.Strings(wantAddUsers) + // sort.Strings(gotAddUsers) + assert.Equal(wantAddUsers, gotAddUsers) + + var gotAddGrps []string + for _, r := range got.addGroupRoles { + gotAddGrps = append(gotAddGrps, r.(*GroupRole).PrincipalId) + } + // sort.Strings(wantAddGroups) + // sort.Strings(gotAddGrps) + assert.Equal(wantAddGroups, gotAddGrps) + + var gotDeleteUsers []string + for _, r := range got.deleteUserRoles { + gotDeleteUsers = append(gotDeleteUsers, r.(*UserRole).PrincipalId) + } + sort.Strings(wantDeleteUsers) + sort.Strings(gotDeleteUsers) + assert.Equal(wantDeleteUsers, gotDeleteUsers) + + var gotDeleteGroups []string + for _, r := range got.deleteGroupRoles { + gotDeleteGroups = append(gotDeleteGroups, r.(*GroupRole).PrincipalId) + } + sort.Strings(wantDeleteGroups) + sort.Strings(gotDeleteGroups) + assert.Equal(wantDeleteGroups, gotDeleteGroups) +}