diff --git a/internal/session/query.go b/internal/session/query.go new file mode 100644 index 0000000000..0f1bf137c8 --- /dev/null +++ b/internal/session/query.go @@ -0,0 +1,20 @@ +package session + +const ( + activateStateCte = ` +insert into session_state +with not_active as ( + select session_id, 'active' as state + from + session s, + session_state ss + where + s.public_id = ss.session_id and + ss.state = 'pending' and + ss.session_id = $1 and + s.version = $2 and + s.public_id not in(select public_id from session_state where session_id = $1 and state = 'active') +) +select * from not_active; +` +) diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index bb0c43f8f3..7bf229aa12 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -308,6 +308,47 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio return s, states, rowsUpdated, err } +// ActivateSession will activate the session and is called by a worker after +// authenticating the session. +func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sessionVersion uint32) (*Session, []*State, error) { + if sessionId == "" { + return nil, nil, fmt.Errorf("activate session state: missing session id %w", db.ErrInvalidParameter) + } + if sessionVersion == 0 { + return nil, nil, fmt.Errorf("activate session state: version cannot be zero: %w", db.ErrInvalidParameter) + } + + updatedSession := AllocSession() + updatedSession.PublicId = sessionId + var returnedStates []*State + _, err := r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + rowsAffected, err := w.Exec(activateStateCte, []interface{}{sessionId, sessionVersion}) + if err != nil { + return fmt.Errorf("unable to activate session %s: %w", sessionId, err) + } + if rowsAffected == 0 { + return fmt.Errorf("unable to activate session %s", sessionId) + } + if err := r.reader.LookupById(ctx, &updatedSession); err != nil { + return fmt.Errorf("lookup session: failed %w for %s", err, sessionId) + } + returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return nil, nil, fmt.Errorf("activate session: %w", err) + } + return &updatedSession, returnedStates, nil +} + // UpdateState will update the session's state using the session id and its // version. No options are currently supported. func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) { @@ -320,6 +361,9 @@ func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionV if s == "" { return nil, nil, fmt.Errorf("update session state: missing session status: %w", db.ErrInvalidParameter) } + if s == StatusActive { + return nil, nil, fmt.Errorf("update session: you must call ActivateSession to update a session's state to active: %w", db.ErrInvalidParameter) + } newState, err := NewState(sessionId, s) if err != nil { diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index e69d58b38e..53419a6c7d 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -302,9 +302,9 @@ func TestRepository_UpdateState(t *testing.T) { wantIsError error }{ { - name: "connected", + name: "cancelling", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: StatusActive, + newStatus: StatusCanceling, wantStateCnt: 2, wantErr: false, }, @@ -322,7 +322,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "bad-version", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: StatusActive, + newStatus: StatusCanceling, overrideSessionVersion: func() *uint32 { v := uint32(22) return &v @@ -332,7 +332,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "empty-version", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: StatusActive, + newStatus: StatusCanceling, overrideSessionVersion: func() *uint32 { v := uint32(0) return &v @@ -343,7 +343,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "bad-sessionId", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: StatusActive, + newStatus: StatusCanceling, overrideSessionId: func() *string { s := "s_thisIsNotValid" return &s @@ -353,7 +353,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "empty-session", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: StatusActive, + newStatus: StatusCanceling, overrideSessionId: func() *string { s := "" return &s @@ -397,6 +397,127 @@ func TestRepository_UpdateState(t *testing.T) { } } +func TestRepository_ActivateState(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + tests := []struct { + name string + session *Session + overrideSessionId *string + overrideSessionVersion *uint32 + wantErr bool + wantIsError error + }{ + { + name: "valid", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + wantErr: false, + }, + { + name: "already-active", + session: func() *Session { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + activeSession, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version) + require.NoError(t, err) + return activeSession + }(), + wantErr: true, + }, + { + name: "bad-session-id", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + overrideSessionId: func() *string { + id, err := newId() + require.NoError(t, err) + return &id + }(), + wantErr: true, + }, + { + name: "bad-session-version", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + overrideSessionVersion: func() *uint32 { + v := uint32(100) + return &v + }(), + wantErr: true, + }, + { + name: "empty-session-id", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + overrideSessionId: func() *string { + id := "" + return &id + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "empty-session-version", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + overrideSessionVersion: func() *uint32 { + v := uint32(0) + return &v + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var id string + var version uint32 + switch { + case tt.overrideSessionId != nil: + id = *tt.overrideSessionId + default: + id = tt.session.PublicId + } + switch { + case tt.overrideSessionVersion != nil: + version = *tt.overrideSessionVersion + default: + version = tt.session.Version + } + s, ss, err := repo.ActivateSession(context.Background(), id, version) + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) + } + return + } + require.NoError(err) + require.NotNil(s) + require.NotNil(ss) + assert.Equal(2, len(ss)) + assert.Equal(StatusActive.String(), ss[0].Status) + }) + t.Run("already active", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + session := TestDefaultSession(t, conn, wrapper, iamRepo) + s, ss, err := repo.ActivateSession(context.Background(), session.PublicId, 1) + require.NoError(err) + require.NotNil(s) + require.NotNil(ss) + assert.Equal(2, len(ss)) + assert.Equal(StatusActive.String(), ss[0].Status) + + _, _, err = repo.ActivateSession(context.Background(), session.PublicId, 1) + require.Error(err) + + _, _, err = repo.ActivateSession(context.Background(), session.PublicId, 2) + require.Error(err) + }) + } +} func TestRepository_UpdateSession(t *testing.T) { t.Parallel() conn, _ := db.TestSetup(t, "postgres")