|
|
|
|
@ -263,7 +263,8 @@ func (r *Repository) DeleteSession(ctx context.Context, publicId string, opt ...
|
|
|
|
|
// CancelSession sets a session's state to "cancelling" in the repo. It's
|
|
|
|
|
// called when the user cancels a session and the controller wants to update the
|
|
|
|
|
// session state to "cancelling" for the given reason, so the workers can get
|
|
|
|
|
// the "cancelling signal" during their next status heartbeat.
|
|
|
|
|
// the "cancelling signal" during their next status heartbeat. CancelSession is
|
|
|
|
|
// idempotent.
|
|
|
|
|
func (r *Repository) CancelSession(ctx context.Context, sessionId string, sessionVersion uint32) (*Session, error) {
|
|
|
|
|
if sessionId == "" {
|
|
|
|
|
return nil, fmt.Errorf("cancel session: missing session id: %w", db.ErrInvalidParameter)
|
|
|
|
|
@ -605,8 +606,8 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// updateState will update the session's state using the session id and its
|
|
|
|
|
// version. States are ordered by start time descending. No options are
|
|
|
|
|
// currently supported.
|
|
|
|
|
// version. updateState is idempotent. States are ordered by start time
|
|
|
|
|
// descending. No options are currently supported.
|
|
|
|
|
func (r *Repository) updateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) {
|
|
|
|
|
if sessionId == "" {
|
|
|
|
|
return nil, nil, fmt.Errorf("update session state: missing session id %w", db.ErrInvalidParameter)
|
|
|
|
|
@ -621,45 +622,57 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
|
|
|
|
|
return nil, nil, fmt.Errorf("update session: you must call ActivateSession to update a session's state to active: %w", db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
newState, err := NewState(sessionId, s)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("update session state: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var rowsAffected int
|
|
|
|
|
updatedSession := AllocSession()
|
|
|
|
|
var returnedStates []*State
|
|
|
|
|
_, err = r.writer.DoTx(
|
|
|
|
|
_, err := r.writer.DoTx(
|
|
|
|
|
ctx,
|
|
|
|
|
db.StdRetryCnt,
|
|
|
|
|
db.ExpBackoff{},
|
|
|
|
|
func(reader db.Reader, w db.Writer) error {
|
|
|
|
|
var err error
|
|
|
|
|
// We need to update the session version as that's the aggregate
|
|
|
|
|
updatedSession.PublicId = sessionId
|
|
|
|
|
updatedSession.Version = uint32(sessionVersion) + 1
|
|
|
|
|
rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"Version"}, nil, db.WithVersion(&sessionVersion))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("unable to update session version: %w", err)
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if rowsUpdated != 1 {
|
|
|
|
|
return fmt.Errorf("updated session and %d rows updated", rowsUpdated)
|
|
|
|
|
}
|
|
|
|
|
if err := w.Create(ctx, newState); err != nil {
|
|
|
|
|
return fmt.Errorf("unable to add new state: %w", err)
|
|
|
|
|
if len(updatedSession.CtTofuToken) > 0 {
|
|
|
|
|
databaseWrapper, err := r.kms.GetWrapper(ctx, updatedSession.ScopeId, kms.KeyPurposeDatabase, kms.WithKeyId(updatedSession.KeyId))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("lookup session: unable to get database wrapper: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if err := updatedSession.decrypt(ctx, databaseWrapper); err != nil {
|
|
|
|
|
return fmt.Errorf("lookup session: cannot decrypt session value: %w", err)
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
updatedSession.CtTofuToken = nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rowsAffected, err = w.Exec(updateSessionState, []interface{}{sessionId, s.String()})
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("unable to update session %s state to %s: %w", sessionId, s.String(), err)
|
|
|
|
|
}
|
|
|
|
|
if rowsAffected != 0 && rowsAffected != 1 {
|
|
|
|
|
return fmt.Errorf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected)
|
|
|
|
|
}
|
|
|
|
|
returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if len(returnedStates) < 1 && returnedStates[0].Status != s {
|
|
|
|
|
return fmt.Errorf("failed to update %s to a state of %s", sessionId, s.String())
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err)
|
|
|
|
|
}
|
|
|
|
|
if len(updatedSession.CtTofuToken) == 0 {
|
|
|
|
|
updatedSession.CtTofuToken = nil
|
|
|
|
|
}
|
|
|
|
|
return &updatedSession, returnedStates, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|