From 2a57b689b2aa73de98dfd45653b0eecdddab96be Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Tue, 8 Sep 2020 14:11:28 -0400 Subject: [PATCH] add UpdateState and unit tests --- internal/session/repository.go | 85 +++++++++++++++++++- internal/session/repository_test.go | 116 ++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 3 deletions(-) diff --git a/internal/session/repository.go b/internal/session/repository.go index dfc8c68f92..47ae563ca7 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -157,7 +157,7 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. return fmt.Errorf("lookup session: failed %w for %s", err, sessionId) } var err error - if states, err = fetchStates(ctx, read, sessionId); err != nil { + if states, err = fetchStates(ctx, read, sessionId, db.WithOrder("start_time desc")); err != nil { return err } return nil @@ -200,8 +200,87 @@ func (r *Repository) UpdateSession(ctx context.Context, s *Session, version uint panic("not implemented") } -func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, int, error) { - panic("not implemented") +// UpdateState will update the session's state using the session id and its +// version. No options are currently supported. +func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) { + if sessionId == "" { + return nil, nil, fmt.Errorf("update session state: missing session id %w", db.ErrInvalidParameter) + } + if sessionVersion == 0 { + return nil, nil, fmt.Errorf("update session state: version cannot be zero: %w", db.ErrInvalidParameter) + } + if s == "" { + return nil, nil, fmt.Errorf("update session state: missing session status: %w", db.ErrInvalidParameter) + } + + newState, err := NewState(sessionId, s) + if err != nil { + return nil, nil, fmt.Errorf("update session state: %w", err) + } + ses, _, err := r.LookupSession(ctx, sessionId) + if err != nil { + return nil, nil, fmt.Errorf("update session state: %w", err) + } + if ses == nil { + return nil, nil, fmt.Errorf("update session state: unable to look up session for %s: %w", sessionId, err) + } + + oplogWrapper, err := r.kms.GetWrapper(ctx, ses.ScopeId, kms.KeyPurposeOplog) + if err != nil { + return nil, nil, fmt.Errorf("update session state: unable to get oplog wrapper: %w", err) + } + + updatedSession := allocSession() + var returnedStates []*State + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + msgs := make([]*oplog.Message, 0, 2) + sessionTicket, err := w.GetTicket(ses) + if err != nil { + return fmt.Errorf("unable to get ticket: %w", err) + } + + // We need to update the session version as that's the aggregate + updatedSession.PublicId = sessionId + updatedSession.Version = uint32(sessionVersion) + 1 + var sessionOplogMsg oplog.Message + rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"Version"}, nil, db.NewOplogMsg(&sessionOplogMsg), db.WithVersion(&sessionVersion)) + if err != nil { + return fmt.Errorf("unable to update session version: %w", err) + } + if rowsUpdated != 1 { + return fmt.Errorf("updated session and %d rows updated", rowsUpdated) + } + msgs = append(msgs, &sessionOplogMsg) + var stateOplogMsg oplog.Message + if err := w.Create(ctx, newState, db.NewOplogMsg(&stateOplogMsg)); err != nil { + return fmt.Errorf("unable to add new state: %w", err) + } + msgs = append(msgs, &stateOplogMsg) + + metadata := oplog.Metadata{ + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, + "scope-id": []string{ses.ScopeId}, + "scope-type": []string{"project"}, + "resource-public-id": []string{sessionId}, + } + if err := w.WriteOplogEntryWith(ctx, oplogWrapper, sessionTicket, metadata, msgs); err != nil { + return fmt.Errorf("unable to write oplog: %w", err) + } + returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err) + } + return &updatedSession, returnedStates, nil } // list will return a listing of resources and honor the WithLimit option or the diff --git a/internal/session/repository_test.go b/internal/session/repository_test.go index 78de8554ba..1bbf60de78 100644 --- a/internal/session/repository_test.go +++ b/internal/session/repository_test.go @@ -390,3 +390,119 @@ func TestRepository_CreateSession(t *testing.T) { }) } } + +func TestRepository_UpdateState(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + + tests := []struct { + name string + session *Session + newStatus Status + overrideSessionId *string + overrideSessionVersion *uint32 + wantStateCnt int + wantErr bool + wantIsError error + }{ + { + name: "connected", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + newStatus: Connected, + wantStateCnt: 2, + wantErr: false, + }, + { + name: "closed", + session: func() *Session { + s := TestDefaultSession(t, conn, wrapper, iamRepo) + _ = TestState(t, conn, s.PublicId, Connected) + return s + }(), + newStatus: Closed, + wantStateCnt: 3, + wantErr: false, + }, + { + name: "bad-version", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + newStatus: Connected, + overrideSessionVersion: func() *uint32 { + v := uint32(22) + return &v + }(), + wantErr: true, + }, + { + name: "empty-version", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + newStatus: Connected, + overrideSessionVersion: func() *uint32 { + v := uint32(0) + return &v + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + { + name: "bad-sessionId", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + newStatus: Connected, + overrideSessionId: func() *string { + s := "s_thisIsNotValid" + return &s + }(), + wantErr: true, + }, + { + name: "empty-session", + session: TestDefaultSession(t, conn, wrapper, iamRepo), + newStatus: Connected, + overrideSessionId: func() *string { + s := "" + return &s + }(), + wantErr: true, + wantIsError: db.ErrInvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + var id string + var version uint32 + switch { + case tt.overrideSessionId != nil: + id = *tt.overrideSessionId + default: + id = tt.session.PublicId + } + switch { + case tt.overrideSessionVersion != nil: + version = *tt.overrideSessionVersion + default: + version = tt.session.Version + } + + s, ss, err := repo.UpdateState(context.Background(), id, version, tt.newStatus) + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error()) + } + return + } + require.NoError(err) + require.NotNil(s) + require.NotNil(ss) + assert.Equal(tt.wantStateCnt, len(ss)) + assert.Equal(tt.newStatus.String(), ss[0].Status) + }) + } +}