add UpdateState and unit tests

pull/347/head
Jim Lambert 6 years ago
parent baa619bf32
commit 2a57b689b2

@ -157,7 +157,7 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
return fmt.Errorf("lookup session: failed %w for %s", err, sessionId)
}
var err error
if states, err = fetchStates(ctx, read, sessionId); err != nil {
if states, err = fetchStates(ctx, read, sessionId, db.WithOrder("start_time desc")); err != nil {
return err
}
return nil
@ -200,8 +200,87 @@ func (r *Repository) UpdateSession(ctx context.Context, s *Session, version uint
panic("not implemented")
}
func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, int, error) {
panic("not implemented")
// UpdateState will update the session's state using the session id and its
// version. No options are currently supported.
func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionVersion uint32, s Status, opt ...Option) (*Session, []*State, error) {
if sessionId == "" {
return nil, nil, fmt.Errorf("update session state: missing session id %w", db.ErrInvalidParameter)
}
if sessionVersion == 0 {
return nil, nil, fmt.Errorf("update session state: version cannot be zero: %w", db.ErrInvalidParameter)
}
if s == "" {
return nil, nil, fmt.Errorf("update session state: missing session status: %w", db.ErrInvalidParameter)
}
newState, err := NewState(sessionId, s)
if err != nil {
return nil, nil, fmt.Errorf("update session state: %w", err)
}
ses, _, err := r.LookupSession(ctx, sessionId)
if err != nil {
return nil, nil, fmt.Errorf("update session state: %w", err)
}
if ses == nil {
return nil, nil, fmt.Errorf("update session state: unable to look up session for %s: %w", sessionId, err)
}
oplogWrapper, err := r.kms.GetWrapper(ctx, ses.ScopeId, kms.KeyPurposeOplog)
if err != nil {
return nil, nil, fmt.Errorf("update session state: unable to get oplog wrapper: %w", err)
}
updatedSession := allocSession()
var returnedStates []*State
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 2)
sessionTicket, err := w.GetTicket(ses)
if err != nil {
return fmt.Errorf("unable to get ticket: %w", err)
}
// We need to update the session version as that's the aggregate
updatedSession.PublicId = sessionId
updatedSession.Version = uint32(sessionVersion) + 1
var sessionOplogMsg oplog.Message
rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"Version"}, nil, db.NewOplogMsg(&sessionOplogMsg), db.WithVersion(&sessionVersion))
if err != nil {
return fmt.Errorf("unable to update session version: %w", err)
}
if rowsUpdated != 1 {
return fmt.Errorf("updated session and %d rows updated", rowsUpdated)
}
msgs = append(msgs, &sessionOplogMsg)
var stateOplogMsg oplog.Message
if err := w.Create(ctx, newState, db.NewOplogMsg(&stateOplogMsg)); err != nil {
return fmt.Errorf("unable to add new state: %w", err)
}
msgs = append(msgs, &stateOplogMsg)
metadata := oplog.Metadata{
"op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()},
"scope-id": []string{ses.ScopeId},
"scope-type": []string{"project"},
"resource-public-id": []string{sessionId},
}
if err := w.WriteOplogEntryWith(ctx, oplogWrapper, sessionTicket, metadata, msgs); err != nil {
return fmt.Errorf("unable to write oplog: %w", err)
}
returnedStates, err = fetchStates(ctx, reader, sessionId, db.WithOrder("start_time desc"))
if err != nil {
return err
}
return nil
},
)
if err != nil {
return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err)
}
return &updatedSession, returnedStates, nil
}
// list will return a listing of resources and honor the WithLimit option or the

@ -390,3 +390,119 @@ func TestRepository_CreateSession(t *testing.T) {
})
}
}
func TestRepository_UpdateState(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
tests := []struct {
name string
session *Session
newStatus Status
overrideSessionId *string
overrideSessionVersion *uint32
wantStateCnt int
wantErr bool
wantIsError error
}{
{
name: "connected",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Connected,
wantStateCnt: 2,
wantErr: false,
},
{
name: "closed",
session: func() *Session {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
_ = TestState(t, conn, s.PublicId, Connected)
return s
}(),
newStatus: Closed,
wantStateCnt: 3,
wantErr: false,
},
{
name: "bad-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Connected,
overrideSessionVersion: func() *uint32 {
v := uint32(22)
return &v
}(),
wantErr: true,
},
{
name: "empty-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Connected,
overrideSessionVersion: func() *uint32 {
v := uint32(0)
return &v
}(),
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "bad-sessionId",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Connected,
overrideSessionId: func() *string {
s := "s_thisIsNotValid"
return &s
}(),
wantErr: true,
},
{
name: "empty-session",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Connected,
overrideSessionId: func() *string {
s := ""
return &s
}(),
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var id string
var version uint32
switch {
case tt.overrideSessionId != nil:
id = *tt.overrideSessionId
default:
id = tt.session.PublicId
}
switch {
case tt.overrideSessionVersion != nil:
version = *tt.overrideSessionVersion
default:
version = tt.session.Version
}
s, ss, err := repo.UpdateState(context.Background(), id, version, tt.newStatus)
if tt.wantErr {
require.Error(err)
if tt.wantIsError != nil {
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error %s", err.Error())
}
return
}
require.NoError(err)
require.NotNil(s)
require.NotNil(ss)
assert.Equal(tt.wantStateCnt, len(ss))
assert.Equal(tt.newStatus.String(), ss[0].Status)
})
}
}

Loading…
Cancel
Save