Additional validation for add|set|remove methods. (#527)

pull/521/head
Todd Knight 6 years ago committed by GitHub
parent f8a2100603
commit f76d5aed9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -268,7 +268,7 @@ func Parse(scopeId, grantString string, opt ...Option) (Grant, error) {
opts := getOpts(opt...)
// Check for templated values ID, and subtitute in with the authenticated values
// Check for templated values ID, and substitute in with the authenticated values
// if so
if grant.id != "" && strings.HasPrefix(grant.id, "{{") {
id := strings.TrimSuffix(strings.TrimPrefix(grant.id, "{{"), "}}")

@ -344,7 +344,7 @@ func validateCreateRequest(req *pbs.CreateAuthMethodRequest) error {
badFields := map[string]string{}
if !handlers.ValidId(scope.Org.Prefix(), req.GetItem().GetScopeId()) &&
scope.Global.String() != req.GetItem().GetScopeId() {
badFields["scope_id"] = "This field is missing or improperly formatted."
badFields["scope_id"] = "This field must be 'global' or a valid org scope id."
}
switch auth.SubtypeFromType(req.GetItem().GetType()) {
case auth.PasswordSubtype:
@ -386,7 +386,7 @@ func validateListRequest(req *pbs.ListAuthMethodsRequest) error {
badFields := map[string]string{}
if !handlers.ValidId(scope.Org.Prefix(), req.GetScopeId()) &&
req.GetScopeId() != scope.Global.String() {
badFields["scope_id"] = "This field must be a valid project scope id."
badFields["scope_id"] = "This field must be 'global' or a valid org scope id."
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Improperly formatted identifier.", badFields)

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@ -295,7 +296,7 @@ func (s Service) addMembersInRepo(ctx context.Context, groupId string, userIds [
if err != nil {
return nil, err
}
_, err = repo.AddGroupMembers(ctx, groupId, version, userIds)
_, err = repo.AddGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add members to group: %v.", err)
@ -315,7 +316,7 @@ func (s Service) setMembersInRepo(ctx context.Context, groupId string, userIds [
if err != nil {
return nil, err
}
_, _, err = repo.SetGroupMembers(ctx, groupId, version, userIds)
_, _, err = repo.SetGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set members on group: %v.", err)
@ -335,7 +336,7 @@ func (s Service) removeMembersInRepo(ctx context.Context, groupId string, userId
if err != nil {
return nil, err
}
_, err = repo.DeleteGroupMembers(ctx, groupId, version, userIds)
_, err = repo.DeleteGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove members from group: %v.", err)
@ -467,8 +468,13 @@ func validateAddGroupMembersRequest(req *pbs.AddGroupMembersRequest) error {
badFields["member_ids"] = "Must be non-empty."
}
for _, id := range req.GetMemberIds() {
if !handlers.ValidId(iam.UserPrefix, id) {
badFields["member_ids"] = "Must only contain valid user ids."
break
}
if id == "u_recovery" {
badFields["member_ids"] = "u_recovery cannot be assigned to a group"
badFields["member_ids"] = "u_recovery cannot be assigned to a group."
break
}
}
if len(badFields) > 0 {
@ -486,8 +492,13 @@ func validateSetGroupMembersRequest(req *pbs.SetGroupMembersRequest) error {
badFields["version"] = "Required field."
}
for _, id := range req.GetMemberIds() {
if !handlers.ValidId(iam.UserPrefix, id) {
badFields["member_ids"] = "Must only contain valid user ids."
break
}
if id == "u_recovery" {
badFields["member_ids"] = "u_recovery cannot be assigned to a group"
badFields["member_ids"] = "u_recovery cannot be assigned to a group."
break
}
}
if len(badFields) > 0 {
@ -507,6 +518,12 @@ func validateRemoveGroupMembersRequest(req *pbs.RemoveGroupMembersRequest) error
if len(req.GetMemberIds()) == 0 {
badFields["member_ids"] = "Must be non-empty."
}
for _, id := range req.GetMemberIds() {
if !handlers.ValidId(iam.UserPrefix, id) {
badFields["member_ids"] = "Must only contain valid user ids."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}

@ -942,6 +942,14 @@ func TestAddMember(t *testing.T) {
addUsers: []string{users[1].GetPublicId()},
resultUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()},
},
{
name: "Add duplicate user on populated group",
setup: func(g *iam.Group) {
iam.TestGroupMember(t, conn, g.GetPublicId(), users[0].GetPublicId())
},
addUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()},
resultUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()},
},
{
name: "Add empty on populated group",
setup: func(g *iam.Group) {
@ -998,6 +1006,24 @@ func TestAddMember(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Invalid user id in member list",
req: &pbs.AddGroupMembersRequest{
Id: grp.GetPublicId(),
Version: grp.GetVersion(),
MemberIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "u_recovery",
req: &pbs.AddGroupMembersRequest{
Id: grp.GetPublicId(),
Version: grp.GetVersion(),
MemberIds: []string{"u_recovery"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1051,6 +1077,14 @@ func TestSetMember(t *testing.T) {
setUsers: []string{users[1].GetPublicId()},
resultUsers: []string{users[1].GetPublicId()},
},
{
name: "Set duplicate user on populated group",
setup: func(r *iam.Group) {
iam.TestGroupMember(t, conn, r.GetPublicId(), users[0].GetPublicId())
},
setUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()},
resultUsers: []string{users[1].GetPublicId()},
},
{
name: "Set empty on populated group",
setup: func(r *iam.Group) {
@ -1102,6 +1136,24 @@ func TestSetMember(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Invalid user id in member list",
req: &pbs.SetGroupMembersRequest{
Id: grp.GetPublicId(),
Version: grp.GetVersion(),
MemberIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "u_recovery",
req: &pbs.SetGroupMembersRequest{
Id: grp.GetPublicId(),
Version: grp.GetVersion(),
MemberIds: []string{"u_recovery"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1165,6 +1217,15 @@ func TestRemoveMember(t *testing.T) {
removeUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()},
resultUsers: []string{},
},
{
name: "Remove duplicate user from group",
setup: func(r *iam.Group) {
iam.TestGroupMember(t, conn, r.GetPublicId(), users[0].GetPublicId())
iam.TestGroupMember(t, conn, r.GetPublicId(), users[1].GetPublicId())
},
removeUsers: []string{users[0].GetPublicId(), users[0].GetPublicId()},
resultUsers: []string{users[1].GetPublicId()},
},
{
name: "Remove empty on populated group",
setup: func(r *iam.Group) {
@ -1203,25 +1264,34 @@ func TestRemoveMember(t *testing.T) {
failCases := []struct {
name string
req *pbs.AddGroupMembersRequest
req *pbs.RemoveGroupMembersRequest
err error
}{
{
name: "Bad Group Id",
req: &pbs.AddGroupMembersRequest{
req: &pbs.RemoveGroupMembersRequest{
Id: "bad id",
Version: grp.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Invalid user id in member list",
req: &pbs.RemoveGroupMembersRequest{
Id: grp.GetPublicId(),
Version: grp.GetVersion(),
MemberIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
_, gErr := s.AddGroupMembers(auth.DisabledAuthTestContext(auth.WithScopeId(grp.GetScopeId())), tc.req)
_, gErr := s.RemoveGroupMembers(auth.DisabledAuthTestContext(auth.WithScopeId(grp.GetScopeId())), tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "AddGroupMembers(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
assert.True(errors.Is(gErr, tc.err), "RemoveGroupMembers(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
}
})
}

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/boundary/internal/servers/controller/handlers"
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/sdk/strutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@ -294,7 +295,7 @@ func (s Service) addInRepo(ctx context.Context, scopeId, setId string, hostIds [
if err != nil {
return nil, err
}
_, err = repo.AddSetMembers(ctx, scopeId, setId, version, hostIds)
_, err = repo.AddSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add hosts to host set: %v.", err)
@ -314,7 +315,7 @@ func (s Service) setInRepo(ctx context.Context, scopeId, setId string, hostIds [
if err != nil {
return nil, err
}
_, _, err = repo.SetSetMembers(ctx, scopeId, setId, version, hostIds)
_, _, err = repo.SetSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set hosts in host set: %v.", err)
@ -335,7 +336,7 @@ func (s Service) removeInRepo(ctx context.Context, scopeId, setId string, hostId
if err != nil {
return nil, err
}
_, err = repo.DeleteSetMembers(ctx, scopeId, setId, version, hostIds)
_, err = repo.DeleteSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove hosts from host set: %v.", err)

@ -802,6 +802,14 @@ func TestAddHostSetHosts(t *testing.T) {
addHosts: []string{hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId(), hs[1].GetPublicId()},
},
{
name: "Add duplicate host on populated set",
setup: func(g *static.HostSet) {
static.TestSetMembers(t, conn, g.GetPublicId(), hs[:1])
},
addHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId(), hs[1].GetPublicId()},
},
}
for _, tc := range addCases {
@ -847,6 +855,15 @@ func TestAddHostSetHosts(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Invalid hosts in list",
req: &pbs.AddHostSetHostsRequest{
Id: ss.GetPublicId(),
Version: ss.GetVersion(),
HostIds: []string{"invalid_id"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -897,6 +914,14 @@ func TestSetHostSetHosts(t *testing.T) {
setHosts: []string{hs[1].GetPublicId()},
resultHosts: []string{hs[1].GetPublicId()},
},
{
name: "Set duplicate host on populated set",
setup: func(r *static.HostSet) {
static.TestSetMembers(t, conn, r.GetPublicId(), hs[:1])
},
setHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHosts: []string{hs[1].GetPublicId()},
},
{
name: "Set empty on populated set",
setup: func(r *static.HostSet) {
@ -938,6 +963,15 @@ func TestSetHostSetHosts(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad Host Id",
req: &pbs.SetHostSetHostsRequest{
Id: ss.GetPublicId(),
Version: ss.GetVersion(),
HostIds: []string{"invalid_id"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -989,6 +1023,14 @@ func TestRemoveHostSetHosts(t *testing.T) {
removeHosts: []string{hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId()},
},
{
name: "Remove 1 duplicate of 2 hosts from set",
setup: func(r *static.HostSet) {
static.TestSetMembers(t, conn, r.GetPublicId(), hs[:2])
},
removeHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId()},
},
{
name: "Remove all hosts from set",
setup: func(r *static.HostSet) {
@ -1041,12 +1083,21 @@ func TestRemoveHostSetHosts(t *testing.T) {
{
name: "empty hosts",
req: &pbs.RemoveHostSetHostsRequest{
Id: "bad id",
Id: ss.GetPublicId(),
Version: ss.GetVersion(),
HostIds: []string{},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "improperly formatted hosts",
req: &pbs.RemoveHostSetHostsRequest{
Id: ss.GetPublicId(),
Version: ss.GetVersion(),
HostIds: []string{"invalid_host_id"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {

@ -17,6 +17,7 @@ import (
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@ -355,7 +356,7 @@ func (s Service) addPrinciplesInRepo(ctx context.Context, roleId string, princip
if err != nil {
return nil, err
}
_, err = repo.AddPrincipalRoles(ctx, roleId, version, principalIds)
_, err = repo.AddPrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add principals to role: %v.", err)
@ -375,7 +376,7 @@ func (s Service) setPrinciplesInRepo(ctx context.Context, roleId string, princip
if err != nil {
return nil, err
}
_, _, err = repo.SetPrincipalRoles(ctx, roleId, version, principalIds)
_, _, err = repo.SetPrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set principals on role: %v.", err)
@ -395,7 +396,7 @@ func (s Service) removePrinciplesInRepo(ctx context.Context, roleId string, prin
if err != nil {
return nil, err
}
_, err = repo.DeletePrincipalRoles(ctx, roleId, version, principalIds)
_, err = repo.DeletePrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove principals from role: %v.", err)
@ -415,7 +416,7 @@ func (s Service) addGrantsInRepo(ctx context.Context, roleId string, grants []st
if err != nil {
return nil, err
}
_, err = repo.AddRoleGrants(ctx, roleId, version, grants)
_, err = repo.AddRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add grants to role: %v.", err)
@ -439,7 +440,7 @@ func (s Service) setGrantsInRepo(ctx context.Context, roleId string, grants []st
if grants == nil {
grants = []string{}
}
_, _, err = repo.SetRoleGrants(ctx, roleId, version, grants)
_, _, err = repo.SetRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set grants on role: %v.", err)
@ -459,7 +460,7 @@ func (s Service) removeGrantsInRepo(ctx context.Context, roleId string, grants [
if err != nil {
return nil, err
}
_, err = repo.DeleteRoleGrants(ctx, roleId, version, grants)
_, err = repo.DeleteRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove grants from role: %v", err)
@ -654,8 +655,13 @@ func validateAddRolePrincipalsRequest(req *pbs.AddRolePrincipalsRequest) error {
badFields["principal_ids"] = "Must be non-empty."
}
for _, id := range req.GetPrincipalIds() {
if !handlers.ValidId(iam.GroupPrefix, id) && ! handlers.ValidId(iam.UserPrefix, id) {
badFields["principal_ids"] = "Must only have valid group and/or user ids."
break
}
if id == "u_recovery" {
badFields["principal_ids"] = "u_recovery cannot be assigned to a role"
break
}
}
if len(badFields) > 0 {
@ -673,8 +679,13 @@ func validateSetRolePrincipalsRequest(req *pbs.SetRolePrincipalsRequest) error {
badFields["version"] = "Required field."
}
for _, id := range req.GetPrincipalIds() {
if !handlers.ValidId(iam.GroupPrefix, id) && ! handlers.ValidId(iam.UserPrefix, id) {
badFields["principal_ids"] = "Must only have valid group and/or user ids."
break
}
if id == "u_recovery" {
badFields["principal_ids"] = "u_recovery cannot be assigned to a role"
break
}
}
if len(badFields) > 0 {
@ -694,6 +705,12 @@ func validateRemoveRolePrincipalsRequest(req *pbs.RemoveRolePrincipalsRequest) e
if len(req.GetPrincipalIds()) == 0 {
badFields["principal_ids"] = "Must be non-empty."
}
for _, id := range req.GetPrincipalIds() {
if !handlers.ValidId(iam.GroupPrefix, id) && !handlers.ValidId(iam.UserPrefix, id) {
badFields["principal_ids"] = "Must only have valid group and/or user ids."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -711,6 +728,16 @@ func validateAddRoleGrantsRequest(req *pbs.AddRoleGrantsRequest) error {
if len(req.GetGrantStrings()) == 0 {
badFields["grant_strings"] = "Must be non-empty."
}
for _, v := range req.GetGrantStrings() {
if len(v) == 0 {
badFields["grant_strings"] = "Grant strings must not be empty."
break
}
if _, err := perms.Parse("p_anything", v); err != nil {
badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v)
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -725,6 +752,16 @@ func validateSetRoleGrantsRequest(req *pbs.SetRoleGrantsRequest) error {
if req.GetVersion() == 0 {
badFields["version"] = "Required field."
}
for _, v := range req.GetGrantStrings() {
if len(v) == 0 {
badFields["grant_strings"] = "Grant strings must not be empty."
break
}
if _, err := perms.Parse("p_anything", v); err != nil {
badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v)
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -742,6 +779,16 @@ func validateRemoveRoleGrantsRequest(req *pbs.RemoveRoleGrantsRequest) error {
if len(req.GetGrantStrings()) == 0 {
badFields["grant_strings"] = "Must be non-empty."
}
for _, v := range req.GetGrantStrings() {
if len(v) == 0 {
badFields["grant_strings"] = "Grant strings must not be empty."
break
}
if _, err := perms.Parse("p_anything", v); err != nil {
badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v)
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}

@ -995,6 +995,14 @@ func TestAddPrincipal(t *testing.T) {
addGroups: []string{groups[1].GetPublicId()},
resultGroups: []string{groups[0].GetPublicId(), groups[1].GetPublicId()},
},
{
name: "Add duplicate group on populated role",
setup: func(r *iam.Role) {
iam.TestGroupRole(t, conn, r.GetPublicId(), groups[0].GetPublicId())
},
addGroups: []string{groups[1].GetPublicId(), groups[1].GetPublicId()},
resultGroups: []string{groups[0].GetPublicId(), groups[1].GetPublicId()},
},
{
name: "Add invalid u_recovery on role",
setup: func(r *iam.Role) {},
@ -1043,6 +1051,24 @@ func TestAddPrincipal(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad Principal Id",
req: &pbs.AddRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "u_recovery Id",
req: &pbs.AddRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"u_recovery"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1101,6 +1127,14 @@ func TestSetPrincipal(t *testing.T) {
setUsers: []string{users[1].GetPublicId()},
resultUsers: []string{users[1].GetPublicId()},
},
{
name: "Set duplicate user on populated role",
setup: func(r *iam.Role) {
iam.TestUserRole(t, conn, r.GetPublicId(), users[0].GetPublicId())
},
setUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()},
resultUsers: []string{users[1].GetPublicId()},
},
{
name: "Set empty on populated role",
setup: func(r *iam.Role) {
@ -1172,6 +1206,24 @@ func TestSetPrincipal(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad Principal Id",
req: &pbs.SetRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "u_recovery",
req: &pbs.SetRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"u_recovery"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1231,6 +1283,15 @@ func TestRemovePrincipal(t *testing.T) {
removeUsers: []string{users[1].GetPublicId()},
resultUsers: []string{users[0].GetPublicId()},
},
{
name: "Remove 1 duplicate user of 2 users from role",
setup: func(r *iam.Role) {
iam.TestUserRole(t, conn, r.GetPublicId(), users[0].GetPublicId())
iam.TestUserRole(t, conn, r.GetPublicId(), users[1].GetPublicId())
},
removeUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()},
resultUsers: []string{users[0].GetPublicId()},
},
{
name: "Remove all users from role",
setup: func(r *iam.Role) {
@ -1262,6 +1323,15 @@ func TestRemovePrincipal(t *testing.T) {
removeGroups: []string{groups[1].GetPublicId()},
resultGroups: []string{groups[0].GetPublicId()},
},
{
name: "Remove 1 duplicate group of 2 groups from role",
setup: func(r *iam.Role) {
iam.TestGroupRole(t, conn, r.GetPublicId(), groups[0].GetPublicId())
iam.TestGroupRole(t, conn, r.GetPublicId(), groups[1].GetPublicId())
},
removeGroups: []string{groups[1].GetPublicId(), groups[1].GetPublicId()},
resultGroups: []string{groups[0].GetPublicId()},
},
{
name: "Remove all groups from role",
setup: func(r *iam.Role) {
@ -1302,22 +1372,40 @@ func TestRemovePrincipal(t *testing.T) {
failCases := []struct {
name string
req *pbs.AddRolePrincipalsRequest
req *pbs.RemoveRolePrincipalsRequest
err error
}{
{
name: "Bad Role Id",
req: &pbs.AddRolePrincipalsRequest{
req: &pbs.RemoveRolePrincipalsRequest{
Id: "bad id",
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad User Id",
req: &pbs.RemoveRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"g_validgroup", "invaliduser"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad Group Id",
req: &pbs.RemoveRolePrincipalsRequest{
Id: role.GetPublicId(),
Version: role.GetVersion(),
PrincipalIds: []string{"u_validuser", "invalidgroup"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
_, gErr := s.AddRolePrincipals(auth.DisabledAuthTestContext(auth.WithScopeId(p.GetPublicId())), tc.req)
_, gErr := s.RemoveRolePrincipals(auth.DisabledAuthTestContext(auth.WithScopeId(p.GetPublicId())), tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "AddRolePrincipals(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
@ -1373,6 +1461,12 @@ func TestAddGrants(t *testing.T) {
add: []string{"id=*;type=*;actions=delete"},
result: []string{"id=1;actions=read", "id=*;type=*;actions=delete"},
},
{
name: "Add duplicate grant on role with grant",
existing: []string{"id=1;actions=read"},
add: []string{"id=*;type=*;actions=delete", "id=*;type=*;actions=delete"},
result: []string{"id=1;actions=read", "id=*;type=*;actions=delete"},
},
{
name: "Add grant matching existing grant",
existing: []string{"id=1;actions=read", "id=*;type=*;actions=delete"},
@ -1435,6 +1529,24 @@ func TestAddGrants(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Unparseable Grant",
req: &pbs.AddRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;type=*;actions=create", "unparseable"},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Empty Grant",
req: &pbs.AddRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;type=*;actions=create", ""},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1484,8 +1596,14 @@ func TestSetGrants(t *testing.T) {
result: []string{"id=*;type=*;actions=delete"},
},
{
name: "Set empty on role",
name: "Set duplicate grant matching existing grant",
existing: []string{"id=1;actions=read", "id=*;type=*;actions=delete"},
set: []string{"id=*;type=*;actions=delete", "id=*;type=*;actions=delete"},
result: []string{"id=*;type=*;actions=delete"},
},
{
name: "Set empty on role",
existing: []string{"id=1;type=*;actions=read", "id=*;type=*;actions=delete"},
set: nil,
result: nil,
},
@ -1514,8 +1632,7 @@ func TestSetGrants(t *testing.T) {
assert.Error(err)
return
}
s, _ := status.FromError(err)
require.NoError(err, "Got error %v", s)
require.NoError(err, "Got error %v", err)
checkEqualGrants(t, tc.result, got.GetItem())
})
}
@ -1547,6 +1664,15 @@ func TestSetGrants(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Unparsable grant",
req: &pbs.SetRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;type=*;actions=create", "unparseable"},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1579,25 +1705,31 @@ func TestRemoveGrants(t *testing.T) {
}{
{
name: "Remove all",
existing: []string{"id=1;actions=read"},
remove: []string{"id=1;actions=read"},
existing: []string{"id=1;type=*;actions=read"},
remove: []string{"id=1;type=*;actions=read"},
},
{
name: "Remove partial",
existing: []string{"id=1;actions=read", "id=2;actions=delete"},
remove: []string{"id=1;actions=read"},
result: []string{"id=2;actions=delete"},
existing: []string{"id=1;type=*;actions=read", "id=2;type=*;actions=delete"},
remove: []string{"id=1;type=*;actions=read"},
result: []string{"id=2;type=*;actions=delete"},
},
{
name: "Remove duplicate",
existing: []string{"id=1;type=*;actions=read", "id=2;type=*;actions=delete"},
remove: []string{"id=1;type=*;actions=read", "id=1;type=*;actions=read"},
result: []string{"id=2;type=*;actions=delete"},
},
{
name: "Remove non existant",
existing: []string{"id=2;actions=delete"},
remove: []string{"id=1;actions=read"},
result: []string{"id=2;actions=delete"},
existing: []string{"id=2;type=*;actions=delete"},
remove: []string{"id=1;type=*;actions=read"},
result: []string{"id=2;type=*;actions=delete"},
},
{
name: "Remove from empty role",
existing: []string{},
remove: []string{"id=1;actions=read"},
remove: []string{"id=1;type=*;actions=read"},
result: nil,
},
}
@ -1645,7 +1777,7 @@ func TestRemoveGrants(t *testing.T) {
name: "Bad Version",
req: &pbs.RemoveRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;actions=create"},
GrantStrings: []string{"id=2;type=*;actions=create"},
Version: role.GetVersion() + 2,
},
err: handlers.ApiErrorWithCode(codes.Internal),
@ -1654,7 +1786,25 @@ func TestRemoveGrants(t *testing.T) {
name: "Bad Role Id",
req: &pbs.RemoveRoleGrantsRequest{
Id: "bad id",
GrantStrings: []string{"id=*;actions=create"},
GrantStrings: []string{"id=*;type=*;actions=create"},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Empty Grant",
req: &pbs.RemoveRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;type=*;actions=create", ""},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Unparseable Grant",
req: &pbs.RemoveRoleGrantsRequest{
Id: role.GetPublicId(),
GrantStrings: []string{"id=*;type=*;actions=create", ";unparsable=2"},
Version: role.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),

@ -14,6 +14,7 @@ import (
pb "github.com/hashicorp/boundary/internal/gen/controller/api/resources/targets"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/servers/controller/common"
@ -24,6 +25,7 @@ import (
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
"github.com/mr-tron/base58"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/proto"
@ -556,7 +558,7 @@ func (s Service) addInRepo(ctx context.Context, targetId string, hostSetId []str
if err != nil {
return nil, err
}
out, m, err := repo.AddTargetHostSets(ctx, targetId, version, hostSetId)
out, m, err := repo.AddTargetHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetId, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add host sets to target: %v.", err)
@ -572,7 +574,7 @@ func (s Service) setInRepo(ctx context.Context, targetId string, hostSetIds []st
if err != nil {
return nil, err
}
_, _, err = repo.SetTargetHostSets(ctx, targetId, version, hostSetIds)
_, _, err = repo.SetTargetHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set host sets in target: %v.", err)
@ -593,7 +595,7 @@ func (s Service) removeInRepo(ctx context.Context, targetId string, hostSetIds [
if err != nil {
return nil, err
}
_, err = repo.DeleteTargeHostSets(ctx, targetId, version, hostSetIds)
_, err = repo.DeleteTargeHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetIds, false))
if err != nil {
// TODO: Figure out a way to surface more helpful error info beyond the Internal error.
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove host sets from target: %v.", err)
@ -800,6 +802,12 @@ func validateAddRequest(req *pbs.AddTargetHostSetsRequest) error {
if len(req.GetHostSetIds()) == 0 {
badFields["host_set_ids"] = "Must be non-empty."
}
for _, id := range req.GetHostSetIds() {
if !handlers.ValidId(static.HostSetPrefix, id) {
badFields["host_set_ids"] = "Incorrectly formatted host set identifier."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -814,6 +822,12 @@ func validateSetRequest(req *pbs.SetTargetHostSetsRequest) error {
if req.GetVersion() == 0 {
badFields["version"] = "Required field."
}
for _, id := range req.GetHostSetIds() {
if !handlers.ValidId(static.HostSetPrefix, id) {
badFields["host_set_ids"] = "Incorrectly formatted host set identifier."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -831,6 +845,12 @@ func validateRemoveRequest(req *pbs.RemoveTargetHostSetsRequest) error {
if len(req.GetHostSetIds()) == 0 {
badFields["host_set_ids"] = "Must be non-empty."
}
for _, id := range req.GetHostSetIds() {
if !handlers.ValidId(static.HostSetPrefix, id) {
badFields["host_set_ids"] = "Incorrectly formatted host set identifier."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}

@ -824,11 +824,17 @@ func TestAddTargetHostSets(t *testing.T) {
resultHostSets: []string{hs[1].GetPublicId()},
},
{
name: "Add host on populated target",
name: "Add set on populated target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "populated", target.WithHostSets([]string{hs[0].GetPublicId()})),
addHostSets: []string{hs[1].GetPublicId()},
resultHostSets: []string{hs[0].GetPublicId(), hs[1].GetPublicId()},
},
{
name: "Add duplicated sets on populated target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "duplicated", target.WithHostSets([]string{hs[0].GetPublicId()})),
addHostSets: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHostSets: []string{hs[0].GetPublicId(), hs[1].GetPublicId()},
},
}
for _, tc := range addCases {
@ -881,6 +887,15 @@ func TestAddTargetHostSets(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Incorrect host set ids",
req: &pbs.AddTargetHostSetsRequest{
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
HostSetIds: []string{"incorrect"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -914,19 +929,25 @@ func TestSetTargetHostSets(t *testing.T) {
resultHostSets []string
}{
{
name: "Set host on empty set",
name: "Set on empty target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "empty"),
setHostSets: []string{hs[1].GetPublicId()},
resultHostSets: []string{hs[1].GetPublicId()},
},
{
name: "Set host on populated set",
name: "Set on populated target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "populated", target.WithHostSets([]string{hs[0].GetPublicId()})),
setHostSets: []string{hs[1].GetPublicId()},
resultHostSets: []string{hs[1].GetPublicId()},
},
{
name: "Set empty on populated set",
name: "Set duplicate host set on populated target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "duplicate", target.WithHostSets([]string{hs[0].GetPublicId()})),
setHostSets: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHostSets: []string{hs[1].GetPublicId()},
},
{
name: "Set empty on populated target",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "another populated", target.WithHostSets([]string{hs[0].GetPublicId()})),
setHostSets: []string{},
resultHostSets: nil,
@ -971,6 +992,15 @@ func TestSetTargetHostSets(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.Internal),
},
{
name: "Bad host set id",
req: &pbs.SetTargetHostSetsRequest{
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
HostSetIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -1016,6 +1046,12 @@ func TestRemoveTargetHostSets(t *testing.T) {
removeHosts: []string{hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId()},
},
{
name: "Remove 1 duplicate set of 2 sets",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "remove duplicate", target.WithHostSets([]string{hs[0].GetPublicId(), hs[1].GetPublicId()})),
removeHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()},
resultHosts: []string{hs[0].GetPublicId()},
},
{
name: "Remove all hosts from set",
tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "remove all", target.WithHostSets([]string{hs[0].GetPublicId(), hs[1].GetPublicId()})),
@ -1073,12 +1109,21 @@ func TestRemoveTargetHostSets(t *testing.T) {
{
name: "empty sets",
req: &pbs.RemoveTargetHostSetsRequest{
Id: "bad id",
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
HostSetIds: []string{},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Invalid set ids",
req: &pbs.RemoveTargetHostSetsRequest{
Id: tar.GetPublicId(),
Version: tar.GetVersion(),
HostSetIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/db"
pb "github.com/hashicorp/boundary/internal/gen/controller/api/resources/users"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
@ -16,6 +17,7 @@ import (
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/strutil"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/types/known/wrapperspb"
)
@ -295,7 +297,7 @@ func (s Service) addInRepo(ctx context.Context, userId string, accountIds []stri
if err != nil {
return nil, err
}
_, err = repo.AddUserAccounts(ctx, userId, version, accountIds)
_, err = repo.AddUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false))
if err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add accounts to user: %v.", err)
}
@ -314,7 +316,7 @@ func (s Service) setInRepo(ctx context.Context, userId string, accountIds []stri
if err != nil {
return nil, err
}
_, err = repo.SetUserAccounts(ctx, userId, version, accountIds)
_, err = repo.SetUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false))
if err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set accounts for the user: %v.", err)
}
@ -333,7 +335,7 @@ func (s Service) removeInRepo(ctx context.Context, userId string, accountIds []s
if err != nil {
return nil, err
}
_, err = repo.DeleteUserAccounts(ctx, userId, version, accountIds)
_, err = repo.DeleteUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false))
if err != nil {
return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove accounts from user: %v.", err)
}
@ -462,6 +464,13 @@ func validateAddUserAccountsRequest(req *pbs.AddUserAccountsRequest) error {
if len(req.GetAccountIds()) == 0 {
badFields["account_ids"] = "Must be non-empty."
}
for _, a := range req.GetAccountIds() {
// TODO: Increase the type of auth accounts that can be added to a user.
if !handlers.ValidId(password.AccountPrefix, a) {
badFields["account_ids"] = "Values must be valid account ids."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -476,6 +485,13 @@ func validateSetUserAccountsRequest(req *pbs.SetUserAccountsRequest) error {
if req.GetVersion() == 0 {
badFields["version"] = "Required field."
}
for _, a := range req.GetAccountIds() {
// TODO: Increase the type of auth accounts that can be added to a user.
if !handlers.ValidId(password.AccountPrefix, a) {
badFields["account_ids"] = "Values must be valid account ids."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}
@ -493,6 +509,13 @@ func validateRemoveUserAccountsRequest(req *pbs.RemoveUserAccountsRequest) error
if len(req.GetAccountIds()) == 0 {
badFields["account_ids"] = "Must be non-empty."
}
for _, a := range req.GetAccountIds() {
// TODO: Increase the type of auth accounts that can be added to a user.
if !handlers.ValidId(password.AccountPrefix, a) {
badFields["account_ids"] = "Values must be valid account ids."
break
}
}
if len(badFields) > 0 {
return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields)
}

@ -644,6 +644,17 @@ func TestAddAccount(t *testing.T) {
addAccounts: []string{accts[1].GetPublicId()},
resultAccounts: []string{accts[0].GetPublicId(), accts[1].GetPublicId()},
},
{
name: "Add duplicate account on populated user",
setup: func(u *iam.User) {
_, err := iamRepo.SetUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(),
[]string{accts[0].GetPublicId()})
require.NoError(t, err)
u.Version = u.Version + 1
},
addAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()},
resultAccounts: []string{accts[0].GetPublicId(), accts[1].GetPublicId()},
},
{
name: "Add empty on populated user",
setup: func(u *iam.User) {
@ -695,6 +706,15 @@ func TestAddAccount(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad account Id",
req: &pbs.AddUserAccountsRequest{
Id: usr.GetPublicId(),
Version: usr.GetVersion(),
AccountIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -746,6 +766,17 @@ func TestSetAccount(t *testing.T) {
setAccounts: []string{accts[1].GetPublicId()},
resultAccounts: []string{accts[1].GetPublicId()},
},
{
name: "Set duplicate account on populated user",
setup: func(u *iam.User) {
iamRepo.AddUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(),
[]string{accts[0].GetPublicId()})
require.NoError(t, err)
u.Version = u.Version + 1
},
setAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()},
resultAccounts: []string{accts[1].GetPublicId()},
},
{
name: "Set empty on populated user",
setup: func(u *iam.User) {
@ -798,6 +829,15 @@ func TestSetAccount(t *testing.T) {
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad account Id",
req: &pbs.SetUserAccountsRequest{
Id: usr.GetPublicId(),
Version: usr.GetVersion(),
AccountIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
@ -849,6 +889,17 @@ func TestRemoveAccount(t *testing.T) {
removeAccounts: []string{accts[1].GetPublicId()},
resultAccounts: []string{accts[0].GetPublicId()},
},
{
name: "Remove 1 duplicate accounts of 2 accounts from user",
setup: func(u *iam.User) {
_, err := iamRepo.SetUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(),
[]string{accts[0].GetPublicId(), accts[1].GetPublicId()})
require.NoError(t, err)
u.Version = u.Version + 1
},
removeAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()},
resultAccounts: []string{accts[0].GetPublicId()},
},
{
name: "Remove all accounts from user",
setup: func(u *iam.User) {
@ -900,22 +951,31 @@ func TestRemoveAccount(t *testing.T) {
failCases := []struct {
name string
req *pbs.AddUserAccountsRequest
req *pbs.RemoveUserAccountsRequest
err error
}{
{
name: "Bad User Id",
req: &pbs.AddUserAccountsRequest{
req: &pbs.RemoveUserAccountsRequest{
Id: "bad id",
Version: usr.GetVersion(),
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "Bad account Id",
req: &pbs.RemoveUserAccountsRequest{
Id: usr.GetPublicId(),
Version: usr.GetVersion(),
AccountIds: []string{"invalid"},
},
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range failCases {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
_, gErr := s.AddUserAccounts(auth.DisabledAuthTestContext(auth.WithScopeId(usr.GetScopeId())), tc.req)
_, gErr := s.RemoveUserAccounts(auth.DisabledAuthTestContext(auth.WithScopeId(usr.GetScopeId())), tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "AddUserAccounts(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)

Loading…
Cancel
Save