diff --git a/internal/daemon/controller/handlers/accounts/account_service.go b/internal/daemon/controller/handlers/accounts/account_service.go index 18e2c5932d..a132fc3a79 100644 --- a/internal/daemon/controller/handlers/accounts/account_service.go +++ b/internal/daemon/controller/handlers/accounts/account_service.go @@ -606,7 +606,7 @@ func (s Service) getFromRepo(ctx context.Context, id string) (auth.Account, []st } return nil, nil, err } - mgs, err := repo.ListManagedGroupMembershipsByMember(ctx, a.GetPublicId()) + mgs, err := repo.ListManagedGroupMembershipsByMember(ctx, a.GetPublicId(), oidc.WithLimit(-1)) if err != nil { return nil, nil, err } @@ -629,7 +629,7 @@ func (s Service) getFromRepo(ctx context.Context, id string) (auth.Account, []st } return nil, nil, err } - mgs, err := repo.ListManagedGroupMembershipsByMember(ctx, a.GetPublicId()) + mgs, err := repo.ListManagedGroupMembershipsByMember(ctx, a.GetPublicId(), ldap.WithLimit(ctx, -1)) if err != nil { return nil, nil, err } diff --git a/internal/daemon/controller/handlers/accounts/account_service_test.go b/internal/daemon/controller/handlers/accounts/account_service_test.go index 5e33ea3741..7a5e549d2f 100644 --- a/internal/daemon/controller/handlers/accounts/account_service_test.go +++ b/internal/daemon/controller/handlers/accounts/account_service_test.go @@ -134,13 +134,15 @@ func TestGet(t *testing.T) { return password.NewRepository(ctx, rw, rw, kmsCache) } oidcRepoFn := func() (*oidc.Repository, error) { - return oidc.NewRepository(ctx, rw, rw, kmsCache) + // Use a small limit to test that membership lookup is explicitly unlimited + return oidc.NewRepository(ctx, rw, rw, kmsCache, oidc.WithLimit(1)) } iamRepoFn := func() (*iam.Repository, error) { return iam.NewRepository(ctx, rw, rw, kmsCache) } ldapRepoFn := func() (*ldap.Repository, error) { - return ldap.NewRepository(ctx, rw, rw, kmsCache) + // Use a small limit to test that membership lookup is explicitly unlimited + return ldap.NewRepository(ctx, rw, rw, kmsCache, ldap.WithLimit(ctx, 1)) } s, err := accounts.NewService(ctx, pwRepoFn, oidcRepoFn, ldapRepoFn, 1000) @@ -175,9 +177,10 @@ func TestGet(t *testing.T) { oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://www.alice.com/callback")[0]), ) oidcA := oidc.TestAccount(t, conn, oidcAm, "test-subject") - // Create a managed group that will always match, so we can test that it is + // Create some managed groups that will always match, so we can test that it is // returned in results mg := oidc.TestManagedGroup(t, conn, oidcAm, `"/token/sub" matches ".*"`) + mg2 := oidc.TestManagedGroup(t, conn, oidcAm, `"/token/sub" matches ".*"`) oidcWireAccount := pb.Account{ Id: oidcA.GetPublicId(), AuthMethodId: oidcA.GetAuthMethodId(), @@ -193,7 +196,7 @@ func TestGet(t *testing.T) { }, }, AuthorizedActions: oidcAuthorizedActions, - ManagedGroupIds: []string{mg.GetPublicId()}, + ManagedGroupIds: []string{mg.GetPublicId(), mg2.GetPublicId()}, } ldapAm := ldap.TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) @@ -204,6 +207,7 @@ func TestGet(t *testing.T) { ldap.WithDn(ctx, "test-dn"), ) ldapMg := ldap.TestManagedGroup(t, conn, ldapAm, []string{"admin"}) + ldapMg2 := ldap.TestManagedGroup(t, conn, ldapAm, []string{"admin"}) ldapWireAccount := pb.Account{ Id: ldapAcct.GetPublicId(), AuthMethodId: ldapAm.GetPublicId(), @@ -222,7 +226,7 @@ func TestGet(t *testing.T) { }, Type: ldap.Subtype.String(), AuthorizedActions: ldapAuthorizedActions, - ManagedGroupIds: []string{ldapMg.GetPublicId()}, + ManagedGroupIds: []string{ldapMg.GetPublicId(), ldapMg2.GetPublicId()}, } cases := []struct { @@ -289,12 +293,14 @@ func TestGet(t *testing.T) { if globals.ResourceInfoFromPrefix(tc.req.Id).Subtype == oidc.Subtype { // Set up managed groups before getting. First get the current - // managed group to make sure we have the right version. + // managed groups to make sure we have the right version. oidcRepo, err := oidcRepoFn() require.NoError(err) currMg, err := oidcRepo.LookupManagedGroup(ctx, mg.GetPublicId()) require.NoError(err) - _, _, err = oidcRepo.SetManagedGroupMemberships(ctx, oidcAm, oidcA, []*oidc.ManagedGroup{currMg}) + currMg2, err := oidcRepo.LookupManagedGroup(ctx, mg2.GetPublicId()) + require.NoError(err) + _, _, err = oidcRepo.SetManagedGroupMemberships(ctx, oidcAm, oidcA, []*oidc.ManagedGroup{currMg, currMg2}) require.NoError(err) } diff --git a/internal/daemon/controller/handlers/managed_groups/managed_group_service.go b/internal/daemon/controller/handlers/managed_groups/managed_group_service.go index dd7c780917..d8c03480d5 100644 --- a/internal/daemon/controller/handlers/managed_groups/managed_group_service.go +++ b/internal/daemon/controller/handlers/managed_groups/managed_group_service.go @@ -446,7 +446,7 @@ func (s Service) getFromRepo(ctx context.Context, id string) (auth.ManagedGroup, } return nil, nil, err } - ids, err := repo.ListManagedGroupMembershipsByGroup(ctx, mg.GetPublicId()) + ids, err := repo.ListManagedGroupMembershipsByGroup(ctx, mg.GetPublicId(), oidc.WithLimit(-1)) if err != nil { return nil, nil, err } @@ -469,7 +469,7 @@ func (s Service) getFromRepo(ctx context.Context, id string) (auth.ManagedGroup, } return nil, nil, err } - ids, err := repo.ListManagedGroupMembershipsByGroup(ctx, mg.GetPublicId()) + ids, err := repo.ListManagedGroupMembershipsByGroup(ctx, mg.GetPublicId(), ldap.WithLimit(ctx, -1)) if err != nil { return nil, nil, err } diff --git a/internal/daemon/controller/handlers/managed_groups/managed_group_service_test.go b/internal/daemon/controller/handlers/managed_groups/managed_group_service_test.go index 8c83d142f4..14762802b9 100644 --- a/internal/daemon/controller/handlers/managed_groups/managed_group_service_test.go +++ b/internal/daemon/controller/handlers/managed_groups/managed_group_service_test.go @@ -118,13 +118,15 @@ func TestGet(t *testing.T) { wrap := db.TestWrapper(t) kmsCache := kms.TestKms(t, conn, wrap) oidcRepoFn := func() (*oidc.Repository, error) { - return oidc.NewRepository(ctx, rw, rw, kmsCache) + // Use a small limit to test that membership lookup is explicitly unlimited + return oidc.NewRepository(ctx, rw, rw, kmsCache, oidc.WithLimit(1)) } iamRepoFn := func() (*iam.Repository, error) { return iam.NewRepository(ctx, rw, rw, kmsCache) } ldapRepoFn := func() (*ldap.Repository, error) { - return ldap.NewRepository(ctx, rw, rw, kmsCache) + // Use a small limit to test that membership lookup is explicitly unlimited + return ldap.NewRepository(ctx, rw, rw, kmsCache, ldap.WithLimit(ctx, 1)) } s, err := managed_groups.NewService(ctx, oidcRepoFn, ldapRepoFn, 1000) @@ -142,6 +144,7 @@ func TestGet(t *testing.T) { oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://www.alice.com/callback")[0]), ) oidcA := oidc.TestAccount(t, conn, oidcAm, "test-subject") + oidcB := oidc.TestAccount(t, conn, oidcAm, "test-subject-2") omg := oidc.TestManagedGroup(t, conn, oidcAm, oidc.TestFakeManagedGroupFilter) // Set up managed group before getting. First get the current @@ -153,6 +156,10 @@ func TestGet(t *testing.T) { require.NoError(t, err) _, _, err = oidcRepo.SetManagedGroupMemberships(ctx, oidcAm, oidcA, []*oidc.ManagedGroup{currMg}) require.NoError(t, err) + currMg, err = oidcRepo.LookupManagedGroup(ctx, omg.GetPublicId()) + require.NoError(t, err) + _, _, err = oidcRepo.SetManagedGroupMemberships(ctx, oidcAm, oidcB, []*oidc.ManagedGroup{currMg}) + require.NoError(t, err) // Fetch the group once more to get the updated time currMg, err = oidcRepo.LookupManagedGroup(ctx, omg.GetPublicId()) require.NoError(t, err) @@ -171,11 +178,12 @@ func TestGet(t *testing.T) { }, }, AuthorizedActions: oidcAuthorizedActions, - MemberIds: []string{oidcA.GetPublicId()}, + MemberIds: []string{oidcA.GetPublicId(), oidcB.GetPublicId()}, } ldapAm := ldap.TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) ldapAcct := ldap.TestAccount(t, conn, ldapAm, "test-login-name", ldap.WithMemberOfGroups(ctx, "admin")) + ldapAcct2 := ldap.TestAccount(t, conn, ldapAm, "test-login-name-2", ldap.WithMemberOfGroups(ctx, "admin")) ldapMg := ldap.TestManagedGroup(t, conn, ldapAm, []string{"admin"}) ldapWireManagedGroup := pb.ManagedGroup{ Id: ldapMg.GetPublicId(), @@ -191,7 +199,7 @@ func TestGet(t *testing.T) { }, }, AuthorizedActions: ldapAuthorizedActions, - MemberIds: []string{ldapAcct.GetPublicId()}, + MemberIds: []string{ldapAcct.GetPublicId(), ldapAcct2.GetPublicId()}, } cases := []struct {