diff --git a/internal/session/query.go b/internal/session/query.go index 050ac0dc85..e449be497f 100644 --- a/internal/session/query.go +++ b/internal/session/query.go @@ -19,7 +19,8 @@ select * from not_active; ` // updateSessionState checks that we don't already have a row for the new - // state before inserting a new state. + // state or it's not already terminated (final state) before inserting a new + // state. updateSessionState = ` insert into session_state(session_id, state) select @@ -34,8 +35,21 @@ where from session_state where - session_id = $1::text and - state = $2 + -- already in the updated state + ( + session_id = $1::text and + state = $2 + ) or + -- already terminated + session_id in ( + select + session_id + from + session_state + where + session_id = $1::text and + state = 'terminated' + ) ) ` diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index f889a6564d..b50a942ce0 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -1057,10 +1057,32 @@ func TestRepository_CancelSession(t *testing.T) { overrideSessionVersion *uint32 wantErr bool wantIsError error + wantStatus Status }{ { - name: "valid", - session: setupFn(), + name: "valid", + session: setupFn(), + wantStatus: StatusCanceling, + }, + { + name: "already-terminated", + session: func() *Session { + session := TestDefaultSession(t, conn, wrapper, iamRepo) + c := TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) + cw := CloseWith{ + ConnectionId: c.PublicId, + BytesUp: 1, + BytesDown: 1, + ClosedReason: ConnectionClosedByUser, + } + _, err = repo.CloseConnections(context.Background(), []CloseWith{cw}) + require.NoError(t, err) + s, _, err := repo.LookupSession(context.Background(), session.PublicId) + require.NoError(t, err) + assert.Equal(t, StatusTerminated, s.States[0].Status) + return session + }(), + wantStatus: StatusTerminated, }, { name: "bad-session-id", @@ -1070,7 +1092,8 @@ func TestRepository_CancelSession(t *testing.T) { require.NoError(t, err) return &id }(), - wantErr: true, + wantErr: true, + wantStatus: StatusCanceling, }, { name: "missing-session-id", @@ -1079,6 +1102,7 @@ func TestRepository_CancelSession(t *testing.T) { id := "" return &id }(), + wantStatus: StatusCanceling, wantErr: true, wantIsError: db.ErrInvalidParameter, }, @@ -1089,7 +1113,8 @@ func TestRepository_CancelSession(t *testing.T) { v := uint32(101) return &v }(), - wantErr: true, + wantStatus: StatusCanceling, + wantErr: true, }, { name: "missing-version-id", @@ -1098,6 +1123,7 @@ func TestRepository_CancelSession(t *testing.T) { v := uint32(0) return &v }(), + wantStatus: StatusCanceling, wantErr: true, wantIsError: db.ErrInvalidParameter, }, @@ -1130,7 +1156,7 @@ func TestRepository_CancelSession(t *testing.T) { require.NoError(err) require.NotNil(s) require.NotNil(s.States) - assert.Equal(StatusCanceling, s.States[0].Status) + assert.Equal(tt.wantStatus, s.States[0].Status) stateCnt := len(s.States) origStartTime := s.States[0].StartTime @@ -1140,7 +1166,7 @@ func TestRepository_CancelSession(t *testing.T) { require.NotNil(s2) require.NotNil(s2.States) assert.Equal(stateCnt, len(s2.States)) - assert.Equal(StatusCanceling, s.States[0].Status) + assert.Equal(tt.wantStatus, s.States[0].Status) assert.Equal(origStartTime, s2.States[0].StartTime) }) }