From b41e983503be89689d7121c3e8d6c876dcf3690c Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Fri, 29 Apr 2022 09:52:01 -0400 Subject: [PATCH] Reorder authz check for sessions (#2042) * Reorder session authz checks This implements a mechanism for service handlers to pass a repo satisfying an interface into the scope IDs listing function. The purpose of this interface is to allow the same function to perform authorization checking against resources rather than requiring service handlers to fetch all resources in the determined set of scopes and then perform this checking. While this could be done in each service handler, this approach centralizes the logic, which has the nice benefit of removing some boilerplate from service handler list functions that have adapted to it. Tests haven't changed, which is expected -- the function should return exactly what was being returned before, simply faster. --- internal/boundary/boundary.go | 13 +- internal/servers/controller/auth/auth.go | 2 +- .../controller/common/scopeids/scope_ids.go | 309 ++++++++++++++---- .../authmethods/authmethod_service.go | 2 +- .../handlers/authtokens/authtoken_service.go | 2 +- .../credentialstore_service.go | 2 +- .../handlers/groups/group_service.go | 2 +- .../host_catalogs/host_catalog_service.go | 2 +- .../controller/handlers/roles/role_service.go | 2 +- .../handlers/scopes/scope_service.go | 2 +- .../handlers/sessions/session_service.go | 48 +-- .../handlers/targets/target_service.go | 2 +- .../controller/handlers/users/user_service.go | 2 +- internal/session/query.go | 7 + internal/session/repository_session.go | 77 ++++- internal/session/session.go | 16 +- 16 files changed, 375 insertions(+), 115 deletions(-) diff --git a/internal/boundary/boundary.go b/internal/boundary/boundary.go index a637c85153..8081495c61 100644 --- a/internal/boundary/boundary.go +++ b/internal/boundary/boundary.go @@ -2,7 +2,9 @@ // define the Boundary domain. package boundary -import "github.com/hashicorp/boundary/internal/db/timestamp" +import ( + "github.com/hashicorp/boundary/internal/db/timestamp" +) // An Entity is an object distinguished by its identity, rather than its // attributes. It can contain value objects and other entities. @@ -25,3 +27,12 @@ type Resource interface { GetName() string GetDescription() string } + +// AuthzProtectedEntity is used by some functions (primarily +// scopeids.AuthzProtectedEntityProvider-conforming implementations) to deliver +// some common information necessary for calculating authz. +type AuthzProtectedEntity interface { + Entity + GetScopeId() string + GetUserId() string +} diff --git a/internal/servers/controller/auth/auth.go b/internal/servers/controller/auth/auth.go index 309f825894..56fa2861cb 100644 --- a/internal/servers/controller/auth/auth.go +++ b/internal/servers/controller/auth/auth.go @@ -196,7 +196,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) { return } if scp == nil { - ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprint("non-existent scope $q", ret.Scope.Id)) + ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("non-existent scope %q", ret.Scope.Id)) return } ret.Scope = &scopes.ScopeInfo{ diff --git a/internal/servers/controller/common/scopeids/scope_ids.go b/internal/servers/controller/common/scopeids/scope_ids.go index 7a6fec7cbb..fd6be78d8d 100644 --- a/internal/servers/controller/common/scopeids/scope_ids.go +++ b/internal/servers/controller/common/scopeids/scope_ids.go @@ -3,6 +3,7 @@ package scopeids import ( "context" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/perms" @@ -15,70 +16,127 @@ import ( "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes" ) -// GetListingScopeIds, given common parameters for List calls, returns the set of scope -// IDs in which to search for resources. It also returns a memoized map of the -// scopes to their info for populating returned values. -// -// Note: This was originally pulled out 1:1 from the role service. It and other -// tests in the other service handlers test this function extensively as it -// forms the basis for all recursive listing tests; see those tests for list -// functionality in the various service handlers. -func GetListingScopeIds( +type authzProtectedEntityProvider interface { + // Fetches basic resource info for the given scopes. Note that this is a + // "where clause" style of argument: if the set of scopes is populated these + // are the scopes to limit to (e.g. to put in a where clause). An empty set + // of scopes means to look in *all* scopes, not none! + FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error) +} + +// ResourceInfo contains information about a particular resource +type ResourceInfo struct { + AuthorizedActions action.ActionSet +} + +// ScopeInfoWithResourceIds contains information about a scope and the resources +// found within it +type ScopeInfoWithResourceIds struct { + *scopes.ScopeInfo + Resources map[string]ResourceInfo +} + +// GetListingResourceInformationInput contains input parameters to the function +type GetListingResourceInformationInput struct { + // An IAM repo function to use for a scope listing call + IamRepoFn common.IamRepoFactory + + // The original auth results from the list command + AuthResults auth.VerifyResults + + // The scope ID to use, or the starting point for a recursive search + RootScopeId string + + // The type of resource being listed + Type resource.Type + + // Whether the search is recursive + Recursive bool + + // A repo to fetch resources + AuthzProtectedEntityProvider authzProtectedEntityProvider + + // The available actions for the resource type + ActionSet action.ActionSet +} + +// GetListingResourceInformationOutput contains results from the function +type GetListingResourceInformationOutput struct { + // The calculated list of relevant scope IDs + ScopeIds []string + + // The specific resource IDs calculated to be authorized for listing + ResourceIds []string + + // A map of scope ID to scope information and a map of resource IDs in that + // scope and specific information about that resource, such as available + // actions + ScopeResourceMap map[string]*ScopeInfoWithResourceIds +} + +// GetListingResourceInformation, given common parameters for List calls, +// returns useful information: the set of scope IDs in which to search for +// resources; the IDs of the resources known to be authorized for that user; and +// a memoized map of the scopes to their info for populating returned values. +func GetListingResourceInformation( // The context to use when listing in the DB, if required ctx context.Context, - // An IAM repo function to use for a listing call, if required - repoFn common.IamRepoFactory, - // The original auth results from the list command - authResults auth.VerifyResults, - // The scope ID to use, or to use as the starting point for a recursive - // search - rootScopeId string, - // The type of resource we are listing - typ resource.Type, - // Whether or not the search should be recursive - recursive bool, - // Whether to only return scopes with exact permissions, or whether parent - // scopes with appropriate permissions are sufficient - directOnly bool, -) ([]string, map[string]*scopes.ScopeInfo, error) { - const op = "GetListingScopeIds" + // The input struct + input GetListingResourceInformationInput, +) (*GetListingResourceInformationOutput, error) { + const op = "scopeids.GetListingResourceInformation" + + output := new(GetListingResourceInformationOutput) // Validation switch { - case typ == resource.Unknown: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource") - case repoFn == nil: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo") - case rootScopeId == "": - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id") - case authResults.Scope == nil: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results") + case input.Type == resource.Unknown: + return nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource") + case input.IamRepoFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo") + case input.RootScopeId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id") + case input.AuthResults.Scope == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results") + case !input.Recursive && input.AuthResults.Scope.Id != input.RootScopeId: + return nil, errors.New(ctx, errors.InvalidParameter, op, "non-recursive search but auth results scope does not match root scope") } + // This will be used to memoize scope info so we can put the right scope + // info for each returned value + output.ScopeResourceMap = map[string]*ScopeInfoWithResourceIds{} + // Base case: if not recursive, return the scope we were given and the // already-looked-up info - if !recursive { - return []string{authResults.Scope.Id}, map[string]*scopes.ScopeInfo{authResults.Scope.Id: authResults.Scope}, nil + if !input.Recursive { + output.ScopeResourceMap[input.AuthResults.Scope.Id] = &ScopeInfoWithResourceIds{ScopeInfo: input.AuthResults.Scope} + // If we don't have information do to the resource lookup ourselves, + // return what we have + if input.AuthzProtectedEntityProvider == nil { + output.ScopeIds = []string{input.AuthResults.Scope.Id} + return output, nil + } + // Otherwise filter on this one scope and return + if err := filterAuthorizedResourceIds(ctx, input, output); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources")) + } + return output, nil } - // This will be used to memoize scope info so we can put the right scope - // info for each returned value - scopeInfoMap := map[string]*scopes.ScopeInfo{} - - repo, err := repoFn() + repo, err := input.IamRepoFn() if err != nil { - return nil, nil, err + return nil, err } // Get all scopes recursively. Start at global because we need to take into // account permissions in parent scopes even if they want to scale back the // returned information to a child scope and its children. scps, err := repo.ListScopesRecursively(ctx, scope.Global.String()) if err != nil { - return nil, nil, err + return nil, err } res := perms.Resource{ - Type: typ, + Type: input.Type, } // For each scope, see if we have permission to list that type in that // scope @@ -88,7 +146,7 @@ func GetListingScopeIds( for _, scp := range scps { scpId := scp.GetPublicId() res.ScopeId = scpId - aSet := authResults.FetchActionSetForType(ctx, + aSet := input.AuthResults.FetchActionSetForType(ctx, // This is overridden by WithResource resource.Unknown, action.ActionSet{action.List}, @@ -97,16 +155,14 @@ func GetListingScopeIds( switch len(aSet) { case 0: // Defer until we've read all scopes. We do this because if the - // ordering coming back isn't in parent-first ording our map + // ordering coming back isn't in parent-first ordering our map // lookup might fail. - if !directOnly { - deferredScopes = append(deferredScopes, scp) - } + deferredScopes = append(deferredScopes, scp) case 1: if aSet[0] != action.List { - return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected action in set") + return nil, errors.New(ctx, errors.Internal, op, "unexpected action in set") } - if scopeInfoMap[scpId] == nil { + if output.ScopeResourceMap[scpId] == nil { scopeInfo := &scopes.ScopeInfo{ Id: scp.GetPublicId(), Type: scp.GetType(), @@ -114,13 +170,13 @@ func GetListingScopeIds( Description: scp.GetDescription(), ParentScopeId: scp.GetParentId(), } - scopeInfoMap[scpId] = scopeInfo + output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo} } if scpId == scope.Global.String() { globalHasList = true } default: - return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set") + return nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set") } } @@ -129,9 +185,9 @@ func GetListingScopeIds( // If they had list on global scope anything else is automatically // included; otherwise if they had list on the parent scope, this // scope is included in the map and is sufficient here. - if globalHasList || scopeInfoMap[scp.GetParentId()] != nil { + if globalHasList || output.ScopeResourceMap[scp.GetParentId()] != nil { scpId := scp.GetPublicId() - if scopeInfoMap[scpId] == nil { + if output.ScopeResourceMap[scpId] == nil { scopeInfo := &scopes.ScopeInfo{ Id: scp.GetPublicId(), Type: scp.GetType(), @@ -139,21 +195,15 @@ func GetListingScopeIds( Description: scp.GetDescription(), ParentScopeId: scp.GetParentId(), } - scopeInfoMap[scpId] = scopeInfo + output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo} } } } - // If we have nothing in scopeInfoMap at this point, we aren't authorized - // anywhere so return 403. - if len(scopeInfoMap) == 0 { - return nil, nil, handlers.ForbiddenError() - } - // Now elide out any that aren't under the root scope ID - elideScopes := make([]string, 0, len(scopeInfoMap)) - for scpId, scp := range scopeInfoMap { - switch rootScopeId { + elideScopes := make([]string, 0, len(output.ScopeResourceMap)) + for scpId, scp := range output.ScopeResourceMap { + switch input.RootScopeId { // If the root is global, it matches case scope.Global.String(): // If the current scope matches the root, it matches @@ -168,13 +218,134 @@ func GetListingScopeIds( } for _, scpId := range elideScopes { - delete(scopeInfoMap, scpId) + delete(output.ScopeResourceMap, scpId) + } + + // If we have nothing in scopeInfoMap at this point, we aren't authorized + // anywhere so return 403. + if len(output.ScopeResourceMap) == 0 { + return nil, handlers.ForbiddenError() } - scopeIds := make([]string, 0, len(scopeInfoMap)) - for k := range scopeInfoMap { - scopeIds = append(scopeIds, k) + if input.AuthzProtectedEntityProvider == nil { + output.populateScopeIdsFromScopeResourceMap() + return output, nil } - return scopeIds, scopeInfoMap, nil + if err := filterAuthorizedResourceIds(ctx, input, output); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources")) + } + + return output, nil +} + +// filterAuthorizedResourceIds calls the passed in function to get IDs for +// resources in the given scopes and then figures out which ones are actually +// authorized for listing by the user. +// +// It also populates the scope IDs in the output +func filterAuthorizedResourceIds( + // The context to use when listing in the DB, if required + ctx context.Context, + // The input struct + input GetListingResourceInformationInput, + // The scope information to fill out + output *GetListingResourceInformationOutput, +) error { + const op = "scopeids.filterAuthorizedResources" + + // Populate scopeIds and determine if we found global + output.populateScopeIdsFromScopeResourceMap() + + // The calling function is giving us a complete set with any recursive + // lookup already performed (that's the point of the function). + scopedResourceInfo, err := input.AuthzProtectedEntityProvider.FetchAuthzProtectedEntitiesByScope(ctx, output.ScopeIds) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + res := perms.Resource{ + Type: resource.Session, + } + + // Now run authorization checks against each so we know if there is a point + // in fetching the full resource, and cache the authorized actions + for scopeId, resourceInfos := range scopedResourceInfo { + for _, resourceInfo := range resourceInfos { + res.Id = resourceInfo.GetPublicId() + res.ScopeId = scopeId + authorizedActions := input.AuthResults.FetchActionSetForId(ctx, resourceInfo.GetPublicId(), input.ActionSet, auth.WithResource(&res)) + if len(authorizedActions) == 0 { + continue + } + + if resourceInfo.GetUserId() != "" { + if authorizedActions.OnlySelf() && resourceInfo.GetUserId() != input.AuthResults.UserId { + continue + } + } + + if output.ScopeResourceMap[scopeId].Resources == nil { + output.ScopeResourceMap[scopeId].Resources = make(map[string]ResourceInfo) + } + + output.ScopeResourceMap[scopeId].Resources[resourceInfo.GetPublicId()] = ResourceInfo{AuthorizedActions: authorizedActions} + output.ResourceIds = append(output.ResourceIds, resourceInfo.GetPublicId()) + } + } + + return nil +} + +// populateScopeIdsFromScopeResourceMap populates the ScopeIds field and returns +// whether global scope was found +func (i *GetListingResourceInformationOutput) populateScopeIdsFromScopeResourceMap() { + for k := range i.ScopeResourceMap { + i.ScopeIds = append(i.ScopeIds, k) + } +} + +// GetListingScopeIds is provided for backwards compatibility with existing +// services; services should eventually migrate to +// GetListingResourceInformation. +func GetListingScopeIds( + // The context to use when listing in the DB, if required + ctx context.Context, + // An IAM repo function to use for a listing call, if required + repoFn common.IamRepoFactory, + // The original auth results from the list command + authResults auth.VerifyResults, + // The scope ID to use, or to use as the starting point for a recursive + // search + rootScopeId string, + // The type of resource we are listing + typ resource.Type, + // Whether or not the search should be recursive + recursive bool, +) ([]string, map[string]*scopes.ScopeInfo, error) { + const op = "scopeids.GetListingScopeIds" + scopeResourceInfo, err := GetListingResourceInformation(ctx, + GetListingResourceInformationInput{ + IamRepoFn: repoFn, + AuthResults: authResults, + RootScopeId: rootScopeId, + Type: typ, + Recursive: recursive, + }, + ) + if err != nil { + if err == handlers.ForbiddenError() { + return nil, nil, err + } + return nil, nil, errors.Wrap(ctx, err, op) + } + if len(scopeResourceInfo.ScopeIds) == 0 { + // This should have already happened in the other function, but... + return nil, nil, handlers.ForbiddenError() + } + scopeMap := make(map[string]*scopes.ScopeInfo, len(scopeResourceInfo.ScopeResourceMap)) + for k, v := range scopeResourceInfo.ScopeResourceMap { + scopeMap[k] = v.ScopeInfo + } + return scopeResourceInfo.ScopeIds, scopeMap, nil } diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service.go b/internal/servers/controller/handlers/authmethods/authmethod_service.go index 1349818960..7d58b43f1d 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service.go @@ -130,7 +130,7 @@ func (s Service) ListAuthMethods(ctx context.Context, req *pbs.ListAuthMethodsRe } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/authtokens/authtoken_service.go b/internal/servers/controller/handlers/authtokens/authtoken_service.go index 6034645b72..2ee04e1bf4 100644 --- a/internal/servers/controller/handlers/authtokens/authtoken_service.go +++ b/internal/servers/controller/handlers/authtokens/authtoken_service.go @@ -81,7 +81,7 @@ func (s Service) ListAuthTokens(ctx context.Context, req *pbs.ListAuthTokensRequ } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/credentialstores/credentialstore_service.go b/internal/servers/controller/handlers/credentialstores/credentialstore_service.go index 8f97b943d4..9fb8ff03ba 100644 --- a/internal/servers/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/servers/controller/handlers/credentialstores/credentialstore_service.go @@ -112,7 +112,7 @@ func (s Service) ListCredentialStores(ctx context.Context, req *pbs.ListCredenti } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/groups/group_service.go b/internal/servers/controller/handlers/groups/group_service.go index 0f562699e7..e7ffb5389f 100644 --- a/internal/servers/controller/handlers/groups/group_service.go +++ b/internal/servers/controller/handlers/groups/group_service.go @@ -92,7 +92,7 @@ func (s Service) ListGroups(ctx context.Context, req *pbs.ListGroupsRequest) (*p } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go b/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go index c83fb1518e..3f961f990e 100644 --- a/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go +++ b/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go @@ -129,7 +129,7 @@ func (s Service) ListHostCatalogs(ctx context.Context, req *pbs.ListHostCatalogs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/roles/role_service.go b/internal/servers/controller/handlers/roles/role_service.go index dc94d4d50a..c9e867cb45 100644 --- a/internal/servers/controller/handlers/roles/role_service.go +++ b/internal/servers/controller/handlers/roles/role_service.go @@ -96,7 +96,7 @@ func (s Service) ListRoles(ctx context.Context, req *pbs.ListRolesRequest) (*pbs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/scopes/scope_service.go b/internal/servers/controller/handlers/scopes/scope_service.go index f35ee93ac9..58eed4ea8d 100644 --- a/internal/servers/controller/handlers/scopes/scope_service.go +++ b/internal/servers/controller/handlers/scopes/scope_service.go @@ -133,7 +133,7 @@ func (s Service) ListScopes(ctx context.Context, req *pbs.ListScopesRequest) (*p } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/sessions/session_service.go b/internal/servers/controller/handlers/sessions/session_service.go index e65d47043f..0565879966 100644 --- a/internal/servers/controller/handlers/sessions/session_service.go +++ b/internal/servers/controller/handlers/sessions/session_service.go @@ -119,6 +119,8 @@ func (s Service) GetSession(ctx context.Context, req *pbs.GetSessionRequest) (*p // ListSessions implements the interface pbs.SessionServiceServer. func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) (*pbs.ListSessionsResponse, error) { + const op = "session.(Service).ListSessions" + if err := validateListRequest(req); err != nil { return nil, err } @@ -137,17 +139,34 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) } } - scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(ctx, - s.iamRepoFn, authResults, req.GetScopeId(), resource.Session, req.GetRecursive(), false) + repo, err := s.repoFn() + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + scopeResourceInfo, err := scopeids.GetListingResourceInformation( + ctx, + scopeids.GetListingResourceInformationInput{ + IamRepoFn: s.iamRepoFn, + AuthResults: authResults, + RootScopeId: req.GetScopeId(), + Type: resource.Session, + Recursive: req.GetRecursive(), + AuthzProtectedEntityProvider: repo, + ActionSet: IdActions, + }, + ) if err != nil { return nil, err } - // If no scopes match, return an empty response - if len(scopeIds) == 0 { + // If no scopes match or we match scopes but there are no resources in them + // that we are authorized to see, return an empty response + if len(scopeResourceInfo.ScopeIds) == 0 || + len(scopeResourceInfo.ResourceIds) == 0 { return &pbs.ListSessionsResponse{}, nil } - sesList, err := s.listFromRepo(ctx, scopeIds) + sesList, err := s.listFromRepo(ctx, scopeResourceInfo.ResourceIds) if err != nil { return nil, err } @@ -164,25 +183,14 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) Type: resource.Session, } for _, item := range sesList { - res.Id = item.GetPublicId() - res.ScopeId = item.ScopeId - authorizedActions := authResults.FetchActionSetForId(ctx, item.GetPublicId(), IdActions, auth.WithResource(&res)) - if len(authorizedActions) == 0 { - continue - } - - if authorizedActions.OnlySelf() && item.UserId != authResults.UserId { - continue - } - outputFields := authResults.FetchOutputFields(res, action.List).SelfOrDefaults(authResults.UserId) outputOpts := make([]handlers.Option, 0, 3) outputOpts = append(outputOpts, handlers.WithOutputFields(&outputFields)) if outputFields.Has(globals.ScopeField) { - outputOpts = append(outputOpts, handlers.WithScope(scopeInfoMap[item.ScopeId])) + outputOpts = append(outputOpts, handlers.WithScope(scopeResourceInfo.ScopeResourceMap[item.ScopeId].ScopeInfo)) } if outputFields.Has(globals.AuthorizedActionsField) { - outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authorizedActions.Strings())) + outputOpts = append(outputOpts, handlers.WithAuthorizedActions(scopeResourceInfo.ScopeResourceMap[item.ScopeId].Resources[item.PublicId].AuthorizedActions.Strings())) } item, err := toProto(ctx, item, outputOpts...) @@ -287,12 +295,12 @@ func (s Service) getFromRepo(ctx context.Context, id string) (*session.Session, return sess, nil } -func (s Service) listFromRepo(ctx context.Context, scopeIds []string) ([]*session.Session, error) { +func (s Service) listFromRepo(ctx context.Context, sessionIds []string) ([]*session.Session, error) { repo, err := s.repoFn() if err != nil { return nil, err } - sesList, err := repo.ListSessions(ctx, session.WithScopeIds(scopeIds)) + sesList, err := repo.ListSessions(ctx, session.WithSessionIds(sessionIds...)) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/targets/target_service.go b/internal/servers/controller/handlers/targets/target_service.go index e0b49fa7db..02b0d956c5 100644 --- a/internal/servers/controller/handlers/targets/target_service.go +++ b/internal/servers/controller/handlers/targets/target_service.go @@ -164,7 +164,7 @@ func (s Service) ListTargets(ctx context.Context, req *pbs.ListTargetsRequest) ( } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/users/user_service.go b/internal/servers/controller/handlers/users/user_service.go index 349dda738b..0b0e417d5e 100644 --- a/internal/servers/controller/handlers/users/user_service.go +++ b/internal/servers/controller/handlers/users/user_service.go @@ -94,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, req *pbs.ListUsersRequest) (*pbs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/session/query.go b/internal/session/query.go index 430ae4a08d..541c4b4759 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -118,6 +118,13 @@ select expiration_time, connection_limit, current_connection_count from session_connection_limit, session_connection_count; ` + + sessionPublicIdList = ` +select public_id, scope_id, user_id from session +%s +; +` + sessionList = ` select * from diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 6e28a42ab6..78fee5005c 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -205,6 +206,51 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O return &session, authzSummary, nil } +// FetchAuthzProtectedEntitiesByScope implements boundary.AuthzProtectedEntityProvider +func (r *Repository) FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error) { + const op = "session.(Repository).FetchAuthzProtectedEntityInfo" + + var where string + var args []interface{} + + inClauseCnt := 0 + + switch len(scopeIds) { + case 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "no scopes given") + case 1: + inClauseCnt += 1 + where, args = fmt.Sprintf("where scope_id = @%d", inClauseCnt), append(args, sql.Named("1", scopeIds[0])) + default: + idsInClause := make([]string, 0, len(scopeIds)) + for _, id := range scopeIds { + inClauseCnt += 1 + idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) + } + where = fmt.Sprintf("where scope_id in (%s)", strings.Join(idsInClause, ",")) + } + + q := sessionPublicIdList + query := fmt.Sprintf(q, where) + + rows, err := r.reader.Query(ctx, query, args) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + defer rows.Close() + + sessionsMap := map[string][]boundary.AuthzProtectedEntity{} + for rows.Next() { + var ses Session + if err := r.reader.ScanRows(ctx, rows, &ses); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) + } + sessionsMap[ses.GetScopeId()] = append(sessionsMap[ses.GetScopeId()], ses) + } + + return sessionsMap, nil +} + // ListSessions will sessions. Supports the WithLimit, WithScopeId, WithSessionIds, and WithServerId options. func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Session, error) { const op = "session.(Repository).ListSessions" @@ -213,25 +259,31 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio var args []interface{} inClauseCnt := 0 - if len(opts.withScopeIds) != 0 { - switch len(opts.withScopeIds) { - case 1: + switch len(opts.withScopeIds) { + case 0: + case 1: + inClauseCnt += 1 + where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withScopeIds[0])) + default: + idsInClause := make([]string, 0, len(opts.withScopeIds)) + for _, id := range opts.withScopeIds { inClauseCnt += 1 - where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named("1", opts.withScopeIds[0])) - default: - idsInClause := make([]string, 0, len(opts.withScopeIds)) - for _, id := range opts.withScopeIds { - inClauseCnt += 1 - idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) - } - where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ","))) + idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) } + where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ","))) } + if opts.withUserId != "" { inClauseCnt += 1 where, args = append(where, fmt.Sprintf("user_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withUserId)) } - if len(opts.withSessionIds) > 0 { + + switch len(opts.withSessionIds) { + case 0: + case 1: + inClauseCnt += 1 + where, args = append(where, fmt.Sprintf("s.public_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withSessionIds[0])) + default: idsInClause := make([]string, 0, len(opts.withSessionIds)) for _, id := range opts.withSessionIds { inClauseCnt += 1 @@ -239,6 +291,7 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio } where = append(where, fmt.Sprintf("s.public_id in (%s)", strings.Join(idsInClause, ","))) } + if opts.withServerId != "" { inClauseCnt += 1 where, args = append(where, fmt.Sprintf("server_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withServerId)) diff --git a/internal/session/session.go b/internal/session/session.go index baaa5331f0..7026656e73 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -119,13 +120,22 @@ type Session struct { tableName string `gorm:"-"` } -func (s *Session) GetPublicId() string { +func (s Session) GetPublicId() string { return s.PublicId } +func (s Session) GetScopeId() string { + return s.ScopeId +} + +func (s Session) GetUserId() string { + return s.UserId +} + var ( - _ Cloneable = (*Session)(nil) - _ db.VetForWriter = (*Session)(nil) + _ Cloneable = (*Session)(nil) + _ db.VetForWriter = (*Session)(nil) + _ boundary.AuthzProtectedEntity = (*Session)(nil) ) // New creates a new in memory session.