From ee33a18caa3088c1dd0d4c7714f810d0cfb4a899 Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 21 Jan 2021 21:03:30 -0500 Subject: [PATCH] Add scope repo function to recursively list (#879) A simple function that returns an appropriate set of scopes based on the incoming scope ID. Includes a test. --- internal/iam/repository_scope.go | 32 ++++++++++++++++- internal/iam/repository_scope_test.go | 51 +++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/internal/iam/repository_scope.go b/internal/iam/repository_scope.go index 64c0f8eb85..e2b6e38157 100644 --- a/internal/iam/repository_scope.go +++ b/internal/iam/repository_scope.go @@ -438,7 +438,7 @@ func (r *Repository) ListProjects(ctx context.Context, withOrgId string, opt ... return projects, nil } -// ListOrgs and supports the WithLimit option. +// ListOrgs supports the WithLimit option. func (r *Repository) ListOrgs(ctx context.Context, opt ...Option) ([]*Scope, error) { const op = "iam.(Repository).ListOrgs" var orgs []*Scope @@ -448,3 +448,33 @@ func (r *Repository) ListOrgs(ctx context.Context, opt ...Option) ([]*Scope, err } return orgs, nil } + +// ListRecursively allows for recursive listing of scopes based on a root scope +// ID. It returns the root scope ID as a part of the set. +func (r *Repository) ListRecursively(ctx context.Context, rootScopeId string, opt ...Option) ([]*Scope, error) { + const op = "iam.(Repository).ListRecursively" + var orgs []*Scope + var where string + var args []interface{} + switch { + case rootScopeId == "global": + // Nothing -- we want all scopes + case strings.HasPrefix(rootScopeId, "o_"): + // The org itself and any projects that have it as parent + where = "public_id = ? or parent_id = ?" + args = append(args, rootScopeId, rootScopeId) + case strings.HasPrefix(rootScopeId, "p_"): + // No scopes can (currently) live under projects, so just the project + // itself + where = "public_id = ?" + args = append(args, rootScopeId) + default: + // We have no idea what scope type this is so bail + return nil, errors.New(errors.InvalidPublicId, op+":TypeSwitch", "invalid scope ID") + } + err := r.list(ctx, &orgs, where, args, opt...) + if err != nil { + return nil, errors.Wrap(err, op+":ListQuery") + } + return orgs, nil +} diff --git a/internal/iam/repository_scope_test.go b/internal/iam/repository_scope_test.go index 1fb4cb157e..7f1b464c04 100644 --- a/internal/iam/repository_scope_test.go +++ b/internal/iam/repository_scope_test.go @@ -614,3 +614,54 @@ func Test_Repository_ListOrgs(t *testing.T) { }) } } + +func Test_Repository_ListRecursive(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + repo := TestRepo(t, conn, wrapper) + var testOrgs []*Scope + var testProjects []*Scope + const subPerScope = 5 + for i := 0; i < subPerScope; i++ { + org := testOrg(t, repo, fmt.Sprint(i), "") + testOrgs = append(testOrgs, org) + for j := 0; j < subPerScope; j++ { + testProjects = append(testProjects, testProject(t, repo, org.PublicId, WithName(fmt.Sprintf("%d-%d", i, j)))) + } + } + tests := []struct { + name string + rootScopeId string + wantCnt int + wantErr bool + }{ + { + name: "global", + rootScopeId: "global", + wantCnt: 1 + len(testOrgs) + len(testProjects), + }, + { + name: "org", + rootScopeId: testOrgs[0].PublicId, + wantCnt: 1 + subPerScope, + }, + { + name: "project", + rootScopeId: testProjects[16].PublicId, + wantCnt: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + got, err := repo.ListRecursively(context.Background(), tt.rootScopeId) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + assert.Equal(tt.wantCnt, len(got)) + }) + } +}