You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/session/repository.go

195 lines
6.1 KiB

package session
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/perms"
"github.com/hashicorp/boundary/internal/types/action"
"github.com/hashicorp/boundary/internal/types/resource"
)
// Clonable provides a cloning interface
type Cloneable interface {
Clone() interface{}
}
// Repository is the session database repository
type Repository struct {
reader db.Reader
writer db.Writer
kms *kms.Kms
// defaultLimit provides a default for limiting the number of results returned from the repo
defaultLimit int
permissions *perms.UserPermissions
}
// RepositoryFactory is a function that creates a Repository.
type RepositoryFactory func(opt ...Option) (*Repository, error)
// NewRepository creates a new session Repository. Supports the options: WithLimit
// which sets a default limit on results returned by repo operations.
func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repository, error) {
const op = "session.NewRepository"
if r == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil reader")
}
if w == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil writer")
}
if kms == nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "nil kms")
}
opts := getOpts(opt...)
if opts.withLimit == 0 {
// zero signals the boundary defaults should be used.
opts.withLimit = db.DefaultLimit
}
if opts.withPermissions != nil {
for _, p := range opts.withPermissions.Permissions {
if p.Resource != resource.Session {
return nil, errors.New(ctx, errors.InvalidParameter, op, "permission for incorrect resource")
}
}
}
return &Repository{
reader: r,
writer: w,
kms: kms,
defaultLimit: opts.withLimit,
permissions: opts.withPermissions,
}, nil
}
func (r *Repository) listPermissionWhereClauses() ([]string, []interface{}) {
var where []string
var args []interface{}
if r.permissions == nil {
return where, args
}
inClauseCnt := 0
for _, p := range r.permissions.Permissions {
if p.Action != action.List {
continue
}
inClauseCnt++
var clauses []string
clauses = append(clauses, fmt.Sprintf("project_id = @project_id_%d", inClauseCnt))
args = append(args, sql.Named(fmt.Sprintf("project_id_%d", inClauseCnt), p.ScopeId))
if len(p.ResourceIds) > 0 {
clauses = append(clauses, fmt.Sprintf("public_id = any(@public_id_%d)", inClauseCnt))
args = append(args, sql.Named(fmt.Sprintf("public_id_%d", inClauseCnt), "{"+strings.Join(p.ResourceIds, ",")+"}"))
}
if p.OnlySelf {
inClauseCnt++
clauses = append(clauses, fmt.Sprintf("user_id = @user_id_%d", inClauseCnt))
args = append(args, sql.Named(fmt.Sprintf("user_id_%d", inClauseCnt), r.permissions.UserId))
}
where = append(where, fmt.Sprintf("(%s)", strings.Join(clauses, " and ")))
}
return where, args
}
func (r *Repository) convertToSessions(ctx context.Context, sessionList []*sessionListView, opt ...Option) ([]*Session, error) {
const op = "session.(Repository).convertToSessions"
opts := getOpts(opt...)
if len(sessionList) == 0 {
return nil, nil
}
sessions := []*Session{}
// deduplication map of states by end_time
states := map[*timestamp.Timestamp]*State{}
var prevSessionId string
var workingSession *Session
for _, sv := range sessionList {
if sv.PublicId != prevSessionId {
if prevSessionId != "" {
for _, s := range states {
workingSession.States = append(workingSession.States, s)
}
states = map[*timestamp.Timestamp]*State{}
sort.Slice(workingSession.States, func(i, j int) bool {
return workingSession.States[i].StartTime.GetTimestamp().AsTime().After(workingSession.States[j].StartTime.GetTimestamp().AsTime())
})
sessions = append(sessions, workingSession)
}
prevSessionId = sv.PublicId
workingSession = &Session{
PublicId: sv.PublicId,
UserId: sv.UserId,
HostId: sv.HostId,
TargetId: sv.TargetId,
HostSetId: sv.HostSetId,
AuthTokenId: sv.AuthTokenId,
ProjectId: sv.ProjectId,
Certificate: sv.Certificate,
ExpirationTime: sv.ExpirationTime,
CtTofuToken: sv.CtTofuToken,
TofuToken: sv.TofuToken, // will always be nil since it's not stored in the database.
TerminationReason: sv.TerminationReason,
CreateTime: sv.CreateTime,
UpdateTime: sv.UpdateTime,
Version: sv.Version,
Endpoint: sv.Endpoint,
ConnectionLimit: sv.ConnectionLimit,
KeyId: sv.KeyId,
}
if opts.withListingConvert {
workingSession.CtTofuToken = nil // CtTofuToken should not returned in lists
workingSession.TofuToken = nil // TofuToken should not returned in lists
workingSession.KeyId = "" // KeyId should not be returned in lists
} else {
if len(workingSession.CtTofuToken) > 0 {
databaseWrapper, err := r.kms.GetWrapper(ctx, workingSession.ProjectId, kms.KeyPurposeDatabase)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
if err := workingSession.decrypt(ctx, databaseWrapper); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("cannot decrypt session value"))
}
} else {
workingSession.CtTofuToken = nil
}
}
}
if _, ok := states[sv.EndTime]; !ok {
states[sv.EndTime] = &State{
SessionId: sv.PublicId,
Status: Status(sv.Status),
PreviousEndTime: sv.PreviousEndTime,
StartTime: sv.StartTime,
EndTime: sv.EndTime,
}
}
}
for _, s := range states {
workingSession.States = append(workingSession.States, s)
}
sort.Slice(workingSession.States, func(i, j int) bool {
return workingSession.States[i].StartTime.GetTimestamp().AsTime().After(workingSession.States[j].StartTime.GetTimestamp().AsTime())
})
sessions = append(sessions, workingSession)
return sessions, nil
}