From 2fbdcf6ce096002e339ab59b563fac1b5dc3adc8 Mon Sep 17 00:00:00 2001 From: Timothy Messier Date: Tue, 9 Aug 2022 12:42:58 +0000 Subject: [PATCH] feat(session): Use permissions for limiting list results The session repository can now be passed an optional set of permissions. The permissions are used to restrict results from the repository. Currently this only impacts the List method. This enabled the sessions service to use a single query to retrieve sessions and replaces the previous behavior where the session service would first request all sessions for a given set of scopes, evaluate the permissions, and then request sessions based on a set of session ids. --- .../postgres/50/01_session_list_index.up.sql | 11 + internal/session/options.go | 10 + internal/session/query.go | 15 - internal/session/repository.go | 57 +++ internal/session/repository_session.go | 97 +---- internal/session/repository_session_test.go | 385 +++++------------- .../session/service_list_for_authz_check.go | 23 -- 7 files changed, 184 insertions(+), 414 deletions(-) create mode 100644 internal/db/schema/migrations/oss/postgres/50/01_session_list_index.up.sql delete mode 100644 internal/session/service_list_for_authz_check.go diff --git a/internal/db/schema/migrations/oss/postgres/50/01_session_list_index.up.sql b/internal/db/schema/migrations/oss/postgres/50/01_session_list_index.up.sql new file mode 100644 index 0000000000..b92cf056fc --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/50/01_session_list_index.up.sql @@ -0,0 +1,11 @@ +begin; + -- Partial index to aid session list requests + -- + -- If a session list request is made using the default list request options + -- and using the standard grants created by boundary by default, + -- it will include where clauses that: + -- * include a project_id paired with a user_id + -- * and where termination_reason is null + create index session_list_pix on session (project_id, user_id, termination_reason) where termination_reason is null; + analyze session; +end; diff --git a/internal/session/options.go b/internal/session/options.go index fa1e90285a..4e90a396ab 100644 --- a/internal/session/options.go +++ b/internal/session/options.go @@ -5,6 +5,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/perms" ) // getOpts - iterate the inbound Options and return a struct @@ -32,6 +33,7 @@ type options struct { withDbOpts []db.Option withWorkerStateDelay time.Duration withTerminated bool + withPermissions *perms.UserPermissions } func getDefaultOptions() options { @@ -120,3 +122,11 @@ func WithTerminated(withTerminated bool) Option { o.withTerminated = withTerminated } } + +// WithPermissions is used to include user permissions when constructing a +// Repository. +func WithPermissions(p *perms.UserPermissions) Option { + return func(o *options) { + o.withPermissions = p + } +} diff --git a/internal/session/query.go b/internal/session/query.go index 8cbfb570cc..21c0015478 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -117,21 +117,6 @@ session_connection_limit(expiration_time, connection_limit) as ( select expiration_time, connection_limit, current_connection_count from session_connection_limit, session_connection_count; -` - nonTerminatedSessionPublicIdList = ` -select public_id, project_id, user_id from session -where - session.termination_reason is null -and - project_id = any(@project_ids) -; -` - - sessionPublicIdList = ` -select public_id, project_id, user_id from session -where - project_id = any(@project_ids) -; ` sessionList = ` diff --git a/internal/session/repository.go b/internal/session/repository.go index 27957382a8..8080ae56f1 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -2,12 +2,18 @@ package session import ( "context" + "database/sql" + "fmt" "sort" + "strings" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/perms" + "github.com/hashicorp/boundary/internal/types/action" + "github.com/hashicorp/boundary/internal/types/resource" ) // Clonable provides a cloning interface @@ -23,8 +29,13 @@ type Repository struct { // defaultLimit provides a default for limiting the number of results returned from the repo defaultLimit int + + permissions *perms.UserPermissions } +// RepositoryFactory is a function that creates a Repository. +type RepositoryFactory func(opt ...Option) (*Repository, error) + // NewRepository creates a new session Repository. Supports the options: WithLimit // which sets a default limit on results returned by repo operations. func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repository, error) { @@ -43,14 +54,60 @@ func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repo // zero signals the boundary defaults should be used. opts.withLimit = db.DefaultLimit } + + if opts.withPermissions != nil { + for _, p := range opts.withPermissions.Permissions { + if p.Resource != resource.Session { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "permission for incorrect resource") + } + } + } + return &Repository{ reader: r, writer: w, kms: kms, defaultLimit: opts.withLimit, + permissions: opts.withPermissions, }, nil } +func (r *Repository) listPermissionWhereClauses() ([]string, []interface{}) { + var where []string + var args []interface{} + + if r.permissions == nil { + return where, args + } + + inClauseCnt := 0 + for _, p := range r.permissions.Permissions { + if p.Action != action.List { + continue + } + + inClauseCnt++ + + var clauses []string + clauses = append(clauses, fmt.Sprintf("project_id = @project_id_%d", inClauseCnt)) + args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.ScopeId)) + + if len(p.ResourceIds) > 0 { + clauses = append(clauses, fmt.Sprintf("public_id = any(@public_id_%d)", inClauseCnt)) + args = append(args, sql.Named(fmt.Sprintf("public_id_%d", inClauseCnt), "{"+strings.Join(p.ResourceIds, ",")+"}")) + } + + if p.OnlySelf { + inClauseCnt++ + clauses = append(clauses, fmt.Sprintf("user_id = @user_id_%d", inClauseCnt)) + args = append(args, sql.Named(fmt.Sprintf("user_id_%d", inClauseCnt), r.permissions.UserId)) + } + + where = append(where, fmt.Sprintf("(%s)", strings.Join(clauses, " and "))) + } + return where, args +} + // list will return a listing of resources and honor the WithLimit option or the // repo defaultLimit. Supports WithOrder option. func (r *Repository) list(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error { diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index a9c0228f29..f0aa64751b 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -9,7 +9,6 @@ import ( "strings" "time" - "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -229,89 +228,28 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O return &session, authzSummary, nil } -// fetchAuthzProtectedSessionsByProject fetches sessions for the given projects. -// Note that the sessions are not fully populated, and only contain the -// necessary information to implement the boundary.AuthzProtectedEntity -// interface. Supports the WithTerminated option. -func (r *Repository) fetchAuthzProtectedSessionsByProject( - ctx context.Context, projectIds []string, opt ...Option, -) (map[string][]boundary.AuthzProtectedEntity, error) { - const op = "session.(Repository).fetchAuthzProtectedSessionsByProject" - - opts := getOpts(opt...) - - if len(projectIds) == 0 { - return nil, errors.New(ctx, errors.InvalidParameter, op, "no project ids given") - } - - args := []interface{}{ - sql.Named("project_ids", "{"+strings.Join(projectIds, ",")+"}"), - } - - var query string - if opts.withTerminated { - query = sessionPublicIdList - } else { - query = nonTerminatedSessionPublicIdList - } - - rows, err := r.reader.Query(ctx, query, args) - if err != nil { - return nil, errors.Wrap(ctx, err, op) - } - defer rows.Close() - - sessionsMap := map[string][]boundary.AuthzProtectedEntity{} - for rows.Next() { - var ses Session - if err := r.reader.ScanRows(ctx, rows, &ses); err != nil { - return nil, errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) - } - sessionsMap[ses.GetProjectId()] = append(sessionsMap[ses.GetProjectId()], ses) - } - - return sessionsMap, nil -} - -// ListSessions lists sessions. Supports the WithLimit, WithProjectId, and WithSessionIds options. +// ListSessions lists sessions. Sessions returned will be limited by the list +// permissions of the repository. Supports the WithTerminated, WithLimit, +// WithOrderByCreateTime options. func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Session, error) { const op = "session.(Repository).ListSessions" opts := getOpts(opt...) - var where []string - var args []interface{} - - inClauseCnt := 0 - switch len(opts.withProjectIds) { - case 0: - case 1: - inClauseCnt += 1 - where, args = append(where, fmt.Sprintf("project_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withProjectIds[0])) - default: - idsInClause := make([]string, 0, len(opts.withProjectIds)) - for _, id := range opts.withProjectIds { - inClauseCnt += 1 - idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) - } - where = append(where, fmt.Sprintf("project_id in (%s)", strings.Join(idsInClause, ","))) - } - if opts.withUserId != "" { - inClauseCnt += 1 - where, args = append(where, fmt.Sprintf("user_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withUserId)) + where, args := r.listPermissionWhereClauses() + if len(where) == 0 { + return nil, nil } - switch len(opts.withSessionIds) { - case 0: - case 1: - inClauseCnt += 1 - where, args = append(where, fmt.Sprintf("s.public_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withSessionIds[0])) - default: - idsInClause := make([]string, 0, len(opts.withSessionIds)) - for _, id := range opts.withSessionIds { - inClauseCnt += 1 - idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) + var whereClause string + if len(where) > 0 { + whereClause = " where (" + strings.Join(where, " or ") + ")" + if !opts.withTerminated { + whereClause += "and termination_reason is null" + } + } else { + if !opts.withTerminated { + whereClause = "where termination_reason is null" } - where = append(where, fmt.Sprintf("s.public_id in (%s)", strings.Join(idsInClause, ","))) } var limit string @@ -323,7 +261,6 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio // non-zero signals an override of the default limit for the repo. limit = fmt.Sprintf("limit %d", opts.withLimit) } - var withOrder string switch opts.withOrderByCreateTime { case db.AscendingOrderBy: @@ -334,10 +271,6 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio withOrder = "order by create_time" } - var whereClause string - if len(where) > 0 { - whereClause = " where " + strings.Join(where, " and ") - } q := sessionList query := fmt.Sprintf(q, whereClause, withOrder, limit, withOrder) diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 0bcf6e0555..30ef301a2a 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -21,9 +21,12 @@ import ( iamStore "github.com/hashicorp/boundary/internal/iam/store" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/perms" "github.com/hashicorp/boundary/internal/target" "github.com/hashicorp/boundary/internal/target/tcp" tcpStore "github.com/hashicorp/boundary/internal/target/tcp/store" + "github.com/hashicorp/boundary/internal/types/action" + "github.com/hashicorp/boundary/internal/types/resource" wrapping "github.com/hashicorp/go-kms-wrapping/v2" "github.com/jackc/pgconn" "github.com/stretchr/testify/assert" @@ -39,10 +42,18 @@ func TestRepository_ListSession(t *testing.T) { iamRepo := iam.TestRepo(t, conn, wrapper) rw := db.New(conn) kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit)) - require.NoError(t, err) composedOf := TestSessionParams(t, conn, wrapper, iamRepo) + listPerms := &perms.UserPermissions{ + UserId: composedOf.UserId, + Permissions: []perms.Permission{ + { + ScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, + }, + }, + } type args struct { opt []Option } @@ -50,58 +61,85 @@ func TestRepository_ListSession(t *testing.T) { name string createCnt int args args + perms *perms.UserPermissions wantCnt int wantErr bool withConnections int }{ { name: "no-limit", - createCnt: repo.defaultLimit + 1, + createCnt: testLimit + 1, args: args{ opt: []Option{WithLimit(-1)}, }, - wantCnt: repo.defaultLimit + 1, + perms: listPerms, + wantCnt: testLimit + 1, wantErr: false, }, { name: "default-limit", - createCnt: repo.defaultLimit + 1, + createCnt: testLimit + 1, args: args{}, - wantCnt: repo.defaultLimit, + perms: listPerms, + wantCnt: testLimit, wantErr: false, }, { name: "custom-limit", - createCnt: repo.defaultLimit + 1, + createCnt: testLimit + 1, args: args{ opt: []Option{WithLimit(3)}, }, + perms: listPerms, wantCnt: 3, wantErr: false, }, { - name: "withProjectIds", - createCnt: repo.defaultLimit + 1, - args: args{ - opt: []Option{WithProjectIds([]string{composedOf.ProjectId})}, + name: "withNoPerms", + createCnt: testLimit + 1, + args: args{}, + perms: &perms.UserPermissions{}, + wantCnt: 0, + wantErr: false, + }, + { + name: "withPermsDifferentScopeId", + createCnt: testLimit + 1, + args: args{}, + perms: &perms.UserPermissions{ + Permissions: []perms.Permission{ + { + ScopeId: "o_thisIsNotValid", + Resource: resource.Session, + Action: action.List, + }, + }, }, - wantCnt: repo.defaultLimit, + wantCnt: 0, wantErr: false, }, { - name: "bad-withProjectId", - createCnt: repo.defaultLimit + 1, - args: args{ - opt: []Option{WithProjectIds([]string{"o_thisIsNotValid"})}, + name: "withPermsNonListAction", + createCnt: testLimit + 1, + args: args{}, + perms: &perms.UserPermissions{ + Permissions: []perms.Permission{ + { + ScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.Read, + }, + }, }, wantCnt: 0, wantErr: false, }, { name: "multiple-connections", - createCnt: repo.defaultLimit + 1, + createCnt: testLimit + 1, args: args{}, - wantCnt: repo.defaultLimit, + perms: listPerms, + wantCnt: testLimit, wantErr: false, withConnections: 3, }, @@ -109,6 +147,10 @@ func TestRepository_ListSession(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) + + repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit), WithPermissions(tt.perms)) + require.NoError(err) + db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") testSessions := []*Session{} for i := 0; i < tt.createCnt; i++ { @@ -150,6 +192,10 @@ func TestRepository_ListSession(t *testing.T) { for i := 0; i < wantCnt; i++ { _ = TestSession(t, conn, wrapper, composedOf) } + + repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit), WithPermissions(listPerms)) + require.NoError(err) + got, err := repo.ListSessions(context.Background(), WithOrderByCreateTime(db.AscendingOrderBy)) require.NoError(err) assert.Equal(wantCnt, len(got)) @@ -160,7 +206,7 @@ func TestRepository_ListSession(t *testing.T) { assert.True(first.Before(second)) } }) - t.Run("withUserId", func(t *testing.T) { + t.Run("onlySelf", func(t *testing.T) { assert, require := assert.New(t), require.New(t) db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") wantCnt := 5 @@ -168,51 +214,25 @@ func TestRepository_ListSession(t *testing.T) { _ = TestSession(t, conn, wrapper, composedOf) } s := TestDefaultSession(t, conn, wrapper, iamRepo) + + p := &perms.UserPermissions{ + UserId: s.UserId, + Permissions: []perms.Permission{ + { + ScopeId: s.ProjectId, + Resource: resource.Session, + Action: action.List, + OnlySelf: true, + }, + }, + } + repo, err := NewRepository(rw, rw, kms, WithLimit(testLimit), WithPermissions(p)) + require.NoError(err) got, err := repo.ListSessions(context.Background(), WithUserId(s.UserId)) require.NoError(err) assert.Equal(1, len(got)) assert.Equal(s.UserId, got[0].UserId) }) - t.Run("withUserIdAndwithScopeId", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") - wantCnt := 5 - for i := 0; i < wantCnt; i++ { - // Scope 1 User 1 - _ = TestSession(t, conn, wrapper, composedOf) - } - // Scope 2 User 2 - s := TestDefaultSession(t, conn, wrapper, iamRepo) - - // Scope 1 User 2 - coDiffUser := composedOf - coDiffUser.AuthTokenId = s.AuthTokenId - coDiffUser.UserId = s.UserId - wantS := TestSession(t, conn, wrapper, coDiffUser) - - got, err := repo.ListSessions(context.Background(), WithUserId(coDiffUser.UserId), WithProjectIds([]string{coDiffUser.ProjectId})) - require.NoError(err) - assert.Equal(1, len(got)) - assert.Equal(wantS.UserId, got[0].UserId) - assert.Equal(wantS.ProjectId, got[0].ProjectId) - }) - t.Run("WithSessionIds", func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") - testSessions := []*Session{} - for i := 0; i < 10; i++ { - s := TestSession(t, conn, wrapper, composedOf) - _ = TestState(t, conn, s.PublicId, StatusActive) - testSessions = append(testSessions, s) - } - assert.Equal(10, len(testSessions)) - withIds := []string{testSessions[0].PublicId, testSessions[1].PublicId} - got, err := repo.ListSessions(context.Background(), WithSessionIds(withIds...), WithOrderByCreateTime(db.AscendingOrderBy)) - require.NoError(err) - assert.Equal(2, len(got)) - assert.Equal(StatusActive, got[0].States[0].Status) - assert.Equal(StatusPending, got[0].States[1].Status) - }) } func TestRepository_ListSessions_Multiple_Scopes(t *testing.T) { @@ -222,23 +242,29 @@ func TestRepository_ListSessions_Multiple_Scopes(t *testing.T) { iamRepo := iam.TestRepo(t, conn, wrapper) rw := db.New(conn) kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms) - require.NoError(t, err) db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") const numPerScope = 10 - var projs []string + var p []perms.Permission for i := 0; i < numPerScope; i++ { composedOf := TestSessionParams(t, conn, wrapper, iamRepo) - projs = append(projs, composedOf.ProjectId) + p = append(p, perms.Permission{ + ScopeId: composedOf.ProjectId, + Resource: resource.Session, + Action: action.List, + }) s := TestSession(t, conn, wrapper, composedOf) _ = TestState(t, conn, s.PublicId, StatusActive) } - got, err := repo.ListSessions(context.Background(), WithProjectIds(projs)) + repo, err := NewRepository(rw, rw, kms, WithPermissions(&perms.UserPermissions{ + Permissions: p, + })) require.NoError(t, err) - assert.Equal(t, len(projs), len(got)) + got, err := repo.ListSessions(context.Background()) + require.NoError(t, err) + assert.Equal(t, len(p), len(got)) } func TestRepository_CreateSession(t *testing.T) { @@ -1623,232 +1649,3 @@ func TestRepository_deleteTerminated(t *testing.T) { }) } } - -func TestFetchAuthzProtectedSessionsByScopes(t *testing.T) { - conn, _ := db.TestSetup(t, "postgres") - ctx := context.Background() - wrapper := db.TestWrapper(t) - iamRepo := iam.TestRepo(t, conn, wrapper) - rw := db.New(conn) - kms := kms.TestKms(t, conn, wrapper) - repo, err := NewRepository(rw, rw, kms) - require.NoError(t, err) - composedOf := TestSessionParams(t, conn, wrapper, iamRepo) - - _, pWithOtherSessions := iam.TestScopes(t, iamRepo) - - hcOther := static.TestCatalogs(t, conn, pWithOtherSessions.GetPublicId(), 1)[0] - hsOther := static.TestSets(t, conn, hcOther.GetPublicId(), 1)[0] - hOther := static.TestHosts(t, conn, hcOther.GetPublicId(), 1)[0] - static.TestSetMembers(t, conn, hsOther.GetPublicId(), []*static.Host{hOther}) - tarOther := tcp.TestTarget(ctx, t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()})) - - composedOfOther := ComposedOf{ - UserId: composedOf.UserId, - HostId: hOther.GetPublicId(), - TargetId: tarOther.GetPublicId(), - HostSetId: hsOther.GetPublicId(), - AuthTokenId: composedOf.AuthTokenId, - ProjectId: pWithOtherSessions.GetPublicId(), - Endpoint: "tcp://127.0.0.1:22", - } - - type testCase struct { - name string - createCnt int - terminateCnt int - otherCnt int - otherTerminateCnt int - opts []Option - reqScopes []string - wantCnt int - wantOtherCnt int - wantErr bool - } - - cases := []testCase{ - { - name: "NonTerminated/none", - createCnt: 0, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 0, - wantErr: false, - }, - { - name: "NonTerminated/one", - createCnt: 1, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 1, - wantErr: false, - }, - { - name: "NonTerminated/many", - createCnt: 5, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 5, - wantErr: false, - }, - { - name: "NonTerminated/many one terminated", - createCnt: 5, - terminateCnt: 1, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 4, - wantErr: false, - }, - { - name: "NonTerminated/many terminated", - createCnt: 5, - terminateCnt: 3, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 2, - wantErr: false, - }, - { - name: "NonTerminated/many multiple projects", - createCnt: 5, - terminateCnt: 3, - reqScopes: []string{composedOf.ProjectId, composedOfOther.ProjectId}, - wantCnt: 2, - wantErr: false, - }, - { - name: "NonTerminated/many multiple projects", - createCnt: 5, - terminateCnt: 3, - otherCnt: 3, - otherTerminateCnt: 1, - reqScopes: []string{composedOf.ProjectId, composedOfOther.ProjectId}, - wantCnt: 2, - wantOtherCnt: 2, - wantErr: false, - }, - { - name: "NonTerminated/no projects", - createCnt: 2, - terminateCnt: 1, - otherCnt: 2, - otherTerminateCnt: 1, - reqScopes: []string{}, - wantErr: true, - }, - { - name: "none", - opts: []Option{WithTerminated(true)}, - createCnt: 0, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 0, - wantErr: false, - }, - { - name: "one", - opts: []Option{WithTerminated(true)}, - createCnt: 1, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 1, - wantErr: false, - }, - { - name: "many", - opts: []Option{WithTerminated(true)}, - createCnt: 5, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 5, - wantErr: false, - }, - { - name: "many one terminated", - opts: []Option{WithTerminated(true)}, - createCnt: 5, - terminateCnt: 1, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 5, - wantErr: false, - }, - { - name: "many terminated", - opts: []Option{WithTerminated(true)}, - createCnt: 5, - terminateCnt: 3, - reqScopes: []string{composedOf.ProjectId}, - wantCnt: 5, - wantErr: false, - }, - { - name: "many multiple projects", - opts: []Option{WithTerminated(true)}, - createCnt: 5, - terminateCnt: 3, - reqScopes: []string{composedOf.ProjectId, composedOfOther.ProjectId}, - wantCnt: 5, - wantErr: false, - }, - { - name: "many multiple projects", - opts: []Option{WithTerminated(true)}, - createCnt: 5, - terminateCnt: 3, - otherCnt: 3, - otherTerminateCnt: 1, - reqScopes: []string{composedOf.ProjectId, composedOfOther.ProjectId}, - wantCnt: 5, - wantOtherCnt: 3, - wantErr: false, - }, - { - name: "no projects", - opts: []Option{WithTerminated(true)}, - createCnt: 2, - terminateCnt: 1, - otherCnt: 2, - otherTerminateCnt: 1, - reqScopes: []string{}, - wantErr: true, - }, - } - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - assert, require := assert.New(t), require.New(t) - db.TestDeleteWhere(t, conn, func() interface{} { i := AllocSession(); return &i }(), "1=1") - - testSessions := []*Session{} - for i := 0; i < tt.createCnt; i++ { - s := TestSession(t, conn, wrapper, composedOf) - _ = TestState(t, conn, s.PublicId, StatusActive) - testSessions = append(testSessions, s) - } - for i := 0; i < tt.terminateCnt; i++ { - _, err := repo.CancelSession(ctx, testSessions[i].PublicId, testSessions[i].Version) - require.NoError(err) - } - terminated, err := repo.TerminateCompletedSessions(ctx) - require.NoError(err) - require.Equal(tt.terminateCnt, terminated) - - otherTestSessions := []*Session{} - for i := 0; i < tt.otherCnt; i++ { - s := TestSession(t, conn, wrapper, composedOfOther) - _ = TestState(t, conn, s.PublicId, StatusActive) - otherTestSessions = append(otherTestSessions, s) - } - for i := 0; i < tt.otherTerminateCnt; i++ { - _, err := repo.CancelSession(ctx, otherTestSessions[i].PublicId, otherTestSessions[i].Version) - require.NoError(err) - } - terminated, err = repo.TerminateCompletedSessions(ctx) - require.NoError(err) - require.Equal(tt.otherTerminateCnt, terminated) - - assert.Equal(tt.otherCnt, len(otherTestSessions)) - - got, err := repo.fetchAuthzProtectedSessionsByProject(ctx, tt.reqScopes, tt.opts...) - if tt.wantErr { - require.Error(err) - return - } - require.NoError(err) - assert.Equal(tt.wantCnt, len(got[composedOf.ProjectId])) - assert.Equal(tt.wantOtherCnt, len(got[composedOfOther.ProjectId])) - }) - } -} diff --git a/internal/session/service_list_for_authz_check.go b/internal/session/service_list_for_authz_check.go deleted file mode 100644 index 98687c7bb1..0000000000 --- a/internal/session/service_list_for_authz_check.go +++ /dev/null @@ -1,23 +0,0 @@ -package session - -import ( - "context" - - "github.com/hashicorp/boundary/internal/boundary" -) - -type listForAuthzCheckFunc func(ctx context.Context, projectIds []string) (map[string][]boundary.AuthzProtectedEntity, error) - -func (a listForAuthzCheckFunc) FetchAuthzProtectedEntitiesByScope(ctx context.Context, projectIds []string) (map[string][]boundary.AuthzProtectedEntity, error) { - return a(ctx, projectIds) -} - -// ListForAuthzCheck returns a functions that fetches sessions for the given -// projects. Note that the sessions are not fully populated, and only contain the -// necessary information to implement the boundary.AuthzProtectedEntity -// interface. Supports the WithTerminated option. -func ListForAuthzCheck(repo *Repository, opt ...Option) listForAuthzCheckFunc { - return func(ctx context.Context, projectIds []string) (map[string][]boundary.AuthzProtectedEntity, error) { - return repo.fetchAuthzProtectedSessionsByProject(ctx, projectIds, opt...) - } -}