Add ActivateSession() which uses a CTE to only active a session that is pending

jimlambrt-session-basics
Jim Lambert 6 years ago
parent e410ce9278
commit 90f9bedc66

@ -0,0 +1,20 @@
package session
const (
activateStateCte = `
insert into session_state
with not_active as (
select session_id, 'active' as state
from
session s,
session_state ss
where
s.public_id = ss.session_id and
ss.state = 'pending' and
ss.session_id = $1 and
s.version = $2 and
s.public_id not in(select public_id from session_state where session_id = $1 and state = 'active')
)
select * from not_active;
`
)

@ -308,6 +308,47 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio
return s, states, rowsUpdated, err
}
// ActivateSession will activate the session and is called by a worker after
// authenticating the session.
func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sessionVersion uint32) (*Session, []*State, error) {
if sessionId == "" {
return nil, nil, fmt.Errorf("activate session state: missing session id %w", db.ErrInvalidParameter)
}
if sessionVersion == 0 {
return nil, nil, fmt.Errorf("activate session state: version cannot be zero: %w", db.ErrInvalidParameter)
}
updatedSession := AllocSession()
updatedSession.PublicId = sessionId
var returnedStates []*State
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
rowsAffected, err := w.Exec(activateStateCte, []interface{}{sessionId, sessionVersion})
if err != nil {
return fmt.Errorf("unable to activate session %s: %w", sessionId, err)
}
if rowsAffected == 0 {
return fmt.Errorf("unable to activate session %s", sessionId)
}
if err := r.reader.LookupById(ctx, &updatedSession); err != nil {
return fmt.Errorf("lookup session: failed %w for %s", err, sessionId)
}
returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
if err != nil {
return err
}
return nil
},
)
if err != nil {
return nil, nil, fmt.Errorf("activate session: %w", err)
}
return &updatedSession, returnedStates, nil
}
// UpdateState will update the session's state using the session id and its
// version. No options are currently supported.
func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) {
@ -320,6 +361,9 @@ func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionV
if s == "" {
return nil, nil, fmt.Errorf("update session state: missing session status: %w", db.ErrInvalidParameter)
}
if s == StatusActive {
return nil, nil, fmt.Errorf("update session: you must call ActivateSession to update a session's state to active: %w", db.ErrInvalidParameter)
}
newState, err := NewState(sessionId, s)
if err != nil {

@ -302,9 +302,9 @@ func TestRepository_UpdateState(t *testing.T) {
wantIsError error
}{
{
name: "connected",
name: "cancelling",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: StatusActive,
newStatus: StatusCanceling,
wantStateCnt: 2,
wantErr: false,
},
@ -322,7 +322,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "bad-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: StatusActive,
newStatus: StatusCanceling,
overrideSessionVersion: func() *uint32 {
v := uint32(22)
return &v
@ -332,7 +332,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "empty-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: StatusActive,
newStatus: StatusCanceling,
overrideSessionVersion: func() *uint32 {
v := uint32(0)
return &v
@ -343,7 +343,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "bad-sessionId",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: StatusActive,
newStatus: StatusCanceling,
overrideSessionId: func() *string {
s := "s_thisIsNotValid"
return &s
@ -353,7 +353,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "empty-session",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: StatusActive,
newStatus: StatusCanceling,
overrideSessionId: func() *string {
s := ""
return &s
@ -397,6 +397,127 @@ func TestRepository_UpdateState(t *testing.T) {
}
}
func TestRepository_ActivateState(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
tests := []struct {
name string
session *Session
overrideSessionId *string
overrideSessionVersion *uint32
wantErr bool
wantIsError error
}{
{
name: "valid",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
wantErr: false,
},
{
name: "already-active",
session: func() *Session {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
activeSession, _, err := repo.ActivateSession(context.Background(), s.PublicId, s.Version)
require.NoError(t, err)
return activeSession
}(),
wantErr: true,
},
{
name: "bad-session-id",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
overrideSessionId: func() *string {
id, err := newId()
require.NoError(t, err)
return &id
}(),
wantErr: true,
},
{
name: "bad-session-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
overrideSessionVersion: func() *uint32 {
v := uint32(100)
return &v
}(),
wantErr: true,
},
{
name: "empty-session-id",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
overrideSessionId: func() *string {
id := ""
return &id
}(),
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-session-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
overrideSessionVersion: func() *uint32 {
v := uint32(0)
return &v
}(),
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var id string
var version uint32
switch {
case tt.overrideSessionId != nil:
id = *tt.overrideSessionId
default:
id = tt.session.PublicId
}
switch {
case tt.overrideSessionVersion != nil:
version = *tt.overrideSessionVersion
default:
version = tt.session.Version
}
s, ss, err := repo.ActivateSession(context.Background(), id, version)
if tt.wantErr {
require.Error(err)
if tt.wantIsError != nil {
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error())
}
return
}
require.NoError(err)
require.NotNil(s)
require.NotNil(ss)
assert.Equal(2, len(ss))
assert.Equal(StatusActive.String(), ss[0].Status)
})
t.Run("already active", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
session := TestDefaultSession(t, conn, wrapper, iamRepo)
s, ss, err := repo.ActivateSession(context.Background(), session.PublicId, 1)
require.NoError(err)
require.NotNil(s)
require.NotNil(ss)
assert.Equal(2, len(ss))
assert.Equal(StatusActive.String(), ss[0].Status)
_, _, err = repo.ActivateSession(context.Background(), session.PublicId, 1)
require.Error(err)
_, _, err = repo.ActivateSession(context.Background(), session.PublicId, 2)
require.Error(err)
})
}
}
func TestRepository_UpdateSession(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")

Loading…
Cancel
Save