Reorder authz check for sessions (#2042)

* Reorder session authz checks

This implements a mechanism for service handlers to pass a repo
satisfying an interface into the scope IDs listing function. The purpose
of this interface is to allow the same function to perform authorization
checking against resources rather than requiring service handlers to
fetch all resources in the determined set of scopes and then perform
this checking. While this could be done in each service handler, this
approach centralizes the logic, which has the nice benefit of removing
some boilerplate from service handler list functions that have adapted
to it.

Tests haven't changed, which is expected -- the function should return
exactly what was being returned before, simply faster.
pull/2049/head
Jeff Mitchell 4 years ago committed by Timothy Messier
parent 0f961eb5f6
commit b41e983503
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -2,7 +2,9 @@
// define the Boundary domain.
package boundary
import "github.com/hashicorp/boundary/internal/db/timestamp"
import (
"github.com/hashicorp/boundary/internal/db/timestamp"
)
// An Entity is an object distinguished by its identity, rather than its
// attributes. It can contain value objects and other entities.
@ -25,3 +27,12 @@ type Resource interface {
GetName() string
GetDescription() string
}
// AuthzProtectedEntity is used by some functions (primarily
// scopeids.AuthzProtectedEntityProvider-conforming implementations) to deliver
// some common information necessary for calculating authz.
type AuthzProtectedEntity interface {
Entity
GetScopeId() string
GetUserId() string
}

@ -196,7 +196,7 @@ func Verify(ctx context.Context, opt ...Option) (ret VerifyResults) {
return
}
if scp == nil {
ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprint("non-existent scope $q", ret.Scope.Id))
ret.Error = errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("non-existent scope %q", ret.Scope.Id))
return
}
ret.Scope = &scopes.ScopeInfo{

@ -3,6 +3,7 @@ package scopeids
import (
"context"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/perms"
@ -15,70 +16,127 @@ import (
"github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes"
)
// GetListingScopeIds, given common parameters for List calls, returns the set of scope
// IDs in which to search for resources. It also returns a memoized map of the
// scopes to their info for populating returned values.
//
// Note: This was originally pulled out 1:1 from the role service. It and other
// tests in the other service handlers test this function extensively as it
// forms the basis for all recursive listing tests; see those tests for list
// functionality in the various service handlers.
func GetListingScopeIds(
type authzProtectedEntityProvider interface {
// Fetches basic resource info for the given scopes. Note that this is a
// "where clause" style of argument: if the set of scopes is populated these
// are the scopes to limit to (e.g. to put in a where clause). An empty set
// of scopes means to look in *all* scopes, not none!
FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error)
}
// ResourceInfo contains information about a particular resource
type ResourceInfo struct {
AuthorizedActions action.ActionSet
}
// ScopeInfoWithResourceIds contains information about a scope and the resources
// found within it
type ScopeInfoWithResourceIds struct {
*scopes.ScopeInfo
Resources map[string]ResourceInfo
}
// GetListingResourceInformationInput contains input parameters to the function
type GetListingResourceInformationInput struct {
// An IAM repo function to use for a scope listing call
IamRepoFn common.IamRepoFactory
// The original auth results from the list command
AuthResults auth.VerifyResults
// The scope ID to use, or the starting point for a recursive search
RootScopeId string
// The type of resource being listed
Type resource.Type
// Whether the search is recursive
Recursive bool
// A repo to fetch resources
AuthzProtectedEntityProvider authzProtectedEntityProvider
// The available actions for the resource type
ActionSet action.ActionSet
}
// GetListingResourceInformationOutput contains results from the function
type GetListingResourceInformationOutput struct {
// The calculated list of relevant scope IDs
ScopeIds []string
// The specific resource IDs calculated to be authorized for listing
ResourceIds []string
// A map of scope ID to scope information and a map of resource IDs in that
// scope and specific information about that resource, such as available
// actions
ScopeResourceMap map[string]*ScopeInfoWithResourceIds
}
// GetListingResourceInformation, given common parameters for List calls,
// returns useful information: the set of scope IDs in which to search for
// resources; the IDs of the resources known to be authorized for that user; and
// a memoized map of the scopes to their info for populating returned values.
func GetListingResourceInformation(
// The context to use when listing in the DB, if required
ctx context.Context,
// An IAM repo function to use for a listing call, if required
repoFn common.IamRepoFactory,
// The original auth results from the list command
authResults auth.VerifyResults,
// The scope ID to use, or to use as the starting point for a recursive
// search
rootScopeId string,
// The type of resource we are listing
typ resource.Type,
// Whether or not the search should be recursive
recursive bool,
// Whether to only return scopes with exact permissions, or whether parent
// scopes with appropriate permissions are sufficient
directOnly bool,
) ([]string, map[string]*scopes.ScopeInfo, error) {
const op = "GetListingScopeIds"
// The input struct
input GetListingResourceInformationInput,
) (*GetListingResourceInformationOutput, error) {
const op = "scopeids.GetListingResourceInformation"
output := new(GetListingResourceInformationOutput)
// Validation
switch {
case typ == resource.Unknown:
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource")
case repoFn == nil:
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo")
case rootScopeId == "":
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id")
case authResults.Scope == nil:
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results")
case input.Type == resource.Unknown:
return nil, errors.New(ctx, errors.InvalidParameter, op, "unknown resource")
case input.IamRepoFn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil iam repo")
case input.RootScopeId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing root scope id")
case input.AuthResults.Scope == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil scope in auth results")
case !input.Recursive && input.AuthResults.Scope.Id != input.RootScopeId:
return nil, errors.New(ctx, errors.InvalidParameter, op, "non-recursive search but auth results scope does not match root scope")
}
// This will be used to memoize scope info so we can put the right scope
// info for each returned value
output.ScopeResourceMap = map[string]*ScopeInfoWithResourceIds{}
// Base case: if not recursive, return the scope we were given and the
// already-looked-up info
if !recursive {
return []string{authResults.Scope.Id}, map[string]*scopes.ScopeInfo{authResults.Scope.Id: authResults.Scope}, nil
if !input.Recursive {
output.ScopeResourceMap[input.AuthResults.Scope.Id] = &ScopeInfoWithResourceIds{ScopeInfo: input.AuthResults.Scope}
// If we don't have information do to the resource lookup ourselves,
// return what we have
if input.AuthzProtectedEntityProvider == nil {
output.ScopeIds = []string{input.AuthResults.Scope.Id}
return output, nil
}
// Otherwise filter on this one scope and return
if err := filterAuthorizedResourceIds(ctx, input, output); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources"))
}
return output, nil
}
// This will be used to memoize scope info so we can put the right scope
// info for each returned value
scopeInfoMap := map[string]*scopes.ScopeInfo{}
repo, err := repoFn()
repo, err := input.IamRepoFn()
if err != nil {
return nil, nil, err
return nil, err
}
// Get all scopes recursively. Start at global because we need to take into
// account permissions in parent scopes even if they want to scale back the
// returned information to a child scope and its children.
scps, err := repo.ListScopesRecursively(ctx, scope.Global.String())
if err != nil {
return nil, nil, err
return nil, err
}
res := perms.Resource{
Type: typ,
Type: input.Type,
}
// For each scope, see if we have permission to list that type in that
// scope
@ -88,7 +146,7 @@ func GetListingScopeIds(
for _, scp := range scps {
scpId := scp.GetPublicId()
res.ScopeId = scpId
aSet := authResults.FetchActionSetForType(ctx,
aSet := input.AuthResults.FetchActionSetForType(ctx,
// This is overridden by WithResource
resource.Unknown,
action.ActionSet{action.List},
@ -97,16 +155,14 @@ func GetListingScopeIds(
switch len(aSet) {
case 0:
// Defer until we've read all scopes. We do this because if the
// ordering coming back isn't in parent-first ording our map
// ordering coming back isn't in parent-first ordering our map
// lookup might fail.
if !directOnly {
deferredScopes = append(deferredScopes, scp)
}
deferredScopes = append(deferredScopes, scp)
case 1:
if aSet[0] != action.List {
return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected action in set")
return nil, errors.New(ctx, errors.Internal, op, "unexpected action in set")
}
if scopeInfoMap[scpId] == nil {
if output.ScopeResourceMap[scpId] == nil {
scopeInfo := &scopes.ScopeInfo{
Id: scp.GetPublicId(),
Type: scp.GetType(),
@ -114,13 +170,13 @@ func GetListingScopeIds(
Description: scp.GetDescription(),
ParentScopeId: scp.GetParentId(),
}
scopeInfoMap[scpId] = scopeInfo
output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo}
}
if scpId == scope.Global.String() {
globalHasList = true
}
default:
return nil, nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set")
return nil, errors.New(ctx, errors.Internal, op, "unexpected number of actions back in set")
}
}
@ -129,9 +185,9 @@ func GetListingScopeIds(
// If they had list on global scope anything else is automatically
// included; otherwise if they had list on the parent scope, this
// scope is included in the map and is sufficient here.
if globalHasList || scopeInfoMap[scp.GetParentId()] != nil {
if globalHasList || output.ScopeResourceMap[scp.GetParentId()] != nil {
scpId := scp.GetPublicId()
if scopeInfoMap[scpId] == nil {
if output.ScopeResourceMap[scpId] == nil {
scopeInfo := &scopes.ScopeInfo{
Id: scp.GetPublicId(),
Type: scp.GetType(),
@ -139,21 +195,15 @@ func GetListingScopeIds(
Description: scp.GetDescription(),
ParentScopeId: scp.GetParentId(),
}
scopeInfoMap[scpId] = scopeInfo
output.ScopeResourceMap[scpId] = &ScopeInfoWithResourceIds{ScopeInfo: scopeInfo}
}
}
}
// If we have nothing in scopeInfoMap at this point, we aren't authorized
// anywhere so return 403.
if len(scopeInfoMap) == 0 {
return nil, nil, handlers.ForbiddenError()
}
// Now elide out any that aren't under the root scope ID
elideScopes := make([]string, 0, len(scopeInfoMap))
for scpId, scp := range scopeInfoMap {
switch rootScopeId {
elideScopes := make([]string, 0, len(output.ScopeResourceMap))
for scpId, scp := range output.ScopeResourceMap {
switch input.RootScopeId {
// If the root is global, it matches
case scope.Global.String():
// If the current scope matches the root, it matches
@ -168,13 +218,134 @@ func GetListingScopeIds(
}
for _, scpId := range elideScopes {
delete(scopeInfoMap, scpId)
delete(output.ScopeResourceMap, scpId)
}
// If we have nothing in scopeInfoMap at this point, we aren't authorized
// anywhere so return 403.
if len(output.ScopeResourceMap) == 0 {
return nil, handlers.ForbiddenError()
}
scopeIds := make([]string, 0, len(scopeInfoMap))
for k := range scopeInfoMap {
scopeIds = append(scopeIds, k)
if input.AuthzProtectedEntityProvider == nil {
output.populateScopeIdsFromScopeResourceMap()
return output, nil
}
return scopeIds, scopeInfoMap, nil
if err := filterAuthorizedResourceIds(ctx, input, output); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error filtering to only authorized resources"))
}
return output, nil
}
// filterAuthorizedResourceIds calls the passed in function to get IDs for
// resources in the given scopes and then figures out which ones are actually
// authorized for listing by the user.
//
// It also populates the scope IDs in the output
func filterAuthorizedResourceIds(
// The context to use when listing in the DB, if required
ctx context.Context,
// The input struct
input GetListingResourceInformationInput,
// The scope information to fill out
output *GetListingResourceInformationOutput,
) error {
const op = "scopeids.filterAuthorizedResources"
// Populate scopeIds and determine if we found global
output.populateScopeIdsFromScopeResourceMap()
// The calling function is giving us a complete set with any recursive
// lookup already performed (that's the point of the function).
scopedResourceInfo, err := input.AuthzProtectedEntityProvider.FetchAuthzProtectedEntitiesByScope(ctx, output.ScopeIds)
if err != nil {
return errors.Wrap(ctx, err, op)
}
res := perms.Resource{
Type: resource.Session,
}
// Now run authorization checks against each so we know if there is a point
// in fetching the full resource, and cache the authorized actions
for scopeId, resourceInfos := range scopedResourceInfo {
for _, resourceInfo := range resourceInfos {
res.Id = resourceInfo.GetPublicId()
res.ScopeId = scopeId
authorizedActions := input.AuthResults.FetchActionSetForId(ctx, resourceInfo.GetPublicId(), input.ActionSet, auth.WithResource(&res))
if len(authorizedActions) == 0 {
continue
}
if resourceInfo.GetUserId() != "" {
if authorizedActions.OnlySelf() && resourceInfo.GetUserId() != input.AuthResults.UserId {
continue
}
}
if output.ScopeResourceMap[scopeId].Resources == nil {
output.ScopeResourceMap[scopeId].Resources = make(map[string]ResourceInfo)
}
output.ScopeResourceMap[scopeId].Resources[resourceInfo.GetPublicId()] = ResourceInfo{AuthorizedActions: authorizedActions}
output.ResourceIds = append(output.ResourceIds, resourceInfo.GetPublicId())
}
}
return nil
}
// populateScopeIdsFromScopeResourceMap populates the ScopeIds field and returns
// whether global scope was found
func (i *GetListingResourceInformationOutput) populateScopeIdsFromScopeResourceMap() {
for k := range i.ScopeResourceMap {
i.ScopeIds = append(i.ScopeIds, k)
}
}
// GetListingScopeIds is provided for backwards compatibility with existing
// services; services should eventually migrate to
// GetListingResourceInformation.
func GetListingScopeIds(
// The context to use when listing in the DB, if required
ctx context.Context,
// An IAM repo function to use for a listing call, if required
repoFn common.IamRepoFactory,
// The original auth results from the list command
authResults auth.VerifyResults,
// The scope ID to use, or to use as the starting point for a recursive
// search
rootScopeId string,
// The type of resource we are listing
typ resource.Type,
// Whether or not the search should be recursive
recursive bool,
) ([]string, map[string]*scopes.ScopeInfo, error) {
const op = "scopeids.GetListingScopeIds"
scopeResourceInfo, err := GetListingResourceInformation(ctx,
GetListingResourceInformationInput{
IamRepoFn: repoFn,
AuthResults: authResults,
RootScopeId: rootScopeId,
Type: typ,
Recursive: recursive,
},
)
if err != nil {
if err == handlers.ForbiddenError() {
return nil, nil, err
}
return nil, nil, errors.Wrap(ctx, err, op)
}
if len(scopeResourceInfo.ScopeIds) == 0 {
// This should have already happened in the other function, but...
return nil, nil, handlers.ForbiddenError()
}
scopeMap := make(map[string]*scopes.ScopeInfo, len(scopeResourceInfo.ScopeResourceMap))
for k, v := range scopeResourceInfo.ScopeResourceMap {
scopeMap[k] = v.ScopeInfo
}
return scopeResourceInfo.ScopeIds, scopeMap, nil
}

@ -130,7 +130,7 @@ func (s Service) ListAuthMethods(ctx context.Context, req *pbs.ListAuthMethodsRe
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive(), false)
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthMethod, req.GetRecursive())
if err != nil {
return nil, err
}

@ -81,7 +81,7 @@ func (s Service) ListAuthTokens(ctx context.Context, req *pbs.ListAuthTokensRequ
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive(), false)
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.AuthToken, req.GetRecursive())
if err != nil {
return nil, err
}

@ -112,7 +112,7 @@ func (s Service) ListCredentialStores(ctx context.Context, req *pbs.ListCredenti
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive(), false)
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.CredentialStore, req.GetRecursive())
if err != nil {
return nil, err
}

@ -92,7 +92,7 @@ func (s Service) ListGroups(ctx context.Context, req *pbs.ListGroupsRequest) (*p
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive(), false)
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Group, req.GetRecursive())
if err != nil {
return nil, err
}

@ -129,7 +129,7 @@ func (s Service) ListHostCatalogs(ctx context.Context, req *pbs.ListHostCatalogs
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive(), false)
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.HostCatalog, req.GetRecursive())
if err != nil {
return nil, err
}

@ -96,7 +96,7 @@ func (s Service) ListRoles(ctx context.Context, req *pbs.ListRolesRequest) (*pbs
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive(), false)
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Role, req.GetRecursive())
if err != nil {
return nil, err
}

@ -133,7 +133,7 @@ func (s Service) ListScopes(ctx context.Context, req *pbs.ListScopesRequest) (*p
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive(), false)
ctx, s.repoFn, authResults, req.GetScopeId(), resource.Scope, req.GetRecursive())
if err != nil {
return nil, err
}

@ -119,6 +119,8 @@ func (s Service) GetSession(ctx context.Context, req *pbs.GetSessionRequest) (*p
// ListSessions implements the interface pbs.SessionServiceServer.
func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest) (*pbs.ListSessionsResponse, error) {
const op = "session.(Service).ListSessions"
if err := validateListRequest(req); err != nil {
return nil, err
}
@ -137,17 +139,34 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest)
}
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(ctx,
s.iamRepoFn, authResults, req.GetScopeId(), resource.Session, req.GetRecursive(), false)
repo, err := s.repoFn()
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
scopeResourceInfo, err := scopeids.GetListingResourceInformation(
ctx,
scopeids.GetListingResourceInformationInput{
IamRepoFn: s.iamRepoFn,
AuthResults: authResults,
RootScopeId: req.GetScopeId(),
Type: resource.Session,
Recursive: req.GetRecursive(),
AuthzProtectedEntityProvider: repo,
ActionSet: IdActions,
},
)
if err != nil {
return nil, err
}
// If no scopes match, return an empty response
if len(scopeIds) == 0 {
// If no scopes match or we match scopes but there are no resources in them
// that we are authorized to see, return an empty response
if len(scopeResourceInfo.ScopeIds) == 0 ||
len(scopeResourceInfo.ResourceIds) == 0 {
return &pbs.ListSessionsResponse{}, nil
}
sesList, err := s.listFromRepo(ctx, scopeIds)
sesList, err := s.listFromRepo(ctx, scopeResourceInfo.ResourceIds)
if err != nil {
return nil, err
}
@ -164,25 +183,14 @@ func (s Service) ListSessions(ctx context.Context, req *pbs.ListSessionsRequest)
Type: resource.Session,
}
for _, item := range sesList {
res.Id = item.GetPublicId()
res.ScopeId = item.ScopeId
authorizedActions := authResults.FetchActionSetForId(ctx, item.GetPublicId(), IdActions, auth.WithResource(&res))
if len(authorizedActions) == 0 {
continue
}
if authorizedActions.OnlySelf() && item.UserId != authResults.UserId {
continue
}
outputFields := authResults.FetchOutputFields(res, action.List).SelfOrDefaults(authResults.UserId)
outputOpts := make([]handlers.Option, 0, 3)
outputOpts = append(outputOpts, handlers.WithOutputFields(&outputFields))
if outputFields.Has(globals.ScopeField) {
outputOpts = append(outputOpts, handlers.WithScope(scopeInfoMap[item.ScopeId]))
outputOpts = append(outputOpts, handlers.WithScope(scopeResourceInfo.ScopeResourceMap[item.ScopeId].ScopeInfo))
}
if outputFields.Has(globals.AuthorizedActionsField) {
outputOpts = append(outputOpts, handlers.WithAuthorizedActions(authorizedActions.Strings()))
outputOpts = append(outputOpts, handlers.WithAuthorizedActions(scopeResourceInfo.ScopeResourceMap[item.ScopeId].Resources[item.PublicId].AuthorizedActions.Strings()))
}
item, err := toProto(ctx, item, outputOpts...)
@ -287,12 +295,12 @@ func (s Service) getFromRepo(ctx context.Context, id string) (*session.Session,
return sess, nil
}
func (s Service) listFromRepo(ctx context.Context, scopeIds []string) ([]*session.Session, error) {
func (s Service) listFromRepo(ctx context.Context, sessionIds []string) ([]*session.Session, error) {
repo, err := s.repoFn()
if err != nil {
return nil, err
}
sesList, err := repo.ListSessions(ctx, session.WithScopeIds(scopeIds))
sesList, err := repo.ListSessions(ctx, session.WithSessionIds(sessionIds...))
if err != nil {
return nil, err
}

@ -164,7 +164,7 @@ func (s Service) ListTargets(ctx context.Context, req *pbs.ListTargetsRequest) (
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive(), false)
ctx, s.iamRepoFn, authResults, req.GetScopeId(), resource.Target, req.GetRecursive())
if err != nil {
return nil, err
}

@ -94,7 +94,7 @@ func (s Service) ListUsers(ctx context.Context, req *pbs.ListUsersRequest) (*pbs
}
scopeIds, scopeInfoMap, err := scopeids.GetListingScopeIds(
ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive(), false)
ctx, s.repoFn, authResults, req.GetScopeId(), resource.User, req.GetRecursive())
if err != nil {
return nil, err
}

@ -118,6 +118,13 @@ select expiration_time, connection_limit, current_connection_count
from
session_connection_limit, session_connection_count;
`
sessionPublicIdList = `
select public_id, scope_id, user_id from session
%s
;
`
sessionList = `
select *
from

@ -8,6 +8,7 @@ import (
"fmt"
"strings"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
@ -205,6 +206,51 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O
return &session, authzSummary, nil
}
// FetchAuthzProtectedEntitiesByScope implements boundary.AuthzProtectedEntityProvider
func (r *Repository) FetchAuthzProtectedEntitiesByScope(ctx context.Context, scopeIds []string) (map[string][]boundary.AuthzProtectedEntity, error) {
const op = "session.(Repository).FetchAuthzProtectedEntityInfo"
var where string
var args []interface{}
inClauseCnt := 0
switch len(scopeIds) {
case 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "no scopes given")
case 1:
inClauseCnt += 1
where, args = fmt.Sprintf("where scope_id = @%d", inClauseCnt), append(args, sql.Named("1", scopeIds[0]))
default:
idsInClause := make([]string, 0, len(scopeIds))
for _, id := range scopeIds {
inClauseCnt += 1
idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id))
}
where = fmt.Sprintf("where scope_id in (%s)", strings.Join(idsInClause, ","))
}
q := sessionPublicIdList
query := fmt.Sprintf(q, where)
rows, err := r.reader.Query(ctx, query, args)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}
defer rows.Close()
sessionsMap := map[string][]boundary.AuthzProtectedEntity{}
for rows.Next() {
var ses Session
if err := r.reader.ScanRows(ctx, rows, &ses); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed"))
}
sessionsMap[ses.GetScopeId()] = append(sessionsMap[ses.GetScopeId()], ses)
}
return sessionsMap, nil
}
// ListSessions will sessions. Supports the WithLimit, WithScopeId, WithSessionIds, and WithServerId options.
func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Session, error) {
const op = "session.(Repository).ListSessions"
@ -213,25 +259,31 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio
var args []interface{}
inClauseCnt := 0
if len(opts.withScopeIds) != 0 {
switch len(opts.withScopeIds) {
case 1:
switch len(opts.withScopeIds) {
case 0:
case 1:
inClauseCnt += 1
where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withScopeIds[0]))
default:
idsInClause := make([]string, 0, len(opts.withScopeIds))
for _, id := range opts.withScopeIds {
inClauseCnt += 1
where, args = append(where, fmt.Sprintf("scope_id = @%d", inClauseCnt)), append(args, sql.Named("1", opts.withScopeIds[0]))
default:
idsInClause := make([]string, 0, len(opts.withScopeIds))
for _, id := range opts.withScopeIds {
inClauseCnt += 1
idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id))
}
where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ",")))
idsInClause, args = append(idsInClause, fmt.Sprintf("@%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), id))
}
where = append(where, fmt.Sprintf("scope_id in (%s)", strings.Join(idsInClause, ",")))
}
if opts.withUserId != "" {
inClauseCnt += 1
where, args = append(where, fmt.Sprintf("user_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withUserId))
}
if len(opts.withSessionIds) > 0 {
switch len(opts.withSessionIds) {
case 0:
case 1:
inClauseCnt += 1
where, args = append(where, fmt.Sprintf("s.public_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withSessionIds[0]))
default:
idsInClause := make([]string, 0, len(opts.withSessionIds))
for _, id := range opts.withSessionIds {
inClauseCnt += 1
@ -239,6 +291,7 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio
}
where = append(where, fmt.Sprintf("s.public_id in (%s)", strings.Join(idsInClause, ",")))
}
if opts.withServerId != "" {
inClauseCnt += 1
where, args = append(where, fmt.Sprintf("server_id = @%d", inClauseCnt)), append(args, sql.Named(fmt.Sprintf("%d", inClauseCnt), opts.withServerId))

@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
@ -119,13 +120,22 @@ type Session struct {
tableName string `gorm:"-"`
}
func (s *Session) GetPublicId() string {
func (s Session) GetPublicId() string {
return s.PublicId
}
func (s Session) GetScopeId() string {
return s.ScopeId
}
func (s Session) GetUserId() string {
return s.UserId
}
var (
_ Cloneable = (*Session)(nil)
_ db.VetForWriter = (*Session)(nil)
_ Cloneable = (*Session)(nil)
_ db.VetForWriter = (*Session)(nil)
_ boundary.AuthzProtectedEntity = (*Session)(nil)
)
// New creates a new in memory session.

Loading…
Cancel
Save