From f76d5aed9f3c197ef04c683423f87a100c7ef08b Mon Sep 17 00:00:00 2001 From: Todd Knight Date: Fri, 2 Oct 2020 16:45:32 -0700 Subject: [PATCH] Additional validation for add|set|remove methods. (#527) --- internal/perms/grants.go | 2 +- .../authmethods/authmethod_service.go | 4 +- .../handlers/groups/group_service.go | 27 ++- .../handlers/groups/group_service_test.go | 78 +++++++- .../handlers/host_sets/host_set_service.go | 7 +- .../host_sets/host_set_service_test.go | 53 ++++- .../controller/handlers/roles/role_service.go | 59 +++++- .../handlers/roles/role_service_test.go | 184 ++++++++++++++++-- .../handlers/targets/target_service.go | 26 ++- .../handlers/targets/target_service_test.go | 55 +++++- .../controller/handlers/users/user_service.go | 29 ++- .../handlers/users/user_service_test.go | 66 ++++++- 12 files changed, 537 insertions(+), 53 deletions(-) diff --git a/internal/perms/grants.go b/internal/perms/grants.go index 92ac9320d3..81b5fe5243 100644 --- a/internal/perms/grants.go +++ b/internal/perms/grants.go @@ -268,7 +268,7 @@ func Parse(scopeId, grantString string, opt ...Option) (Grant, error) { opts := getOpts(opt...) - // Check for templated values ID, and subtitute in with the authenticated values + // Check for templated values ID, and substitute in with the authenticated values // if so if grant.id != "" && strings.HasPrefix(grant.id, "{{") { id := strings.TrimSuffix(strings.TrimPrefix(grant.id, "{{"), "}}") diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service.go b/internal/servers/controller/handlers/authmethods/authmethod_service.go index 93ffa7ead3..e48c939a99 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service.go @@ -344,7 +344,7 @@ func validateCreateRequest(req *pbs.CreateAuthMethodRequest) error { badFields := map[string]string{} if !handlers.ValidId(scope.Org.Prefix(), req.GetItem().GetScopeId()) && scope.Global.String() != req.GetItem().GetScopeId() { - badFields["scope_id"] = "This field is missing or improperly formatted." + badFields["scope_id"] = "This field must be 'global' or a valid org scope id." } switch auth.SubtypeFromType(req.GetItem().GetType()) { case auth.PasswordSubtype: @@ -386,7 +386,7 @@ func validateListRequest(req *pbs.ListAuthMethodsRequest) error { badFields := map[string]string{} if !handlers.ValidId(scope.Org.Prefix(), req.GetScopeId()) && req.GetScopeId() != scope.Global.String() { - badFields["scope_id"] = "This field must be a valid project scope id." + badFields["scope_id"] = "This field must be 'global' or a valid org scope id." } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Improperly formatted identifier.", badFields) diff --git a/internal/servers/controller/handlers/groups/group_service.go b/internal/servers/controller/handlers/groups/group_service.go index 61adc7008d..ed1cd6cde3 100644 --- a/internal/servers/controller/handlers/groups/group_service.go +++ b/internal/servers/controller/handlers/groups/group_service.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/sdk/strutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -295,7 +296,7 @@ func (s Service) addMembersInRepo(ctx context.Context, groupId string, userIds [ if err != nil { return nil, err } - _, err = repo.AddGroupMembers(ctx, groupId, version, userIds) + _, err = repo.AddGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add members to group: %v.", err) @@ -315,7 +316,7 @@ func (s Service) setMembersInRepo(ctx context.Context, groupId string, userIds [ if err != nil { return nil, err } - _, _, err = repo.SetGroupMembers(ctx, groupId, version, userIds) + _, _, err = repo.SetGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set members on group: %v.", err) @@ -335,7 +336,7 @@ func (s Service) removeMembersInRepo(ctx context.Context, groupId string, userId if err != nil { return nil, err } - _, err = repo.DeleteGroupMembers(ctx, groupId, version, userIds) + _, err = repo.DeleteGroupMembers(ctx, groupId, version, strutil.RemoveDuplicates(userIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove members from group: %v.", err) @@ -467,8 +468,13 @@ func validateAddGroupMembersRequest(req *pbs.AddGroupMembersRequest) error { badFields["member_ids"] = "Must be non-empty." } for _, id := range req.GetMemberIds() { + if !handlers.ValidId(iam.UserPrefix, id) { + badFields["member_ids"] = "Must only contain valid user ids." + break + } if id == "u_recovery" { - badFields["member_ids"] = "u_recovery cannot be assigned to a group" + badFields["member_ids"] = "u_recovery cannot be assigned to a group." + break } } if len(badFields) > 0 { @@ -486,8 +492,13 @@ func validateSetGroupMembersRequest(req *pbs.SetGroupMembersRequest) error { badFields["version"] = "Required field." } for _, id := range req.GetMemberIds() { + if !handlers.ValidId(iam.UserPrefix, id) { + badFields["member_ids"] = "Must only contain valid user ids." + break + } if id == "u_recovery" { - badFields["member_ids"] = "u_recovery cannot be assigned to a group" + badFields["member_ids"] = "u_recovery cannot be assigned to a group." + break } } if len(badFields) > 0 { @@ -507,6 +518,12 @@ func validateRemoveGroupMembersRequest(req *pbs.RemoveGroupMembersRequest) error if len(req.GetMemberIds()) == 0 { badFields["member_ids"] = "Must be non-empty." } + for _, id := range req.GetMemberIds() { + if !handlers.ValidId(iam.UserPrefix, id) { + badFields["member_ids"] = "Must only contain valid user ids." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } diff --git a/internal/servers/controller/handlers/groups/group_service_test.go b/internal/servers/controller/handlers/groups/group_service_test.go index 697e1cbbd3..2a080bf607 100644 --- a/internal/servers/controller/handlers/groups/group_service_test.go +++ b/internal/servers/controller/handlers/groups/group_service_test.go @@ -942,6 +942,14 @@ func TestAddMember(t *testing.T) { addUsers: []string{users[1].GetPublicId()}, resultUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()}, }, + { + name: "Add duplicate user on populated group", + setup: func(g *iam.Group) { + iam.TestGroupMember(t, conn, g.GetPublicId(), users[0].GetPublicId()) + }, + addUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()}, + resultUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()}, + }, { name: "Add empty on populated group", setup: func(g *iam.Group) { @@ -998,6 +1006,24 @@ func TestAddMember(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid user id in member list", + req: &pbs.AddGroupMembersRequest{ + Id: grp.GetPublicId(), + Version: grp.GetVersion(), + MemberIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "u_recovery", + req: &pbs.AddGroupMembersRequest{ + Id: grp.GetPublicId(), + Version: grp.GetVersion(), + MemberIds: []string{"u_recovery"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1051,6 +1077,14 @@ func TestSetMember(t *testing.T) { setUsers: []string{users[1].GetPublicId()}, resultUsers: []string{users[1].GetPublicId()}, }, + { + name: "Set duplicate user on populated group", + setup: func(r *iam.Group) { + iam.TestGroupMember(t, conn, r.GetPublicId(), users[0].GetPublicId()) + }, + setUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()}, + resultUsers: []string{users[1].GetPublicId()}, + }, { name: "Set empty on populated group", setup: func(r *iam.Group) { @@ -1102,6 +1136,24 @@ func TestSetMember(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid user id in member list", + req: &pbs.SetGroupMembersRequest{ + Id: grp.GetPublicId(), + Version: grp.GetVersion(), + MemberIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "u_recovery", + req: &pbs.SetGroupMembersRequest{ + Id: grp.GetPublicId(), + Version: grp.GetVersion(), + MemberIds: []string{"u_recovery"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1165,6 +1217,15 @@ func TestRemoveMember(t *testing.T) { removeUsers: []string{users[0].GetPublicId(), users[1].GetPublicId()}, resultUsers: []string{}, }, + { + name: "Remove duplicate user from group", + setup: func(r *iam.Group) { + iam.TestGroupMember(t, conn, r.GetPublicId(), users[0].GetPublicId()) + iam.TestGroupMember(t, conn, r.GetPublicId(), users[1].GetPublicId()) + }, + removeUsers: []string{users[0].GetPublicId(), users[0].GetPublicId()}, + resultUsers: []string{users[1].GetPublicId()}, + }, { name: "Remove empty on populated group", setup: func(r *iam.Group) { @@ -1203,25 +1264,34 @@ func TestRemoveMember(t *testing.T) { failCases := []struct { name string - req *pbs.AddGroupMembersRequest + req *pbs.RemoveGroupMembersRequest err error }{ { name: "Bad Group Id", - req: &pbs.AddGroupMembersRequest{ + req: &pbs.RemoveGroupMembersRequest{ Id: "bad id", Version: grp.GetVersion(), }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid user id in member list", + req: &pbs.RemoveGroupMembersRequest{ + Id: grp.GetPublicId(), + Version: grp.GetVersion(), + MemberIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - _, gErr := s.AddGroupMembers(auth.DisabledAuthTestContext(auth.WithScopeId(grp.GetScopeId())), tc.req) + _, gErr := s.RemoveGroupMembers(auth.DisabledAuthTestContext(auth.WithScopeId(grp.GetScopeId())), tc.req) if tc.err != nil { require.Error(gErr) - assert.True(errors.Is(gErr, tc.err), "AddGroupMembers(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) + assert.True(errors.Is(gErr, tc.err), "RemoveGroupMembers(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) } }) } diff --git a/internal/servers/controller/handlers/host_sets/host_set_service.go b/internal/servers/controller/handlers/host_sets/host_set_service.go index d54dc52b52..0b8389afc4 100644 --- a/internal/servers/controller/handlers/host_sets/host_set_service.go +++ b/internal/servers/controller/handlers/host_sets/host_set_service.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/servers/controller/handlers" "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" + "github.com/hashicorp/boundary/sdk/strutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -294,7 +295,7 @@ func (s Service) addInRepo(ctx context.Context, scopeId, setId string, hostIds [ if err != nil { return nil, err } - _, err = repo.AddSetMembers(ctx, scopeId, setId, version, hostIds) + _, err = repo.AddSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add hosts to host set: %v.", err) @@ -314,7 +315,7 @@ func (s Service) setInRepo(ctx context.Context, scopeId, setId string, hostIds [ if err != nil { return nil, err } - _, _, err = repo.SetSetMembers(ctx, scopeId, setId, version, hostIds) + _, _, err = repo.SetSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set hosts in host set: %v.", err) @@ -335,7 +336,7 @@ func (s Service) removeInRepo(ctx context.Context, scopeId, setId string, hostId if err != nil { return nil, err } - _, err = repo.DeleteSetMembers(ctx, scopeId, setId, version, hostIds) + _, err = repo.DeleteSetMembers(ctx, scopeId, setId, version, strutil.RemoveDuplicates(hostIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove hosts from host set: %v.", err) diff --git a/internal/servers/controller/handlers/host_sets/host_set_service_test.go b/internal/servers/controller/handlers/host_sets/host_set_service_test.go index 7ffff3eb68..f393772fc3 100644 --- a/internal/servers/controller/handlers/host_sets/host_set_service_test.go +++ b/internal/servers/controller/handlers/host_sets/host_set_service_test.go @@ -802,6 +802,14 @@ func TestAddHostSetHosts(t *testing.T) { addHosts: []string{hs[1].GetPublicId()}, resultHosts: []string{hs[0].GetPublicId(), hs[1].GetPublicId()}, }, + { + name: "Add duplicate host on populated set", + setup: func(g *static.HostSet) { + static.TestSetMembers(t, conn, g.GetPublicId(), hs[:1]) + }, + addHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHosts: []string{hs[0].GetPublicId(), hs[1].GetPublicId()}, + }, } for _, tc := range addCases { @@ -847,6 +855,15 @@ func TestAddHostSetHosts(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid hosts in list", + req: &pbs.AddHostSetHostsRequest{ + Id: ss.GetPublicId(), + Version: ss.GetVersion(), + HostIds: []string{"invalid_id"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -897,6 +914,14 @@ func TestSetHostSetHosts(t *testing.T) { setHosts: []string{hs[1].GetPublicId()}, resultHosts: []string{hs[1].GetPublicId()}, }, + { + name: "Set duplicate host on populated set", + setup: func(r *static.HostSet) { + static.TestSetMembers(t, conn, r.GetPublicId(), hs[:1]) + }, + setHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHosts: []string{hs[1].GetPublicId()}, + }, { name: "Set empty on populated set", setup: func(r *static.HostSet) { @@ -938,6 +963,15 @@ func TestSetHostSetHosts(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad Host Id", + req: &pbs.SetHostSetHostsRequest{ + Id: ss.GetPublicId(), + Version: ss.GetVersion(), + HostIds: []string{"invalid_id"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -989,6 +1023,14 @@ func TestRemoveHostSetHosts(t *testing.T) { removeHosts: []string{hs[1].GetPublicId()}, resultHosts: []string{hs[0].GetPublicId()}, }, + { + name: "Remove 1 duplicate of 2 hosts from set", + setup: func(r *static.HostSet) { + static.TestSetMembers(t, conn, r.GetPublicId(), hs[:2]) + }, + removeHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHosts: []string{hs[0].GetPublicId()}, + }, { name: "Remove all hosts from set", setup: func(r *static.HostSet) { @@ -1041,12 +1083,21 @@ func TestRemoveHostSetHosts(t *testing.T) { { name: "empty hosts", req: &pbs.RemoveHostSetHostsRequest{ - Id: "bad id", + Id: ss.GetPublicId(), Version: ss.GetVersion(), HostIds: []string{}, }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "improperly formatted hosts", + req: &pbs.RemoveHostSetHostsRequest{ + Id: ss.GetPublicId(), + Version: ss.GetVersion(), + HostIds: []string{"invalid_host_id"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/servers/controller/handlers/roles/role_service.go b/internal/servers/controller/handlers/roles/role_service.go index 5d790f0641..8151d11e5b 100644 --- a/internal/servers/controller/handlers/roles/role_service.go +++ b/internal/servers/controller/handlers/roles/role_service.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/sdk/strutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -355,7 +356,7 @@ func (s Service) addPrinciplesInRepo(ctx context.Context, roleId string, princip if err != nil { return nil, err } - _, err = repo.AddPrincipalRoles(ctx, roleId, version, principalIds) + _, err = repo.AddPrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add principals to role: %v.", err) @@ -375,7 +376,7 @@ func (s Service) setPrinciplesInRepo(ctx context.Context, roleId string, princip if err != nil { return nil, err } - _, _, err = repo.SetPrincipalRoles(ctx, roleId, version, principalIds) + _, _, err = repo.SetPrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set principals on role: %v.", err) @@ -395,7 +396,7 @@ func (s Service) removePrinciplesInRepo(ctx context.Context, roleId string, prin if err != nil { return nil, err } - _, err = repo.DeletePrincipalRoles(ctx, roleId, version, principalIds) + _, err = repo.DeletePrincipalRoles(ctx, roleId, version, strutil.RemoveDuplicates(principalIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove principals from role: %v.", err) @@ -415,7 +416,7 @@ func (s Service) addGrantsInRepo(ctx context.Context, roleId string, grants []st if err != nil { return nil, err } - _, err = repo.AddRoleGrants(ctx, roleId, version, grants) + _, err = repo.AddRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add grants to role: %v.", err) @@ -439,7 +440,7 @@ func (s Service) setGrantsInRepo(ctx context.Context, roleId string, grants []st if grants == nil { grants = []string{} } - _, _, err = repo.SetRoleGrants(ctx, roleId, version, grants) + _, _, err = repo.SetRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set grants on role: %v.", err) @@ -459,7 +460,7 @@ func (s Service) removeGrantsInRepo(ctx context.Context, roleId string, grants [ if err != nil { return nil, err } - _, err = repo.DeleteRoleGrants(ctx, roleId, version, grants) + _, err = repo.DeleteRoleGrants(ctx, roleId, version, strutil.RemoveDuplicates(grants, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove grants from role: %v", err) @@ -654,8 +655,13 @@ func validateAddRolePrincipalsRequest(req *pbs.AddRolePrincipalsRequest) error { badFields["principal_ids"] = "Must be non-empty." } for _, id := range req.GetPrincipalIds() { + if !handlers.ValidId(iam.GroupPrefix, id) && ! handlers.ValidId(iam.UserPrefix, id) { + badFields["principal_ids"] = "Must only have valid group and/or user ids." + break + } if id == "u_recovery" { badFields["principal_ids"] = "u_recovery cannot be assigned to a role" + break } } if len(badFields) > 0 { @@ -673,8 +679,13 @@ func validateSetRolePrincipalsRequest(req *pbs.SetRolePrincipalsRequest) error { badFields["version"] = "Required field." } for _, id := range req.GetPrincipalIds() { + if !handlers.ValidId(iam.GroupPrefix, id) && ! handlers.ValidId(iam.UserPrefix, id) { + badFields["principal_ids"] = "Must only have valid group and/or user ids." + break + } if id == "u_recovery" { badFields["principal_ids"] = "u_recovery cannot be assigned to a role" + break } } if len(badFields) > 0 { @@ -694,6 +705,12 @@ func validateRemoveRolePrincipalsRequest(req *pbs.RemoveRolePrincipalsRequest) e if len(req.GetPrincipalIds()) == 0 { badFields["principal_ids"] = "Must be non-empty." } + for _, id := range req.GetPrincipalIds() { + if !handlers.ValidId(iam.GroupPrefix, id) && !handlers.ValidId(iam.UserPrefix, id) { + badFields["principal_ids"] = "Must only have valid group and/or user ids." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -711,6 +728,16 @@ func validateAddRoleGrantsRequest(req *pbs.AddRoleGrantsRequest) error { if len(req.GetGrantStrings()) == 0 { badFields["grant_strings"] = "Must be non-empty." } + for _, v := range req.GetGrantStrings() { + if len(v) == 0 { + badFields["grant_strings"] = "Grant strings must not be empty." + break + } + if _, err := perms.Parse("p_anything", v); err != nil { + badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -725,6 +752,16 @@ func validateSetRoleGrantsRequest(req *pbs.SetRoleGrantsRequest) error { if req.GetVersion() == 0 { badFields["version"] = "Required field." } + for _, v := range req.GetGrantStrings() { + if len(v) == 0 { + badFields["grant_strings"] = "Grant strings must not be empty." + break + } + if _, err := perms.Parse("p_anything", v); err != nil { + badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -742,6 +779,16 @@ func validateRemoveRoleGrantsRequest(req *pbs.RemoveRoleGrantsRequest) error { if len(req.GetGrantStrings()) == 0 { badFields["grant_strings"] = "Must be non-empty." } + for _, v := range req.GetGrantStrings() { + if len(v) == 0 { + badFields["grant_strings"] = "Grant strings must not be empty." + break + } + if _, err := perms.Parse("p_anything", v); err != nil { + badFields["grant_strings"] = fmt.Sprintf("Improperly formatted grant %q.", v) + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } diff --git a/internal/servers/controller/handlers/roles/role_service_test.go b/internal/servers/controller/handlers/roles/role_service_test.go index 5d91e7e273..a12e3458ec 100644 --- a/internal/servers/controller/handlers/roles/role_service_test.go +++ b/internal/servers/controller/handlers/roles/role_service_test.go @@ -995,6 +995,14 @@ func TestAddPrincipal(t *testing.T) { addGroups: []string{groups[1].GetPublicId()}, resultGroups: []string{groups[0].GetPublicId(), groups[1].GetPublicId()}, }, + { + name: "Add duplicate group on populated role", + setup: func(r *iam.Role) { + iam.TestGroupRole(t, conn, r.GetPublicId(), groups[0].GetPublicId()) + }, + addGroups: []string{groups[1].GetPublicId(), groups[1].GetPublicId()}, + resultGroups: []string{groups[0].GetPublicId(), groups[1].GetPublicId()}, + }, { name: "Add invalid u_recovery on role", setup: func(r *iam.Role) {}, @@ -1043,6 +1051,24 @@ func TestAddPrincipal(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad Principal Id", + req: &pbs.AddRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "u_recovery Id", + req: &pbs.AddRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"u_recovery"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1101,6 +1127,14 @@ func TestSetPrincipal(t *testing.T) { setUsers: []string{users[1].GetPublicId()}, resultUsers: []string{users[1].GetPublicId()}, }, + { + name: "Set duplicate user on populated role", + setup: func(r *iam.Role) { + iam.TestUserRole(t, conn, r.GetPublicId(), users[0].GetPublicId()) + }, + setUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()}, + resultUsers: []string{users[1].GetPublicId()}, + }, { name: "Set empty on populated role", setup: func(r *iam.Role) { @@ -1172,6 +1206,24 @@ func TestSetPrincipal(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad Principal Id", + req: &pbs.SetRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "u_recovery", + req: &pbs.SetRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"u_recovery"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1231,6 +1283,15 @@ func TestRemovePrincipal(t *testing.T) { removeUsers: []string{users[1].GetPublicId()}, resultUsers: []string{users[0].GetPublicId()}, }, + { + name: "Remove 1 duplicate user of 2 users from role", + setup: func(r *iam.Role) { + iam.TestUserRole(t, conn, r.GetPublicId(), users[0].GetPublicId()) + iam.TestUserRole(t, conn, r.GetPublicId(), users[1].GetPublicId()) + }, + removeUsers: []string{users[1].GetPublicId(), users[1].GetPublicId()}, + resultUsers: []string{users[0].GetPublicId()}, + }, { name: "Remove all users from role", setup: func(r *iam.Role) { @@ -1262,6 +1323,15 @@ func TestRemovePrincipal(t *testing.T) { removeGroups: []string{groups[1].GetPublicId()}, resultGroups: []string{groups[0].GetPublicId()}, }, + { + name: "Remove 1 duplicate group of 2 groups from role", + setup: func(r *iam.Role) { + iam.TestGroupRole(t, conn, r.GetPublicId(), groups[0].GetPublicId()) + iam.TestGroupRole(t, conn, r.GetPublicId(), groups[1].GetPublicId()) + }, + removeGroups: []string{groups[1].GetPublicId(), groups[1].GetPublicId()}, + resultGroups: []string{groups[0].GetPublicId()}, + }, { name: "Remove all groups from role", setup: func(r *iam.Role) { @@ -1302,22 +1372,40 @@ func TestRemovePrincipal(t *testing.T) { failCases := []struct { name string - req *pbs.AddRolePrincipalsRequest + req *pbs.RemoveRolePrincipalsRequest err error }{ { name: "Bad Role Id", - req: &pbs.AddRolePrincipalsRequest{ + req: &pbs.RemoveRolePrincipalsRequest{ Id: "bad id", Version: role.GetVersion(), }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad User Id", + req: &pbs.RemoveRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"g_validgroup", "invaliduser"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "Bad Group Id", + req: &pbs.RemoveRolePrincipalsRequest{ + Id: role.GetPublicId(), + Version: role.GetVersion(), + PrincipalIds: []string{"u_validuser", "invalidgroup"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - _, gErr := s.AddRolePrincipals(auth.DisabledAuthTestContext(auth.WithScopeId(p.GetPublicId())), tc.req) + _, gErr := s.RemoveRolePrincipals(auth.DisabledAuthTestContext(auth.WithScopeId(p.GetPublicId())), tc.req) if tc.err != nil { require.Error(gErr) assert.True(errors.Is(gErr, tc.err), "AddRolePrincipals(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) @@ -1373,6 +1461,12 @@ func TestAddGrants(t *testing.T) { add: []string{"id=*;type=*;actions=delete"}, result: []string{"id=1;actions=read", "id=*;type=*;actions=delete"}, }, + { + name: "Add duplicate grant on role with grant", + existing: []string{"id=1;actions=read"}, + add: []string{"id=*;type=*;actions=delete", "id=*;type=*;actions=delete"}, + result: []string{"id=1;actions=read", "id=*;type=*;actions=delete"}, + }, { name: "Add grant matching existing grant", existing: []string{"id=1;actions=read", "id=*;type=*;actions=delete"}, @@ -1435,6 +1529,24 @@ func TestAddGrants(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Unparseable Grant", + req: &pbs.AddRoleGrantsRequest{ + Id: role.GetPublicId(), + GrantStrings: []string{"id=*;type=*;actions=create", "unparseable"}, + Version: role.GetVersion(), + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "Empty Grant", + req: &pbs.AddRoleGrantsRequest{ + Id: role.GetPublicId(), + GrantStrings: []string{"id=*;type=*;actions=create", ""}, + Version: role.GetVersion(), + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1484,8 +1596,14 @@ func TestSetGrants(t *testing.T) { result: []string{"id=*;type=*;actions=delete"}, }, { - name: "Set empty on role", + name: "Set duplicate grant matching existing grant", existing: []string{"id=1;actions=read", "id=*;type=*;actions=delete"}, + set: []string{"id=*;type=*;actions=delete", "id=*;type=*;actions=delete"}, + result: []string{"id=*;type=*;actions=delete"}, + }, + { + name: "Set empty on role", + existing: []string{"id=1;type=*;actions=read", "id=*;type=*;actions=delete"}, set: nil, result: nil, }, @@ -1514,8 +1632,7 @@ func TestSetGrants(t *testing.T) { assert.Error(err) return } - s, _ := status.FromError(err) - require.NoError(err, "Got error %v", s) + require.NoError(err, "Got error %v", err) checkEqualGrants(t, tc.result, got.GetItem()) }) } @@ -1547,6 +1664,15 @@ func TestSetGrants(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Unparsable grant", + req: &pbs.SetRoleGrantsRequest{ + Id: role.GetPublicId(), + GrantStrings: []string{"id=*;type=*;actions=create", "unparseable"}, + Version: role.GetVersion(), + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1579,25 +1705,31 @@ func TestRemoveGrants(t *testing.T) { }{ { name: "Remove all", - existing: []string{"id=1;actions=read"}, - remove: []string{"id=1;actions=read"}, + existing: []string{"id=1;type=*;actions=read"}, + remove: []string{"id=1;type=*;actions=read"}, }, { name: "Remove partial", - existing: []string{"id=1;actions=read", "id=2;actions=delete"}, - remove: []string{"id=1;actions=read"}, - result: []string{"id=2;actions=delete"}, + existing: []string{"id=1;type=*;actions=read", "id=2;type=*;actions=delete"}, + remove: []string{"id=1;type=*;actions=read"}, + result: []string{"id=2;type=*;actions=delete"}, + }, + { + name: "Remove duplicate", + existing: []string{"id=1;type=*;actions=read", "id=2;type=*;actions=delete"}, + remove: []string{"id=1;type=*;actions=read", "id=1;type=*;actions=read"}, + result: []string{"id=2;type=*;actions=delete"}, }, { name: "Remove non existant", - existing: []string{"id=2;actions=delete"}, - remove: []string{"id=1;actions=read"}, - result: []string{"id=2;actions=delete"}, + existing: []string{"id=2;type=*;actions=delete"}, + remove: []string{"id=1;type=*;actions=read"}, + result: []string{"id=2;type=*;actions=delete"}, }, { name: "Remove from empty role", existing: []string{}, - remove: []string{"id=1;actions=read"}, + remove: []string{"id=1;type=*;actions=read"}, result: nil, }, } @@ -1645,7 +1777,7 @@ func TestRemoveGrants(t *testing.T) { name: "Bad Version", req: &pbs.RemoveRoleGrantsRequest{ Id: role.GetPublicId(), - GrantStrings: []string{"id=*;actions=create"}, + GrantStrings: []string{"id=2;type=*;actions=create"}, Version: role.GetVersion() + 2, }, err: handlers.ApiErrorWithCode(codes.Internal), @@ -1654,7 +1786,25 @@ func TestRemoveGrants(t *testing.T) { name: "Bad Role Id", req: &pbs.RemoveRoleGrantsRequest{ Id: "bad id", - GrantStrings: []string{"id=*;actions=create"}, + GrantStrings: []string{"id=*;type=*;actions=create"}, + Version: role.GetVersion(), + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "Empty Grant", + req: &pbs.RemoveRoleGrantsRequest{ + Id: role.GetPublicId(), + GrantStrings: []string{"id=*;type=*;actions=create", ""}, + Version: role.GetVersion(), + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, + { + name: "Unparseable Grant", + req: &pbs.RemoveRoleGrantsRequest{ + Id: role.GetPublicId(), + GrantStrings: []string{"id=*;type=*;actions=create", ";unparsable=2"}, Version: role.GetVersion(), }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), diff --git a/internal/servers/controller/handlers/targets/target_service.go b/internal/servers/controller/handlers/targets/target_service.go index 02ac3e9253..d4d8c3aaf8 100644 --- a/internal/servers/controller/handlers/targets/target_service.go +++ b/internal/servers/controller/handlers/targets/target_service.go @@ -14,6 +14,7 @@ import ( pb "github.com/hashicorp/boundary/internal/gen/controller/api/resources/targets" pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" "github.com/hashicorp/boundary/internal/host" + "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/servers/controller/common" @@ -24,6 +25,7 @@ import ( "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/sdk/strutil" "github.com/mr-tron/base58" "google.golang.org/grpc/codes" "google.golang.org/protobuf/proto" @@ -556,7 +558,7 @@ func (s Service) addInRepo(ctx context.Context, targetId string, hostSetId []str if err != nil { return nil, err } - out, m, err := repo.AddTargetHostSets(ctx, targetId, version, hostSetId) + out, m, err := repo.AddTargetHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetId, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add host sets to target: %v.", err) @@ -572,7 +574,7 @@ func (s Service) setInRepo(ctx context.Context, targetId string, hostSetIds []st if err != nil { return nil, err } - _, _, err = repo.SetTargetHostSets(ctx, targetId, version, hostSetIds) + _, _, err = repo.SetTargetHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set host sets in target: %v.", err) @@ -593,7 +595,7 @@ func (s Service) removeInRepo(ctx context.Context, targetId string, hostSetIds [ if err != nil { return nil, err } - _, err = repo.DeleteTargeHostSets(ctx, targetId, version, hostSetIds) + _, err = repo.DeleteTargeHostSets(ctx, targetId, version, strutil.RemoveDuplicates(hostSetIds, false)) if err != nil { // TODO: Figure out a way to surface more helpful error info beyond the Internal error. return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove host sets from target: %v.", err) @@ -800,6 +802,12 @@ func validateAddRequest(req *pbs.AddTargetHostSetsRequest) error { if len(req.GetHostSetIds()) == 0 { badFields["host_set_ids"] = "Must be non-empty." } + for _, id := range req.GetHostSetIds() { + if !handlers.ValidId(static.HostSetPrefix, id) { + badFields["host_set_ids"] = "Incorrectly formatted host set identifier." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -814,6 +822,12 @@ func validateSetRequest(req *pbs.SetTargetHostSetsRequest) error { if req.GetVersion() == 0 { badFields["version"] = "Required field." } + for _, id := range req.GetHostSetIds() { + if !handlers.ValidId(static.HostSetPrefix, id) { + badFields["host_set_ids"] = "Incorrectly formatted host set identifier." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -831,6 +845,12 @@ func validateRemoveRequest(req *pbs.RemoveTargetHostSetsRequest) error { if len(req.GetHostSetIds()) == 0 { badFields["host_set_ids"] = "Must be non-empty." } + for _, id := range req.GetHostSetIds() { + if !handlers.ValidId(static.HostSetPrefix, id) { + badFields["host_set_ids"] = "Incorrectly formatted host set identifier." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } diff --git a/internal/servers/controller/handlers/targets/target_service_test.go b/internal/servers/controller/handlers/targets/target_service_test.go index 0ae31b0ada..70ef04f8e5 100644 --- a/internal/servers/controller/handlers/targets/target_service_test.go +++ b/internal/servers/controller/handlers/targets/target_service_test.go @@ -824,11 +824,17 @@ func TestAddTargetHostSets(t *testing.T) { resultHostSets: []string{hs[1].GetPublicId()}, }, { - name: "Add host on populated target", + name: "Add set on populated target", tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "populated", target.WithHostSets([]string{hs[0].GetPublicId()})), addHostSets: []string{hs[1].GetPublicId()}, resultHostSets: []string{hs[0].GetPublicId(), hs[1].GetPublicId()}, }, + { + name: "Add duplicated sets on populated target", + tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "duplicated", target.WithHostSets([]string{hs[0].GetPublicId()})), + addHostSets: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHostSets: []string{hs[0].GetPublicId(), hs[1].GetPublicId()}, + }, } for _, tc := range addCases { @@ -881,6 +887,15 @@ func TestAddTargetHostSets(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Incorrect host set ids", + req: &pbs.AddTargetHostSetsRequest{ + Id: tar.GetPublicId(), + Version: tar.GetVersion(), + HostSetIds: []string{"incorrect"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -914,19 +929,25 @@ func TestSetTargetHostSets(t *testing.T) { resultHostSets []string }{ { - name: "Set host on empty set", + name: "Set on empty target", tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "empty"), setHostSets: []string{hs[1].GetPublicId()}, resultHostSets: []string{hs[1].GetPublicId()}, }, { - name: "Set host on populated set", + name: "Set on populated target", tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "populated", target.WithHostSets([]string{hs[0].GetPublicId()})), setHostSets: []string{hs[1].GetPublicId()}, resultHostSets: []string{hs[1].GetPublicId()}, }, { - name: "Set empty on populated set", + name: "Set duplicate host set on populated target", + tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "duplicate", target.WithHostSets([]string{hs[0].GetPublicId()})), + setHostSets: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHostSets: []string{hs[1].GetPublicId()}, + }, + { + name: "Set empty on populated target", tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "another populated", target.WithHostSets([]string{hs[0].GetPublicId()})), setHostSets: []string{}, resultHostSets: nil, @@ -971,6 +992,15 @@ func TestSetTargetHostSets(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.Internal), }, + { + name: "Bad host set id", + req: &pbs.SetTargetHostSetsRequest{ + Id: tar.GetPublicId(), + Version: tar.GetVersion(), + HostSetIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -1016,6 +1046,12 @@ func TestRemoveTargetHostSets(t *testing.T) { removeHosts: []string{hs[1].GetPublicId()}, resultHosts: []string{hs[0].GetPublicId()}, }, + { + name: "Remove 1 duplicate set of 2 sets", + tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "remove duplicate", target.WithHostSets([]string{hs[0].GetPublicId(), hs[1].GetPublicId()})), + removeHosts: []string{hs[1].GetPublicId(), hs[1].GetPublicId()}, + resultHosts: []string{hs[0].GetPublicId()}, + }, { name: "Remove all hosts from set", tar: target.TestTcpTarget(t, conn, proj.GetPublicId(), "remove all", target.WithHostSets([]string{hs[0].GetPublicId(), hs[1].GetPublicId()})), @@ -1073,12 +1109,21 @@ func TestRemoveTargetHostSets(t *testing.T) { { name: "empty sets", req: &pbs.RemoveTargetHostSetsRequest{ - Id: "bad id", + Id: tar.GetPublicId(), Version: tar.GetVersion(), HostSetIds: []string{}, }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Invalid set ids", + req: &pbs.RemoveTargetHostSetsRequest{ + Id: tar.GetPublicId(), + Version: tar.GetVersion(), + HostSetIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { diff --git a/internal/servers/controller/handlers/users/user_service.go b/internal/servers/controller/handlers/users/user_service.go index 9117f6852e..0217595eff 100644 --- a/internal/servers/controller/handlers/users/user_service.go +++ b/internal/servers/controller/handlers/users/user_service.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/auth/password" "github.com/hashicorp/boundary/internal/db" pb "github.com/hashicorp/boundary/internal/gen/controller/api/resources/users" pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services" @@ -16,6 +17,7 @@ import ( "github.com/hashicorp/boundary/internal/types/action" "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/sdk/strutil" "google.golang.org/grpc/codes" "google.golang.org/protobuf/types/known/wrapperspb" ) @@ -295,7 +297,7 @@ func (s Service) addInRepo(ctx context.Context, userId string, accountIds []stri if err != nil { return nil, err } - _, err = repo.AddUserAccounts(ctx, userId, version, accountIds) + _, err = repo.AddUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false)) if err != nil { return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to add accounts to user: %v.", err) } @@ -314,7 +316,7 @@ func (s Service) setInRepo(ctx context.Context, userId string, accountIds []stri if err != nil { return nil, err } - _, err = repo.SetUserAccounts(ctx, userId, version, accountIds) + _, err = repo.SetUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false)) if err != nil { return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to set accounts for the user: %v.", err) } @@ -333,7 +335,7 @@ func (s Service) removeInRepo(ctx context.Context, userId string, accountIds []s if err != nil { return nil, err } - _, err = repo.DeleteUserAccounts(ctx, userId, version, accountIds) + _, err = repo.DeleteUserAccounts(ctx, userId, version, strutil.RemoveDuplicates(accountIds, false)) if err != nil { return nil, handlers.ApiErrorWithCodeAndMessage(codes.Internal, "Unable to remove accounts from user: %v.", err) } @@ -462,6 +464,13 @@ func validateAddUserAccountsRequest(req *pbs.AddUserAccountsRequest) error { if len(req.GetAccountIds()) == 0 { badFields["account_ids"] = "Must be non-empty." } + for _, a := range req.GetAccountIds() { + // TODO: Increase the type of auth accounts that can be added to a user. + if !handlers.ValidId(password.AccountPrefix, a) { + badFields["account_ids"] = "Values must be valid account ids." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -476,6 +485,13 @@ func validateSetUserAccountsRequest(req *pbs.SetUserAccountsRequest) error { if req.GetVersion() == 0 { badFields["version"] = "Required field." } + for _, a := range req.GetAccountIds() { + // TODO: Increase the type of auth accounts that can be added to a user. + if !handlers.ValidId(password.AccountPrefix, a) { + badFields["account_ids"] = "Values must be valid account ids." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } @@ -493,6 +509,13 @@ func validateRemoveUserAccountsRequest(req *pbs.RemoveUserAccountsRequest) error if len(req.GetAccountIds()) == 0 { badFields["account_ids"] = "Must be non-empty." } + for _, a := range req.GetAccountIds() { + // TODO: Increase the type of auth accounts that can be added to a user. + if !handlers.ValidId(password.AccountPrefix, a) { + badFields["account_ids"] = "Values must be valid account ids." + break + } + } if len(badFields) > 0 { return handlers.InvalidArgumentErrorf("Errors in provided fields.", badFields) } diff --git a/internal/servers/controller/handlers/users/user_service_test.go b/internal/servers/controller/handlers/users/user_service_test.go index 3085b36f84..08d768ea13 100644 --- a/internal/servers/controller/handlers/users/user_service_test.go +++ b/internal/servers/controller/handlers/users/user_service_test.go @@ -644,6 +644,17 @@ func TestAddAccount(t *testing.T) { addAccounts: []string{accts[1].GetPublicId()}, resultAccounts: []string{accts[0].GetPublicId(), accts[1].GetPublicId()}, }, + { + name: "Add duplicate account on populated user", + setup: func(u *iam.User) { + _, err := iamRepo.SetUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(), + []string{accts[0].GetPublicId()}) + require.NoError(t, err) + u.Version = u.Version + 1 + }, + addAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()}, + resultAccounts: []string{accts[0].GetPublicId(), accts[1].GetPublicId()}, + }, { name: "Add empty on populated user", setup: func(u *iam.User) { @@ -695,6 +706,15 @@ func TestAddAccount(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad account Id", + req: &pbs.AddUserAccountsRequest{ + Id: usr.GetPublicId(), + Version: usr.GetVersion(), + AccountIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -746,6 +766,17 @@ func TestSetAccount(t *testing.T) { setAccounts: []string{accts[1].GetPublicId()}, resultAccounts: []string{accts[1].GetPublicId()}, }, + { + name: "Set duplicate account on populated user", + setup: func(u *iam.User) { + iamRepo.AddUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(), + []string{accts[0].GetPublicId()}) + require.NoError(t, err) + u.Version = u.Version + 1 + }, + setAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()}, + resultAccounts: []string{accts[1].GetPublicId()}, + }, { name: "Set empty on populated user", setup: func(u *iam.User) { @@ -798,6 +829,15 @@ func TestSetAccount(t *testing.T) { }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad account Id", + req: &pbs.SetUserAccountsRequest{ + Id: usr.GetPublicId(), + Version: usr.GetVersion(), + AccountIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { @@ -849,6 +889,17 @@ func TestRemoveAccount(t *testing.T) { removeAccounts: []string{accts[1].GetPublicId()}, resultAccounts: []string{accts[0].GetPublicId()}, }, + { + name: "Remove 1 duplicate accounts of 2 accounts from user", + setup: func(u *iam.User) { + _, err := iamRepo.SetUserAccounts(context.Background(), u.GetPublicId(), u.GetVersion(), + []string{accts[0].GetPublicId(), accts[1].GetPublicId()}) + require.NoError(t, err) + u.Version = u.Version + 1 + }, + removeAccounts: []string{accts[1].GetPublicId(), accts[1].GetPublicId()}, + resultAccounts: []string{accts[0].GetPublicId()}, + }, { name: "Remove all accounts from user", setup: func(u *iam.User) { @@ -900,22 +951,31 @@ func TestRemoveAccount(t *testing.T) { failCases := []struct { name string - req *pbs.AddUserAccountsRequest + req *pbs.RemoveUserAccountsRequest err error }{ { name: "Bad User Id", - req: &pbs.AddUserAccountsRequest{ + req: &pbs.RemoveUserAccountsRequest{ Id: "bad id", Version: usr.GetVersion(), }, err: handlers.ApiErrorWithCode(codes.InvalidArgument), }, + { + name: "Bad account Id", + req: &pbs.RemoveUserAccountsRequest{ + Id: usr.GetPublicId(), + Version: usr.GetVersion(), + AccountIds: []string{"invalid"}, + }, + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + }, } for _, tc := range failCases { t.Run(tc.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - _, gErr := s.AddUserAccounts(auth.DisabledAuthTestContext(auth.WithScopeId(usr.GetScopeId())), tc.req) + _, gErr := s.RemoveUserAccounts(auth.DisabledAuthTestContext(auth.WithScopeId(usr.GetScopeId())), tc.req) if tc.err != nil { require.Error(gErr) assert.True(errors.Is(gErr, tc.err), "AddUserAccounts(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)