refactor CancelSession and updateStates to be idempotent (#390)

pull/393/head
Jim 5 years ago committed by GitHub
parent 72de1a6916
commit b334aeff41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,27 @@ with not_active as (
select * from not_active;
`
// updateSessionState checks that we don't already have a row for the new
// state before inserting a new state.
updateSessionState = `
insert into session_state(session_id, state)
select
$1::text, $2
from
session s
where
s.public_id = $1::text and
s.public_id not in (
select
session_id
from
session_state
where
session_id = $1::text and
state = $2
)
`
terminateSessionCte = `
insert into session_state
with terminated as (

@ -263,7 +263,8 @@ func (r *Repository) DeleteSession(ctx context.Context, publicId string, opt ...
// CancelSession sets a session's state to "cancelling" in the repo. It's
// called when the user cancels a session and the controller wants to update the
// session state to "cancelling" for the given reason, so the workers can get
// the "cancelling signal" during their next status heartbeat.
// the "cancelling signal" during their next status heartbeat. CancelSession is
// idempotent.
func (r *Repository) CancelSession(ctx context.Context, sessionId string, sessionVersion uint32) (*Session, error) {
if sessionId == "" {
return nil, fmt.Errorf("cancel session: missing session id: %w", db.ErrInvalidParameter)
@ -605,8 +606,8 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess
}
// updateState will update the session's state using the session id and its
// version. States are ordered by start time descending. No options are
// currently supported.
// version. updateState is idempotent. States are ordered by start time
// descending. No options are currently supported.
func (r *Repository) updateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) {
if sessionId == "" {
return nil, nil, fmt.Errorf("update session state: missing session id %w", db.ErrInvalidParameter)
@ -621,45 +622,57 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
return nil, nil, fmt.Errorf("update session: you must call ActivateSession to update a session's state to active: %w", db.ErrInvalidParameter)
}
newState, err := NewState(sessionId, s)
if err != nil {
return nil, nil, fmt.Errorf("update session state: %w", err)
}
var rowsAffected int
updatedSession := AllocSession()
var returnedStates []*State
_, err = r.writer.DoTx(
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
var err error
// We need to update the session version as that's the aggregate
updatedSession.PublicId = sessionId
updatedSession.Version = uint32(sessionVersion) + 1
rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"Version"}, nil, db.WithVersion(&sessionVersion))
if err != nil {
return fmt.Errorf("unable to update session version: %w", err)
return err
}
if rowsUpdated != 1 {
return fmt.Errorf("updated session and %d rows updated", rowsUpdated)
}
if err := w.Create(ctx, newState); err != nil {
return fmt.Errorf("unable to add new state: %w", err)
if len(updatedSession.CtTofuToken) > 0 {
databaseWrapper, err := r.kms.GetWrapper(ctx, updatedSession.ScopeId, kms.KeyPurposeDatabase, kms.WithKeyId(updatedSession.KeyId))
if err != nil {
return fmt.Errorf("lookup session: unable to get database wrapper: %w", err)
}
if err := updatedSession.decrypt(ctx, databaseWrapper); err != nil {
return fmt.Errorf("lookup session: cannot decrypt session value: %w", err)
}
} else {
updatedSession.CtTofuToken = nil
}
rowsAffected, err = w.Exec(updateSessionState, []interface{}{sessionId, s.String()})
if err != nil {
return fmt.Errorf("unable to update session %s state to %s: %w", sessionId, s.String(), err)
}
if rowsAffected != 0 && rowsAffected != 1 {
return fmt.Errorf("updated session %s to state %s and %d rows inserted (should be 0 or 1)", sessionId, s.String(), rowsAffected)
}
returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
if err != nil {
return err
}
if len(returnedStates) < 1 && returnedStates[0].Status != s {
return fmt.Errorf("failed to update %s to a state of %s", sessionId, s.String())
}
return nil
},
)
if err != nil {
return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err)
}
if len(updatedSession.CtTofuToken) == 0 {
updatedSession.CtTofuToken = nil
}
return &updatedSession, returnedStates, nil
}

@ -287,7 +287,7 @@ func TestRepository_CreateSession(t *testing.T) {
}
}
func TestRepository_UpdateState(t *testing.T) {
func TestRepository_updateState(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
@ -849,7 +849,6 @@ func TestRepository_CancelSession(t *testing.T) {
default:
version = tt.session.Version
}
s, err := repo.CancelSession(context.Background(), id, version)
if tt.wantErr {
require.Error(err)
@ -862,6 +861,17 @@ func TestRepository_CancelSession(t *testing.T) {
require.NotNil(s)
require.NotNil(s.States)
assert.Equal(StatusCancelling, s.States[0].Status)
stateCnt := len(s.States)
origStartTime := s.States[0].StartTime
// check idempontency
s2, err := repo.CancelSession(context.Background(), id, version+1)
require.NoError(err)
require.NotNil(s2)
require.NotNil(s2.States)
assert.Equal(stateCnt, len(s2.States))
assert.Equal(StatusCancelling, s.States[0].Status)
assert.Equal(origStartTime, s2.States[0].StartTime)
})
}
}

Loading…
Cancel
Save