From 3395bfa12de056fdbe6db6f5b8489dd89e6b8ced Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Tue, 8 Sep 2020 21:27:24 -0400 Subject: [PATCH] added UpdateSession() with unit tests --- internal/session/repository.go | 127 +++++++++-- internal/session/repository_test.go | 324 ++++++++++++++++++++++------ 2 files changed, 376 insertions(+), 75 deletions(-) diff --git a/internal/session/repository.go b/internal/session/repository.go index 47ae563ca7..49369bc968 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/hashicorp/boundary/internal/db" + dbcommon "github.com/hashicorp/boundary/internal/db/common" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" ) @@ -52,7 +53,9 @@ func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repo } // CreateSession inserts into the repository and returns the new Session with -// its State of "Pending". No options are currently supported. +// its State of "Pending". The following fields must be empty when creating a +// session: Address, Port, ServerId, ServerType, and PublicId. No options are +// currently supported. func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt ...Option) (*Session, *State, error) { if newSession == nil { return nil, nil, fmt.Errorf("create session: missing session: %w", db.ErrInvalidParameter) @@ -63,12 +66,6 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt if newSession.PublicId != "" { return nil, nil, fmt.Errorf("create session: public id is not empty: %w", db.ErrInvalidParameter) } - if newSession.ServerId == "" { - return nil, nil, fmt.Errorf("create session: server id is empty: %w", db.ErrInvalidParameter) - } - if newSession.ServerType == "" { - return nil, nil, fmt.Errorf("create session: server type is empty: %w", db.ErrInvalidParameter) - } if newSession.TargetId == "" { return nil, nil, fmt.Errorf("create session: target id is empty: %w", db.ErrInvalidParameter) } @@ -87,11 +84,17 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt if newSession.ScopeId == "" { return nil, nil, fmt.Errorf("create session: scope id is empty: %w", db.ErrInvalidParameter) } - if newSession.Address == "" { - return nil, nil, fmt.Errorf("create session: address is empty: %w", db.ErrInvalidParameter) + if newSession.Address != "" { + return nil, nil, fmt.Errorf("create session: address must empty: %w", db.ErrInvalidParameter) + } + if newSession.Port != "" { + return nil, nil, fmt.Errorf("create session: port id must empty: %w", db.ErrInvalidParameter) } - if newSession.Port == "" { - return nil, nil, fmt.Errorf("create session: port id is empty: %w", db.ErrInvalidParameter) + if newSession.ServerId != "" { + return nil, nil, fmt.Errorf("create session: server id must empty: %w", db.ErrInvalidParameter) + } + if newSession.ServerType != "" { + return nil, nil, fmt.Errorf("create session: server type must empty: %w", db.ErrInvalidParameter) } id, err := newId() @@ -126,7 +129,7 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt return fmt.Errorf("%d states found for new session %s", len(foundStates), returnedSession.PublicId) } returnedState = foundStates[0] - if returnedState.Status != Pending.String() { + if returnedState.Status != StatusPending.String() { return fmt.Errorf("new session %s state is not valid: %s", returnedSession.PublicId, returnedState.Status) } return nil @@ -196,8 +199,104 @@ func (r *Repository) DeleteSession(ctx context.Context, publicId string, opt ... panic("not implemented") } -func (r *Repository) UpdateSession(ctx context.Context, s *Session, version uint32, fieldMaskPaths []string, opt ...Option) (*Session, []*State, int, error) { - panic("not implemented") +// UpdateSession updates the repository entry for the session, using the +// fieldMaskPaths. Only BytesUp, BytesDown, TerminationReason, ServerId and +// ServerType a muttable and will be set to NULL if set to a zero value and +// included in the fieldMaskPaths. +func (r *Repository) UpdateSession(ctx context.Context, session *Session, version uint32, fieldMaskPaths []string, opt ...Option) (*Session, []*State, int, error) { + if session == nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session %w", db.ErrInvalidParameter) + } + if session.Session == nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session store %w", db.ErrInvalidParameter) + } + if session.PublicId == "" { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session public id %w", db.ErrInvalidParameter) + } + for _, f := range fieldMaskPaths { + switch { + case strings.EqualFold("BytesUp", f): + case strings.EqualFold("BytesDown", f): + case strings.EqualFold("TerminationReason", f): + case strings.EqualFold("ServerId", f): + case strings.EqualFold("ServerType", f): + case strings.EqualFold("Address", f): + case strings.EqualFold("Port", f): + default: + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: field: %s: %w", f, db.ErrInvalidFieldMask) + } + } + var dbMask, nullFields []string + dbMask, nullFields = dbcommon.BuildUpdatePaths( + map[string]interface{}{ + "BytesUp": session.BytesUp, + "BytesDown": session.BytesDown, + "TerminationReason": session.TerminationReason, + "ServerId": session.ServerId, + "ServerType": session.ServerType, + "Address": session.Address, + "Port": session.Port, + }, + fieldMaskPaths, + ) + if len(dbMask) == 0 && len(nullFields) == 0 { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", db.ErrEmptyFieldMask) + } + + var sessionScopeId string + switch { + case session.ScopeId != "": + sessionScopeId = session.ScopeId + default: + ses, _, err := r.LookupSession(ctx, session.PublicId) + if err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", err) + } + if ses == nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: unable to look up session for %s: %w", session.PublicId, err) + } + sessionScopeId = ses.ScopeId + } + + oplogWrapper, err := r.kms.GetWrapper(ctx, sessionScopeId, kms.KeyPurposeOplog) + if err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("unable to get oplog wrapper: %w", err) + } + + var s *Session + var states []*State + var rowsUpdated int + _, err = r.writer.DoTx( + ctx, + db.StdRetryCnt, + db.ExpBackoff{}, + func(reader db.Reader, w db.Writer) error { + var err error + s = session.Clone().(*Session) + metadata := s.oplog(oplog.OpType_OP_TYPE_UPDATE) + metadata["scope-id"] = []string{sessionScopeId} + rowsUpdated, err = w.Update( + ctx, + s, + dbMask, + nullFields, + db.WithOplog(oplogWrapper, metadata), + ) + if err == nil && rowsUpdated > 1 { + // return err, which will result in a rollback of the update + return errors.New("error more than 1 session would have been updated ") + } + states, err = fetchStates(ctx, reader, s.PublicId, db.WithOrder("start_time desc")) + if err != nil { + return err + } + return nil + }, + ) + if err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, session.PublicId) + } + return s, states, rowsUpdated, err } // UpdateState will update the session's state using the session id and its diff --git a/internal/session/repository_test.go b/internal/session/repository_test.go index c1f153ae8d..f3de52602e 100644 --- a/internal/session/repository_test.go +++ b/internal/session/repository_test.go @@ -7,11 +7,18 @@ import ( "time" "github.com/golang/protobuf/ptypes" + "github.com/hashicorp/boundary/internal/auth/password" + "github.com/hashicorp/boundary/internal/authtoken" "github.com/hashicorp/boundary/internal/db" + dbassert "github.com/hashicorp/boundary/internal/db/assert" + "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/servers" "github.com/hashicorp/boundary/internal/session/store" + "github.com/hashicorp/boundary/internal/target" + "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -250,30 +257,6 @@ func TestRepository_CreateSession(t *testing.T) { wantErr: true, wantIsError: db.ErrInvalidParameter, }, - { - name: "empty-serverId", - args: args{ - composedOf: func() ComposedOf { - c := TestSessionParams(t, conn, wrapper, iamRepo) - c.ServerId = "" - return c - }(), - }, - wantErr: true, - wantIsError: db.ErrInvalidParameter, - }, - { - name: "empty-serverType", - args: args{ - composedOf: func() ComposedOf { - c := TestSessionParams(t, conn, wrapper, iamRepo) - c.ServerType = "" - return c - }(), - }, - wantErr: true, - wantIsError: db.ErrInvalidParameter, - }, { name: "empty-targetId", args: args{ @@ -322,30 +305,6 @@ func TestRepository_CreateSession(t *testing.T) { wantErr: true, wantIsError: db.ErrInvalidParameter, }, - { - name: "empty-address", - args: args{ - composedOf: func() ComposedOf { - c := TestSessionParams(t, conn, wrapper, iamRepo) - c.Address = "" - return c - }(), - }, - wantErr: true, - wantIsError: db.ErrInvalidParameter, - }, - { - name: "empty-port", - args: args{ - composedOf: func() ComposedOf { - c := TestSessionParams(t, conn, wrapper, iamRepo) - c.Port = "" - return c - }(), - }, - wantErr: true, - wantIsError: db.ErrInvalidParameter, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -354,14 +313,10 @@ func TestRepository_CreateSession(t *testing.T) { Session: &store.Session{ UserId: tt.args.composedOf.UserId, HostId: tt.args.composedOf.HostId, - ServerId: tt.args.composedOf.ServerId, - ServerType: tt.args.composedOf.ServerType.String(), TargetId: tt.args.composedOf.TargetId, SetId: tt.args.composedOf.HostSetId, AuthTokenId: tt.args.composedOf.AuthTokenId, ScopeId: tt.args.composedOf.ScopeId, - Address: tt.args.composedOf.Address, - Port: tt.args.composedOf.Port, }, } ses, st, err := repo.CreateSession(context.Background(), s) @@ -377,7 +332,7 @@ func TestRepository_CreateSession(t *testing.T) { require.NoError(err) assert.NotNil(ses.CreateTime) assert.NotNil(st.StartTime) - assert.Equal(st.GetStatus(), Pending.String()) + assert.Equal(st.GetStatus(), StatusPending.String()) foundSession, foundStates, err := repo.LookupSession(context.Background(), ses.PublicId) assert.NoError(err) assert.True(proto.Equal(foundSession, ses)) @@ -386,7 +341,7 @@ func TestRepository_CreateSession(t *testing.T) { assert.NoError(err) require.Equal(1, len(foundStates)) - assert.Equal(foundStates[0].GetStatus(), Pending.String()) + assert.Equal(foundStates[0].GetStatus(), StatusPending.String()) }) } } @@ -414,7 +369,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "connected", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: Active, + newStatus: StatusActive, wantStateCnt: 2, wantErr: false, }, @@ -422,17 +377,17 @@ func TestRepository_UpdateState(t *testing.T) { name: "closed", session: func() *Session { s := TestDefaultSession(t, conn, wrapper, iamRepo) - _ = TestState(t, conn, s.PublicId, Active) + _ = TestState(t, conn, s.PublicId, StatusActive) return s }(), - newStatus: Closed, + newStatus: StatusClosed, wantStateCnt: 3, wantErr: false, }, { name: "bad-version", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: Active, + newStatus: StatusActive, overrideSessionVersion: func() *uint32 { v := uint32(22) return &v @@ -442,7 +397,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "empty-version", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: Active, + newStatus: StatusActive, overrideSessionVersion: func() *uint32 { v := uint32(0) return &v @@ -453,7 +408,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "bad-sessionId", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: Active, + newStatus: StatusActive, overrideSessionId: func() *string { s := "s_thisIsNotValid" return &s @@ -463,7 +418,7 @@ func TestRepository_UpdateState(t *testing.T) { { name: "empty-session", session: TestDefaultSession(t, conn, wrapper, iamRepo), - newStatus: Active, + newStatus: StatusActive, overrideSessionId: func() *string { s := "" return &s @@ -506,3 +461,250 @@ func TestRepository_UpdateState(t *testing.T) { }) } } + +func TestRepository_UpdateSession(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + serversRepo, err := servers.NewRepository(rw, rw, kms) + require.NoError(t, err) + + newServerFunc := func() string { + id, err := uuid.GenerateUUID() + require.NoError(t, err) + worker := &servers.Server{ + Name: "test-session-worker-" + id, + Type: servers.ServerTypeWorker.String(), + Description: "Test Session Worker", + Address: "127.0.0.1", + } + _, _, err = serversRepo.UpsertServer(context.Background(), worker) + require.NoError(t, err) + return worker.Name + } + + type args struct { + bytesUp uint64 + bytesDown uint64 + terminationReason TerminationReason + serverId string + serverType string + fieldMaskPaths []string + opt []Option + publicId *string // not updateable - db.ErrInvalidFieldMask + userId string // not updateable - db.ErrInvalidFieldMask + hostId string // not updateable - db.ErrInvalidFieldMask + targetId string // not updateable - db.ErrInvalidFieldMask + setId string // not updateable - db.ErrInvalidFieldMask + authTokenId string // not updateable - db.ErrInvalidFieldMask + scopeId string // not updateable - db.ErrInvalidFieldMask + } + tests := []struct { + name string + args args + wantRowsUpdate int + wantErr bool + wantIsError error + }{ + { + name: "valid", + args: args{ + bytesUp: 100, + bytesDown: 110, + terminationReason: Terminated, + serverId: newServerFunc(), + serverType: servers.ServerTypeWorker.String(), + fieldMaskPaths: []string{"BytesUp", "BytesDown", "TerminationReason", "ServerId", "ServerType"}, + }, + wantErr: false, + wantRowsUpdate: 1, + }, + { + name: "publicId", + args: args{ + publicId: func() *string { + id, err := newId() + require.NoError(t, err) + return &id + }(), + fieldMaskPaths: []string{"PublicId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "userId", + args: args{ + userId: func() string { + org, _ := iam.TestScopes(t, iamRepo) + u := iam.TestUser(t, iamRepo, org.PublicId) + return u.PublicId + }(), + fieldMaskPaths: []string{"UserId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "hostId", + args: args{ + hostId: func() string { + _, proj := iam.TestScopes(t, iamRepo) + cats := static.TestCatalogs(t, conn, proj.PublicId, 1) + hosts := static.TestHosts(t, conn, cats[0].PublicId, 1) + return hosts[0].PublicId + }(), + fieldMaskPaths: []string{"HostId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "targetId", + args: args{ + targetId: func() string { + _, proj := iam.TestScopes(t, iamRepo) + tcpTarget := target.TestTcpTarget(t, conn, proj.PublicId, "test target") + return tcpTarget.PublicId + }(), + fieldMaskPaths: []string{"TargetId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "setId", + args: args{ + setId: func() string { + _, proj := iam.TestScopes(t, iamRepo) + cats := static.TestCatalogs(t, conn, proj.PublicId, 1) + sets := static.TestSets(t, conn, cats[0].PublicId, 1) + return sets[0].PublicId + }(), + fieldMaskPaths: []string{"SetId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "AuthTokenId", + args: args{ + authTokenId: func() string { + ctx := context.Background() + org, _ := iam.TestScopes(t, iamRepo) + authMethod := password.TestAuthMethods(t, conn, org.PublicId, 1)[0] + acct := password.TestAccounts(t, conn, authMethod.GetPublicId(), 1)[0] + user, err := iamRepo.LookupUserWithLogin(ctx, acct.GetPublicId(), iam.WithAutoVivify(true)) + require.NoError(t, err) + + authTokenRepo, err := authtoken.NewRepository(rw, rw, kms) + require.NoError(t, err) + at, err := authTokenRepo.CreateAuthToken(ctx, user, acct.GetPublicId()) + require.NoError(t, err) + return at.PublicId + }(), + fieldMaskPaths: []string{"AuthTokenId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + { + name: "ScopeId", + args: args{ + scopeId: func() string { + _, proj := iam.TestScopes(t, iamRepo) + return proj.PublicId + }(), + fieldMaskPaths: []string{"ScopeId"}, + }, + wantErr: true, + wantRowsUpdate: 0, + wantIsError: db.ErrInvalidFieldMask, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + composedOf := TestSessionParams(t, conn, wrapper, iamRepo) + s := TestSession(t, conn, composedOf) + + updateSession := allocSession() + updateSession.PublicId = s.PublicId + if tt.args.publicId != nil { + updateSession.PublicId = *tt.args.publicId + } + updateSession.BytesUp = tt.args.bytesUp + updateSession.BytesDown = tt.args.bytesDown + updateSession.ServerId = tt.args.serverId + updateSession.ServerType = tt.args.serverType + updateSession.TerminationReason = tt.args.terminationReason.String() + updateSession.Version = s.Version + afterUpdateSession, afterUpdateState, updatedRows, err := repo.UpdateSession(context.Background(), &updateSession, updateSession.Version, tt.args.fieldMaskPaths, tt.args.opt...) + + if tt.wantErr { + require.Error(err) + if tt.wantIsError != nil { + assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error: %s", err.Error()) + } + assert.Nil(afterUpdateSession) + assert.Nil(afterUpdateState) + assert.Equal(0, updatedRows) + err = db.TestVerifyOplog(t, rw, s.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.Error(err) + assert.True(errors.Is(db.ErrRecordNotFound, err)) + return + } + require.NoError(err) + assert.Equal(tt.wantRowsUpdate, updatedRows) + require.NotNil(afterUpdateSession) + require.NotNil(afterUpdateState) + switch tt.name { + case "valid-no-op": + assert.Equal(s.UpdateTime, afterUpdateSession.UpdateTime) + default: + assert.NotEqual(s.UpdateTime, afterUpdateSession.UpdateTime) + } + foundSession, foundStates, err := repo.LookupSession(context.Background(), s.PublicId) + require.NoError(err) + assert.True(proto.Equal(afterUpdateSession, foundSession)) + dbassrt := dbassert.New(t, rw) + if tt.args.bytesUp == 0 { + dbassrt.IsNull(foundSession, "BytesUp") + } + dbassrt = dbassert.New(t, rw) + if tt.args.bytesDown == 0 { + dbassrt.IsNull(foundSession, "BytesDown") + } + if tt.args.serverId == "" { + dbassrt.IsNull(foundSession, "ServerId") + } + if tt.args.serverType == "" { + dbassrt.IsNull(foundSession, "ServerType") + } + assert.Equal(tt.args.bytesUp, foundSession.BytesUp) + assert.Equal(tt.args.bytesDown, foundSession.BytesDown) + assert.Equal(tt.args.terminationReason.String(), foundSession.TerminationReason) + assert.Equal(tt.args.serverId, foundSession.ServerId) + assert.Equal(tt.args.serverType, foundSession.ServerType) + + err = db.TestVerifyOplog(t, rw, s.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + + require.Equal(1, len(foundStates)) + assert.Equal(StatusPending.String(), foundStates[0].Status) + }) + } + +}