diff --git a/internal/session/query.go b/internal/session/query.go index 2e5817d63e..967a758ec6 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -18,6 +18,27 @@ with not_active as ( select * from not_active; ` + // updateSessionState checks that we don't already have a row for the new + // state before inserting a new state. + updateSessionState = ` +insert into session_state(session_id, state) +select + $1::text, $2 +from + session s +where + s.public_id = $1::text and + s.public_id not in ( + select + session_id + from + session_state + where + session_id = $1::text and + state = $2 + ) +` + terminateSessionCte = ` insert into session_state with terminated as ( diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 66a3de601e..07b68684b5 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -263,7 +263,8 @@ func (r *Repository) DeleteSession(ctx context.Context, publicId string, opt ... // CancelSession sets a session's state to "cancelling" in the repo. It's // called when the user cancels a session and the controller wants to update the // session state to "cancelling" for the given reason, so the workers can get -// the "cancelling signal" during their next status heartbeat. +// the "cancelling signal" during their next status heartbeat. CancelSession is +// idempotent. func (r *Repository) CancelSession(ctx context.Context, sessionId string, sessionVersion uint32) (*Session, error) { if sessionId == "" { return nil, fmt.Errorf("cancel session: missing session id: %w", db.ErrInvalidParameter) @@ -605,8 +606,8 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess } // updateState will update the session's state using the session id and its -// version. States are ordered by start time descending. No options are -// currently supported. +// version. updateState is idempotent. States are ordered by start time +// descending. No options are currently supported. func (r *Repository) updateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) { if sessionId == "" { return nil, nil, fmt.Errorf("update session state: missing session id %w", db.ErrInvalidParameter) @@ -621,45 +622,57 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV return nil, nil, fmt.Errorf("update session: you must call ActivateSession to update a session's state to active: %w", db.ErrInvalidParameter) } - newState, err := NewState(sessionId, s) - if err != nil { - return nil, nil, fmt.Errorf("update session state: %w", err) - } - + var rowsAffected int updatedSession := AllocSession() var returnedStates []*State - _, err = r.writer.DoTx( + _, err := r.writer.DoTx( ctx, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { + var err error // We need to update the session version as that's the aggregate updatedSession.PublicId = sessionId updatedSession.Version = uint32(sessionVersion) + 1 rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"Version"}, nil, db.WithVersion(&sessionVersion)) if err != nil { - return fmt.Errorf("unable to update session version: %w", err) + return err } if rowsUpdated != 1 { return fmt.Errorf("updated session and %d rows updated", rowsUpdated) } - if err := w.Create(ctx, newState); err != nil { - return fmt.Errorf("unable to add new state: %w", err) + if len(updatedSession.CtTofuToken) > 0 { + databaseWrapper, err := r.kms.GetWrapper(ctx, updatedSession.ScopeId, kms.KeyPurposeDatabase, kms.WithKeyId(updatedSession.KeyId)) + if err != nil { + return fmt.Errorf("lookup session: unable to get database wrapper: %w", err) + } + if err := updatedSession.decrypt(ctx, databaseWrapper); err != nil { + return fmt.Errorf("lookup session: cannot decrypt session value: %w", err) + } + } else { + updatedSession.CtTofuToken = nil } + rowsAffected, err = w.Exec(updateSessionState, []interface{}{sessionId, s.String()}) + if err != nil { + return fmt.Errorf("unable to update session %s state to %s: %w", sessionId, s.String(), err) + } + if rowsAffected != 0 && rowsAffected != 1 { + return fmt.Errorf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected) + } returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) if err != nil { return err } + if len(returnedStates) < 1 && returnedStates[0].Status != s { + return fmt.Errorf("failed to update %s to a state of %s", sessionId, s.String()) + } return nil }, ) if err != nil { return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err) } - if len(updatedSession.CtTofuToken) == 0 { - updatedSession.CtTofuToken = nil - } return &updatedSession, returnedStates, nil } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 690eec0333..2c938c8aca 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -287,7 +287,7 @@ func TestRepository_CreateSession(t *testing.T) { } } -func TestRepository_UpdateState(t *testing.T) { +func TestRepository_updateState(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) @@ -849,7 +849,6 @@ func TestRepository_CancelSession(t *testing.T) { default: version = tt.session.Version } - s, err := repo.CancelSession(context.Background(), id, version) if tt.wantErr { require.Error(err) @@ -862,6 +861,17 @@ func TestRepository_CancelSession(t *testing.T) { require.NotNil(s) require.NotNil(s.States) assert.Equal(StatusCancelling, s.States[0].Status) + + stateCnt := len(s.States) + origStartTime := s.States[0].StartTime + // check idempontency + s2, err := repo.CancelSession(context.Background(), id, version+1) + require.NoError(err) + require.NotNil(s2) + require.NotNil(s2.States) + assert.Equal(stateCnt, len(s2.States)) + assert.Equal(StatusCancelling, s.States[0].Status) + assert.Equal(origStartTime, s2.States[0].StartTime) }) } }