diff --git a/internal/boundary/boundary.go b/internal/boundary/boundary.go index a637c85153..8081495c61 100644 --- a/internal/boundary/boundary.go +++ b/internal/boundary/boundary.go @@ -2,7 +2,9 @@ // define the Boundary domain. package boundary -import "github.com/hashicorp/boundary/internal/db/timestamp" +import ( + "github.com/hashicorp/boundary/internal/db/timestamp" +) // An Entity is an object distinguished by its identity, rather than its // attributes. It can contain value objects and other entities. @@ -25,3 +27,12 @@ type Resource interface { GetName() string GetDescription() string } + +// AuthzProtectedEntity is used by some functions (primarily +// scopeids.AuthzProtectedEntityProvider-conforming implementations) to deliver +// some common information necessary for calculating authz. +type AuthzProtectedEntity interface { + Entity + GetScopeId() string + GetUserId() string +} diff --git a/internal/servers/controller/auth/auth.go b/internal/servers/controller/auth/auth.go index 309f825894..56fa2861cb 100644 --- a/internal/servers/controller/auth/auth.go +++ b/internal/servers/controller/auth/auth.go @@ -196,7 +196,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) { return } if scp == nil { - ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprint("non-existent scope $q", ret.Scope.Id)) + ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("non-existent scope %q", ret.Scope.Id)) return } ret.Scope = &scopes.ScopeInfo{ diff --git a/internal/servers/controller/common/scopeids/scope_ids.go b/internal/servers/controller/common/scopeids/scope_ids.go index 7a6fec7cbb..fd6be78d8d 100644 --- a/internal/servers/controller/common/scopeids/scope_ids.go +++ b/internal/servers/controller/common/scopeids/scope_ids.go @@ -3,6 +3,7 @@ package scopeids import ( "context" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/perms" @@ -15,70 +16,127 @@ import ( "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes" ) -// GetListingScopeIds, given common parameters for List calls, returns the set of scope -// IDs in which to search for resources. It also returns a memoized map of the -// scopes to their info for populating returned values. -// -// Note: This was originally pulled out 1:1 from the role service. It and other -// tests in the other service handlers test this function extensively as it -// forms the basis for all recursive listing tests; see those tests for list -// functionality in the various service handlers. -func GetListingScopeIds( +type authzProtectedEntityProvider interface { + // Fetches basic resource info for the given scopes. Note that this is a + // "where clause" style of argument: if the set of scopes is populated these + // are the scopes to limit to (e.g. to put in a where clause). An empty set + // of scopes means to look in *all* scopes, not none! + FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error) +} + +// ResourceInfo contains information about a particular resource +type ResourceInfo struct { + AuthorizedActions action.ActionSet +} + +// ScopeInfoWithResourceIds contains information about a scope and the resources +// found within it +type ScopeInfoWithResourceIds struct { + *scopes.ScopeInfo + Resources map[string]ResourceInfo +} + +// GetListingResourceInformationInput contains input parameters to the function +type GetListingResourceInformationInput struct { + // An IAM repo function to use for a scope listing call + IamRepoFn common.IamRepoFactory + + // The original auth results from the list command + AuthResults auth.VerifyResults + + // The scope ID to use, or the starting point for a recursive search + RootScopeId string + + // The type of resource being listed + Type resource.Type + + // Whether the search is recursive + Recursive bool + + // A repo to fetch resources + AuthzProtectedEntityProvider authzProtectedEntityProvider + + // The available actions for the resource type + ActionSet action.ActionSet +} + +// GetListingResourceInformationOutput contains results from the function +type GetListingResourceInformationOutput struct { + // The calculated list of relevant scope IDs + ScopeIds []string + + // The specific resource IDs calculated to be authorized for listing + ResourceIds []string + + // A map of scope ID to scope information and a map of resource IDs in that + // scope and specific information about that resource, such as available + // actions + ScopeResourceMap map[string]*ScopeInfoWithResourceIds +} + +// GetListingResourceInformation, given common parameters for List calls, +// returns useful information: the set of scope IDs in which to search for +// resources; the IDs of the resources known to be authorized for that user; and +// a memoized map of the scopes to their info for populating returned values. +func GetListingResourceInformation( // The context to use when listing in the DB, if required ctx context.Context, - // An IAM repo function to use for a listing call, if required - repoFn common.IamRepoFactory, - // The original auth results from the list command - authResults auth.VerifyResults, - // The scope ID to use, or to use as the starting point for a recursive - // search - rootScopeId string, - // The type of resource we are listing - typ resource.Type, - // Whether or not the search should be recursive - recursive bool, - // Whether to only return scopes with exact permissions, or whether parent - // scopes with appropriate permissions are sufficient - directOnly bool, -) ([]string, map[string]*scopes.ScopeInfo, error) { - const op = "GetListingScopeIds" + // The input struct + input GetListingResourceInformationInput, +) (*GetListingResourceInformationOutput, error) { + const op = "scopeids.GetListingResourceInformation" + + output := new(GetListingResourceInformationOutput) // Validation switch { - case typ == resource.Unknown: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource") - case repoFn == nil: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo") - case rootScopeId == "": - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id") - case authResults.Scope == nil: - return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results") + case input.Type == resource.Unknown: + return nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource") + case input.IamRepoFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo") + case input.RootScopeId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id") + case input.AuthResults.Scope == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results") + case !input.Recursive && input.AuthResults.Scope.Id != input.RootScopeId: + return nil, errors.New(ctx, errors.InvalidParameter, op, "non-recursive search but auth results scope does not match root scope") } + // This will be used to memoize scope info so we can put the right scope + // info for each returned value + output.ScopeResourceMap = map[string]*ScopeInfoWithResourceIds{} + // Base case: if not recursive, return the scope we were given and the // already-looked-up info - if !recursive { - return []string{authResults.Scope.Id}, map[string]*scopes.ScopeInfo{authResults.Scope.Id: authResults.Scope}, nil + if !input.Recursive { + output.ScopeResourceMap[input.AuthResults.Scope.Id] = &ScopeInfoWithResourceIds{ScopeInfo: input.AuthResults.Scope} + // If we don't have information do to the resource lookup ourselves, + // return what we have + if input.AuthzProtectedEntityProvider == nil { + output.ScopeIds = []string{input.AuthResults.Scope.Id} + return output, nil + } + // Otherwise filter on this one scope and return + if err := filterAuthorizedResourceIds(ctx, input, output); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources")) + } + return output, nil } - // This will be used to memoize scope info so we can put the right scope - // info for each returned value - scopeInfoMap := map[string]*scopes.ScopeInfo{} - - repo, err := repoFn() + repo, err := input.IamRepoFn() if err != nil { - return nil, nil, err + return nil, err } // Get all scopes recursively. Start at global because we need to take into // account permissions in parent scopes even if they want to scale back the // returned information to a child scope and its children. scps, err := repo.ListScopesRecursively(ctx, scope.Global.String()) if err != nil { - return nil, nil, err + return nil, err } res := perms.Resource{ - Type: typ, + Type: input.Type, } // For each scope, see if we have permission to list that type in that // scope @@ -88,7 +146,7 @@ func GetListingScopeIds( for _, scp := range scps { scpId := scp.GetPublicId() res.ScopeId = scpId - aSet := authResults.FetchActionSetForType(ctx, + aSet := input.AuthResults.FetchActionSetForType(ctx, // This is overridden by WithResource resource.Unknown, action.ActionSet{action.List}, @@ -97,16 +155,14 @@ func GetListingScopeIds( switch len(aSet) { case 0: // Defer until we've read all scopes. We do this because if the - // ordering coming back isn't in parent-first ording our map + // ordering coming back isn't in parent-first ordering our map // lookup might fail. - if !directOnly { - deferredScopes = append(deferredScopes, scp) - } + deferredScopes = append(deferredScopes, scp) case 1: if aSet[0] != action.List { - return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected action in set") + return nil, errors.New(ctx, errors.Internal, op, "unexpected action in set") } - if scopeInfoMap[scpId] == nil { + if output.ScopeResourceMap[scpId] == nil { scopeInfo := &scopes.ScopeInfo{ Id: scp.GetPublicId(), Type: scp.GetType(), @@ -114,13 +170,13 @@ func GetListingScopeIds( Description: scp.GetDescription(), ParentScopeId: scp.GetParentId(), } - scopeInfoMap[scpId] = scopeInfo + output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo} } if scpId == scope.Global.String() { globalHasList = true } default: - return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set") + return nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set") } } @@ -129,9 +185,9 @@ func GetListingScopeIds( // If they had list on global scope anything else is automatically // included; otherwise if they had list on the parent scope, this // scope is included in the map and is sufficient here. - if globalHasList || scopeInfoMap[scp.GetParentId()] != nil { + if globalHasList || output.ScopeResourceMap[scp.GetParentId()] != nil { scpId := scp.GetPublicId() - if scopeInfoMap[scpId] == nil { + if output.ScopeResourceMap[scpId] == nil { scopeInfo := &scopes.ScopeInfo{ Id: scp.GetPublicId(), Type: scp.GetType(), @@ -139,21 +195,15 @@ func GetListingScopeIds( Description: scp.GetDescription(), ParentScopeId: scp.GetParentId(), } - scopeInfoMap[scpId] = scopeInfo + output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo} } } } - // If we have nothing in scopeInfoMap at this point, we aren't authorized - // anywhere so return 403. - if len(scopeInfoMap) == 0 { - return nil, nil, handlers.ForbiddenError() - } - // Now elide out any that aren't under the root scope ID - elideScopes := make([]string, 0, len(scopeInfoMap)) - for scpId, scp := range scopeInfoMap { - switch rootScopeId { + elideScopes := make([]string, 0, len(output.ScopeResourceMap)) + for scpId, scp := range output.ScopeResourceMap { + switch input.RootScopeId { // If the root is global, it matches case scope.Global.String(): // If the current scope matches the root, it matches @@ -168,13 +218,134 @@ func GetListingScopeIds( } for _, scpId := range elideScopes { - delete(scopeInfoMap, scpId) + delete(output.ScopeResourceMap, scpId) + } + + // If we have nothing in scopeInfoMap at this point, we aren't authorized + // anywhere so return 403. + if len(output.ScopeResourceMap) == 0 { + return nil, handlers.ForbiddenError() } - scopeIds := make([]string, 0, len(scopeInfoMap)) - for k := range scopeInfoMap { - scopeIds = append(scopeIds, k) + if input.AuthzProtectedEntityProvider == nil { + output.populateScopeIdsFromScopeResourceMap() + return output, nil } - return scopeIds, scopeInfoMap, nil + if err := filterAuthorizedResourceIds(ctx, input, output); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources")) + } + + return output, nil +} + +// filterAuthorizedResourceIds calls the passed in function to get IDs for +// resources in the given scopes and then figures out which ones are actually +// authorized for listing by the user. +// +// It also populates the scope IDs in the output +func filterAuthorizedResourceIds( + // The context to use when listing in the DB, if required + ctx context.Context, + // The input struct + input GetListingResourceInformationInput, + // The scope information to fill out + output *GetListingResourceInformationOutput, +) error { + const op = "scopeids.filterAuthorizedResources" + + // Populate scopeIds and determine if we found global + output.populateScopeIdsFromScopeResourceMap() + + // The calling function is giving us a complete set with any recursive + // lookup already performed (that's the point of the function). + scopedResourceInfo, err := input.AuthzProtectedEntityProvider.FetchAuthzProtectedEntitiesByScope(ctx, output.ScopeIds) + if err != nil { + return errors.Wrap(ctx, err, op) + } + + res := perms.Resource{ + Type: resource.Session, + } + + // Now run authorization checks against each so we know if there is a point + // in fetching the full resource, and cache the authorized actions + for scopeId, resourceInfos := range scopedResourceInfo { + for _, resourceInfo := range resourceInfos { + res.Id = resourceInfo.GetPublicId() + res.ScopeId = scopeId + authorizedActions := input.AuthResults.FetchActionSetForId(ctx, resourceInfo.GetPublicId(), input.ActionSet, auth.WithResource(&res)) + if len(authorizedActions) == 0 { + continue + } + + if resourceInfo.GetUserId() != "" { + if authorizedActions.OnlySelf() && resourceInfo.GetUserId() != input.AuthResults.UserId { + continue + } + } + + if output.ScopeResourceMap[scopeId].Resources == nil { + output.ScopeResourceMap[scopeId].Resources = make(map[string]ResourceInfo) + } + + output.ScopeResourceMap[scopeId].Resources[resourceInfo.GetPublicId()] = ResourceInfo{AuthorizedActions: authorizedActions} + output.ResourceIds = append(output.ResourceIds, resourceInfo.GetPublicId()) + } + } + + return nil +} + +// populateScopeIdsFromScopeResourceMap populates the ScopeIds field and returns +// whether global scope was found +func (i *GetListingResourceInformationOutput) populateScopeIdsFromScopeResourceMap() { + for k := range i.ScopeResourceMap { + i.ScopeIds = append(i.ScopeIds, k) + } +} + +// GetListingScopeIds is provided for backwards compatibility with existing +// services; services should eventually migrate to +// GetListingResourceInformation. +func GetListingScopeIds( + // The context to use when listing in the DB, if required + ctx context.Context, + // An IAM repo function to use for a listing call, if required + repoFn common.IamRepoFactory, + // The original auth results from the list command + authResults auth.VerifyResults, + // The scope ID to use, or to use as the starting point for a recursive + // search + rootScopeId string, + // The type of resource we are listing + typ resource.Type, + // Whether or not the search should be recursive + recursive bool, +) ([]string, map[string]*scopes.ScopeInfo, error) { + const op = "scopeids.GetListingScopeIds" + scopeResourceInfo, err := GetListingResourceInformation(ctx, + GetListingResourceInformationInput{ + IamRepoFn: repoFn, + AuthResults: authResults, + RootScopeId: rootScopeId, + Type: typ, + Recursive: recursive, + }, + ) + if err != nil { + if err == handlers.ForbiddenError() { + return nil, nil, err + } + return nil, nil, errors.Wrap(ctx, err, op) + } + if len(scopeResourceInfo.ScopeIds) == 0 { + // This should have already happened in the other function, but... + return nil, nil, handlers.ForbiddenError() + } + scopeMap := make(map[string]*scopes.ScopeInfo, len(scopeResourceInfo.ScopeResourceMap)) + for k, v := range scopeResourceInfo.ScopeResourceMap { + scopeMap[k] = v.ScopeInfo + } + return scopeResourceInfo.ScopeIds, scopeMap, nil } diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service.go b/internal/servers/controller/handlers/authmethods/authmethod_service.go index 1349818960..7d58b43f1d 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service.go @@ -130,7 +130,7 @@ func (s Service) ListAuthMethods(ctx context.Context, req *pbs.ListAuthMethodsRe } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/authtokens/authtoken_service.go b/internal/servers/controller/handlers/authtokens/authtoken_service.go index 6034645b72..2ee04e1bf4 100644 --- a/internal/servers/controller/handlers/authtokens/authtoken_service.go +++ b/internal/servers/controller/handlers/authtokens/authtoken_service.go @@ -81,7 +81,7 @@ func (s Service) ListAuthTokens(ctx context.Context, req *pbs.ListAuthTokensRequ } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/credentialstores/credentialstore_service.go b/internal/servers/controller/handlers/credentialstores/credentialstore_service.go index 8f97b943d4..9fb8ff03ba 100644 --- a/internal/servers/controller/handlers/credentialstores/credentialstore_service.go +++ b/internal/servers/controller/handlers/credentialstores/credentialstore_service.go @@ -112,7 +112,7 @@ func (s Service) ListCredentialStores(ctx context.Context, req *pbs.ListCredenti } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/groups/group_service.go b/internal/servers/controller/handlers/groups/group_service.go index 0f562699e7..e7ffb5389f 100644 --- a/internal/servers/controller/handlers/groups/group_service.go +++ b/internal/servers/controller/handlers/groups/group_service.go @@ -92,7 +92,7 @@ func (s Service) ListGroups(ctx context.Context, req *pbs.ListGroupsRequest) (*p } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go b/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go index c83fb1518e..3f961f990e 100644 --- a/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go +++ b/internal/servers/controller/handlers/host_catalogs/host_catalog_service.go @@ -129,7 +129,7 @@ func (s Service) ListHostCatalogs(ctx context.Context, req *pbs.ListHostCatalogs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/roles/role_service.go b/internal/servers/controller/handlers/roles/role_service.go index dc94d4d50a..c9e867cb45 100644 --- a/internal/servers/controller/handlers/roles/role_service.go +++ b/internal/servers/controller/handlers/roles/role_service.go @@ -96,7 +96,7 @@ func (s Service) ListRoles(ctx context.Context, req *pbs.ListRolesRequest) (*pbs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/scopes/scope_service.go b/internal/servers/controller/handlers/scopes/scope_service.go index f35ee93ac9..58eed4ea8d 100644 --- a/internal/servers/controller/handlers/scopes/scope_service.go +++ b/internal/servers/controller/handlers/scopes/scope_service.go @@ -133,7 +133,7 @@ func (s Service) ListScopes(ctx context.Context, req *pbs.ListScopesRequest) (*p } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/sessions/session_service.go b/internal/servers/controller/handlers/sessions/session_service.go index e65d47043f..0565879966 100644 --- a/internal/servers/controller/handlers/sessions/session_service.go +++ b/internal/servers/controller/handlers/sessions/session_service.go @@ -119,6 +119,8 @@ func (s Service) GetSession(ctx context.Context, req *pbs.GetSessionRequest) (*p // ListSessions implements the interface pbs.SessionServiceServer. func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) (*pbs.ListSessionsResponse, error) { + const op = "session.(Service).ListSessions" + if err := validateListRequest(req); err != nil { return nil, err } @@ -137,17 +139,34 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) } } - scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(ctx, - s.iamRepoFn, authResults, req.GetScopeId(), resource.Session, req.GetRecursive(), false) + repo, err := s.repoFn() + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + + scopeResourceInfo, err := scopeids.GetListingResourceInformation( + ctx, + scopeids.GetListingResourceInformationInput{ + IamRepoFn: s.iamRepoFn, + AuthResults: authResults, + RootScopeId: req.GetScopeId(), + Type: resource.Session, + Recursive: req.GetRecursive(), + AuthzProtectedEntityProvider: repo, + ActionSet: IdActions, + }, + ) if err != nil { return nil, err } - // If no scopes match, return an empty response - if len(scopeIds) == 0 { + // If no scopes match or we match scopes but there are no resources in them + // that we are authorized to see, return an empty response + if len(scopeResourceInfo.ScopeIds) == 0 || + len(scopeResourceInfo.ResourceIds) == 0 { return &pbs.ListSessionsResponse{}, nil } - sesList, err := s.listFromRepo(ctx, scopeIds) + sesList, err := s.listFromRepo(ctx, scopeResourceInfo.ResourceIds) if err != nil { return nil, err } @@ -164,25 +183,14 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) Type: resource.Session, } for _, item := range sesList { - res.Id = item.GetPublicId() - res.ScopeId = item.ScopeId - authorizedActions := authResults.FetchActionSetForId(ctx, item.GetPublicId(), IdActions, auth.WithResource(&res)) - if len(authorizedActions) == 0 { - continue - } - - if authorizedActions.OnlySelf() && item.UserId != authResults.UserId { - continue - } - outputFields := authResults.FetchOutputFields(res, action.List).SelfOrDefaults(authResults.UserId) outputOpts := make([]handlers.Option, 0, 3) outputOpts = append(outputOpts, handlers.WithOutputFields(&outputFields)) if outputFields.Has(globals.ScopeField) { - outputOpts = append(outputOpts, handlers.WithScope(scopeInfoMap[item.ScopeId])) + outputOpts = append(outputOpts, handlers.WithScope(scopeResourceInfo.ScopeResourceMap[item.ScopeId].ScopeInfo)) } if outputFields.Has(globals.AuthorizedActionsField) { - outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authorizedActions.Strings())) + outputOpts = append(outputOpts, handlers.WithAuthorizedActions(scopeResourceInfo.ScopeResourceMap[item.ScopeId].Resources[item.PublicId].AuthorizedActions.Strings())) } item, err := toProto(ctx, item, outputOpts...) @@ -287,12 +295,12 @@ func (s Service) getFromRepo(ctx context.Context, id string) (*session.Session, return sess, nil } -func (s Service) listFromRepo(ctx context.Context, scopeIds []string) ([]*session.Session, error) { +func (s Service) listFromRepo(ctx context.Context, sessionIds []string) ([]*session.Session, error) { repo, err := s.repoFn() if err != nil { return nil, err } - sesList, err := repo.ListSessions(ctx, session.WithScopeIds(scopeIds)) + sesList, err := repo.ListSessions(ctx, session.WithSessionIds(sessionIds...)) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/targets/target_service.go b/internal/servers/controller/handlers/targets/target_service.go index e0b49fa7db..02b0d956c5 100644 --- a/internal/servers/controller/handlers/targets/target_service.go +++ b/internal/servers/controller/handlers/targets/target_service.go @@ -164,7 +164,7 @@ func (s Service) ListTargets(ctx context.Context, req *pbs.ListTargetsRequest) ( } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive(), false) + ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/servers/controller/handlers/users/user_service.go b/internal/servers/controller/handlers/users/user_service.go index 349dda738b..0b0e417d5e 100644 --- a/internal/servers/controller/handlers/users/user_service.go +++ b/internal/servers/controller/handlers/users/user_service.go @@ -94,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, req *pbs.ListUsersRequest) (*pbs } scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds( - ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive(), false) + ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive()) if err != nil { return nil, err } diff --git a/internal/session/query.go b/internal/session/query.go index 430ae4a08d..541c4b4759 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -118,6 +118,13 @@ select expiration_time, connection_limit, current_connection_count from session_connection_limit, session_connection_count; ` + + sessionPublicIdList = ` +select public_id, scope_id, user_id from session +%s +; +` + sessionList = ` select * from diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 6e28a42ab6..78fee5005c 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -8,6 +8,7 @@ import ( "fmt" "strings" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -205,6 +206,51 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O return &session, authzSummary, nil } +// FetchAuthzProtectedEntitiesByScope implements boundary.AuthzProtectedEntityProvider +func (r *Repository) FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error) { + const op = "session.(Repository).FetchAuthzProtectedEntityInfo" + + var where string + var args []interface{} + + inClauseCnt := 0 + + switch len(scopeIds) { + case 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "no scopes given") + case 1: + inClauseCnt += 1 + where, args = fmt.Sprintf("where scope_id = @%d", inClauseCnt), append(args, sql.Named("1", scopeIds[0])) + default: + idsInClause := make([]string, 0, len(scopeIds)) + for _, id := range scopeIds { + inClauseCnt += 1 + idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) + } + where = fmt.Sprintf("where scope_id in (%s)", strings.Join(idsInClause, ",")) + } + + q := sessionPublicIdList + query := fmt.Sprintf(q, where) + + rows, err := r.reader.Query(ctx, query, args) + if err != nil { + return nil, errors.Wrap(ctx, err, op) + } + defer rows.Close() + + sessionsMap := map[string][]boundary.AuthzProtectedEntity{} + for rows.Next() { + var ses Session + if err := r.reader.ScanRows(ctx, rows, &ses); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) + } + sessionsMap[ses.GetScopeId()] = append(sessionsMap[ses.GetScopeId()], ses) + } + + return sessionsMap, nil +} + // ListSessions will sessions. Supports the WithLimit, WithScopeId, WithSessionIds, and WithServerId options. func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Session, error) { const op = "session.(Repository).ListSessions" @@ -213,25 +259,31 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio var args []interface{} inClauseCnt := 0 - if len(opts.withScopeIds) != 0 { - switch len(opts.withScopeIds) { - case 1: + switch len(opts.withScopeIds) { + case 0: + case 1: + inClauseCnt += 1 + where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withScopeIds[0])) + default: + idsInClause := make([]string, 0, len(opts.withScopeIds)) + for _, id := range opts.withScopeIds { inClauseCnt += 1 - where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named("1", opts.withScopeIds[0])) - default: - idsInClause := make([]string, 0, len(opts.withScopeIds)) - for _, id := range opts.withScopeIds { - inClauseCnt += 1 - idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) - } - where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ","))) + idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id)) } + where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ","))) } + if opts.withUserId != "" { inClauseCnt += 1 where, args = append(where, fmt.Sprintf("user_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withUserId)) } - if len(opts.withSessionIds) > 0 { + + switch len(opts.withSessionIds) { + case 0: + case 1: + inClauseCnt += 1 + where, args = append(where, fmt.Sprintf("s.public_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withSessionIds[0])) + default: idsInClause := make([]string, 0, len(opts.withSessionIds)) for _, id := range opts.withSessionIds { inClauseCnt += 1 @@ -239,6 +291,7 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio } where = append(where, fmt.Sprintf("s.public_id in (%s)", strings.Join(idsInClause, ","))) } + if opts.withServerId != "" { inClauseCnt += 1 where, args = append(where, fmt.Sprintf("server_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withServerId)) diff --git a/internal/session/session.go b/internal/session/session.go index baaa5331f0..7026656e73 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -11,6 +11,7 @@ import ( "strings" "time" + "github.com/hashicorp/boundary/internal/boundary" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -119,13 +120,22 @@ type Session struct { tableName string `gorm:"-"` } -func (s *Session) GetPublicId() string { +func (s Session) GetPublicId() string { return s.PublicId } +func (s Session) GetScopeId() string { + return s.ScopeId +} + +func (s Session) GetUserId() string { + return s.UserId +} + var ( - _ Cloneable = (*Session)(nil) - _ db.VetForWriter = (*Session)(nil) + _ Cloneable = (*Session)(nil) + _ db.VetForWriter = (*Session)(nil) + _ boundary.AuthzProtectedEntity = (*Session)(nil) ) // New creates a new in memory session.