added UpdateSession() with unit tests

pull/347/head
Jim Lambert 6 years ago
parent e8c44af322
commit 3395bfa12d

@ -7,6 +7,7 @@ import (
"strings"
"github.com/hashicorp/boundary/internal/db"
dbcommon "github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
)
@ -52,7 +53,9 @@ func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repo
}
// CreateSession inserts into the repository and returns the new Session with
// its State of "Pending". No options are currently supported.
// its State of "Pending". The following fields must be empty when creating a
// session: Address, Port, ServerId, ServerType, and PublicId. No options are
// currently supported.
func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt ...Option) (*Session, *State, error) {
if newSession == nil {
return nil, nil, fmt.Errorf("create session: missing session: %w", db.ErrInvalidParameter)
@ -63,12 +66,6 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt
if newSession.PublicId != "" {
return nil, nil, fmt.Errorf("create session: public id is not empty: %w", db.ErrInvalidParameter)
}
if newSession.ServerId == "" {
return nil, nil, fmt.Errorf("create session: server id is empty: %w", db.ErrInvalidParameter)
}
if newSession.ServerType == "" {
return nil, nil, fmt.Errorf("create session: server type is empty: %w", db.ErrInvalidParameter)
}
if newSession.TargetId == "" {
return nil, nil, fmt.Errorf("create session: target id is empty: %w", db.ErrInvalidParameter)
}
@ -87,11 +84,17 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt
if newSession.ScopeId == "" {
return nil, nil, fmt.Errorf("create session: scope id is empty: %w", db.ErrInvalidParameter)
}
if newSession.Address == "" {
return nil, nil, fmt.Errorf("create session: address is empty: %w", db.ErrInvalidParameter)
if newSession.Address != "" {
return nil, nil, fmt.Errorf("create session: address must empty: %w", db.ErrInvalidParameter)
}
if newSession.Port != "" {
return nil, nil, fmt.Errorf("create session: port id must empty: %w", db.ErrInvalidParameter)
}
if newSession.Port == "" {
return nil, nil, fmt.Errorf("create session: port id is empty: %w", db.ErrInvalidParameter)
if newSession.ServerId != "" {
return nil, nil, fmt.Errorf("create session: server id must empty: %w", db.ErrInvalidParameter)
}
if newSession.ServerType != "" {
return nil, nil, fmt.Errorf("create session: server type must empty: %w", db.ErrInvalidParameter)
}
id, err := newId()
@ -126,7 +129,7 @@ func (r *Repository) CreateSession(ctx context.Context, newSession *Session, opt
return fmt.Errorf("%d states found for new session %s", len(foundStates), returnedSession.PublicId)
}
returnedState = foundStates[0]
if returnedState.Status != Pending.String() {
if returnedState.Status != StatusPending.String() {
return fmt.Errorf("new session %s state is not valid: %s", returnedSession.PublicId, returnedState.Status)
}
return nil
@ -196,8 +199,104 @@ func (r *Repository) DeleteSession(ctx context.Context, publicId string, opt ...
panic("not implemented")
}
func (r *Repository) UpdateSession(ctx context.Context, s *Session, version uint32, fieldMaskPaths []string, opt ...Option) (*Session, []*State, int, error) {
panic("not implemented")
// UpdateSession updates the repository entry for the session, using the
// fieldMaskPaths. Only BytesUp, BytesDown, TerminationReason, ServerId and
// ServerType a muttable and will be set to NULL if set to a zero value and
// included in the fieldMaskPaths.
func (r *Repository) UpdateSession(ctx context.Context, session *Session, version uint32, fieldMaskPaths []string, opt ...Option) (*Session, []*State, int, error) {
if session == nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session %w", db.ErrInvalidParameter)
}
if session.Session == nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session store %w", db.ErrInvalidParameter)
}
if session.PublicId == "" {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session public id %w", db.ErrInvalidParameter)
}
for _, f := range fieldMaskPaths {
switch {
case strings.EqualFold("BytesUp", f):
case strings.EqualFold("BytesDown", f):
case strings.EqualFold("TerminationReason", f):
case strings.EqualFold("ServerId", f):
case strings.EqualFold("ServerType", f):
case strings.EqualFold("Address", f):
case strings.EqualFold("Port", f):
default:
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: field: %s: %w", f, db.ErrInvalidFieldMask)
}
}
var dbMask, nullFields []string
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"BytesUp": session.BytesUp,
"BytesDown": session.BytesDown,
"TerminationReason": session.TerminationReason,
"ServerId": session.ServerId,
"ServerType": session.ServerType,
"Address": session.Address,
"Port": session.Port,
},
fieldMaskPaths,
)
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", db.ErrEmptyFieldMask)
}
var sessionScopeId string
switch {
case session.ScopeId != "":
sessionScopeId = session.ScopeId
default:
ses, _, err := r.LookupSession(ctx, session.PublicId)
if err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", err)
}
if ses == nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: unable to look up session for %s: %w", session.PublicId, err)
}
sessionScopeId = ses.ScopeId
}
oplogWrapper, err := r.kms.GetWrapper(ctx, sessionScopeId, kms.KeyPurposeOplog)
if err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("unable to get oplog wrapper: %w", err)
}
var s *Session
var states []*State
var rowsUpdated int
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
var err error
s = session.Clone().(*Session)
metadata := s.oplog(oplog.OpType_OP_TYPE_UPDATE)
metadata["scope-id"] = []string{sessionScopeId}
rowsUpdated, err = w.Update(
ctx,
s,
dbMask,
nullFields,
db.WithOplog(oplogWrapper, metadata),
)
if err == nil && rowsUpdated > 1 {
// return err, which will result in a rollback of the update
return errors.New("error more than 1 session would have been updated ")
}
states, err = fetchStates(ctx, reader, s.PublicId, db.WithOrder("start_time desc"))
if err != nil {
return err
}
return nil
},
)
if err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, session.PublicId)
}
return s, states, rowsUpdated, err
}
// UpdateState will update the session's state using the session id and its

@ -7,11 +7,18 @@ import (
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/db"
dbassert "github.com/hashicorp/boundary/internal/db/assert"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/session/store"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
@ -250,30 +257,6 @@ func TestRepository_CreateSession(t *testing.T) {
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-serverId",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.ServerId = ""
return c
}(),
},
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-serverType",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.ServerType = ""
return c
}(),
},
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-targetId",
args: args{
@ -322,30 +305,6 @@ func TestRepository_CreateSession(t *testing.T) {
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-address",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.Address = ""
return c
}(),
},
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
{
name: "empty-port",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.Port = ""
return c
}(),
},
wantErr: true,
wantIsError: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -354,14 +313,10 @@ func TestRepository_CreateSession(t *testing.T) {
Session: &store.Session{
UserId: tt.args.composedOf.UserId,
HostId: tt.args.composedOf.HostId,
ServerId: tt.args.composedOf.ServerId,
ServerType: tt.args.composedOf.ServerType.String(),
TargetId: tt.args.composedOf.TargetId,
SetId: tt.args.composedOf.HostSetId,
AuthTokenId: tt.args.composedOf.AuthTokenId,
ScopeId: tt.args.composedOf.ScopeId,
Address: tt.args.composedOf.Address,
Port: tt.args.composedOf.Port,
},
}
ses, st, err := repo.CreateSession(context.Background(), s)
@ -377,7 +332,7 @@ func TestRepository_CreateSession(t *testing.T) {
require.NoError(err)
assert.NotNil(ses.CreateTime)
assert.NotNil(st.StartTime)
assert.Equal(st.GetStatus(), Pending.String())
assert.Equal(st.GetStatus(), StatusPending.String())
foundSession, foundStates, err := repo.LookupSession(context.Background(), ses.PublicId)
assert.NoError(err)
assert.True(proto.Equal(foundSession, ses))
@ -386,7 +341,7 @@ func TestRepository_CreateSession(t *testing.T) {
assert.NoError(err)
require.Equal(1, len(foundStates))
assert.Equal(foundStates[0].GetStatus(), Pending.String())
assert.Equal(foundStates[0].GetStatus(), StatusPending.String())
})
}
}
@ -414,7 +369,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "connected",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Active,
newStatus: StatusActive,
wantStateCnt: 2,
wantErr: false,
},
@ -422,17 +377,17 @@ func TestRepository_UpdateState(t *testing.T) {
name: "closed",
session: func() *Session {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
_ = TestState(t, conn, s.PublicId, Active)
_ = TestState(t, conn, s.PublicId, StatusActive)
return s
}(),
newStatus: Closed,
newStatus: StatusClosed,
wantStateCnt: 3,
wantErr: false,
},
{
name: "bad-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Active,
newStatus: StatusActive,
overrideSessionVersion: func() *uint32 {
v := uint32(22)
return &v
@ -442,7 +397,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "empty-version",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Active,
newStatus: StatusActive,
overrideSessionVersion: func() *uint32 {
v := uint32(0)
return &v
@ -453,7 +408,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "bad-sessionId",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Active,
newStatus: StatusActive,
overrideSessionId: func() *string {
s := "s_thisIsNotValid"
return &s
@ -463,7 +418,7 @@ func TestRepository_UpdateState(t *testing.T) {
{
name: "empty-session",
session: TestDefaultSession(t, conn, wrapper, iamRepo),
newStatus: Active,
newStatus: StatusActive,
overrideSessionId: func() *string {
s := ""
return &s
@ -506,3 +461,250 @@ func TestRepository_UpdateState(t *testing.T) {
})
}
}
func TestRepository_UpdateSession(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
serversRepo, err := servers.NewRepository(rw, rw, kms)
require.NoError(t, err)
newServerFunc := func() string {
id, err := uuid.GenerateUUID()
require.NoError(t, err)
worker := &servers.Server{
Name: "test-session-worker-" + id,
Type: servers.ServerTypeWorker.String(),
Description: "Test Session Worker",
Address: "127.0.0.1",
}
_, _, err = serversRepo.UpsertServer(context.Background(), worker)
require.NoError(t, err)
return worker.Name
}
type args struct {
bytesUp uint64
bytesDown uint64
terminationReason TerminationReason
serverId string
serverType string
fieldMaskPaths []string
opt []Option
publicId *string // not updateable - db.ErrInvalidFieldMask
userId string // not updateable - db.ErrInvalidFieldMask
hostId string // not updateable - db.ErrInvalidFieldMask
targetId string // not updateable - db.ErrInvalidFieldMask
setId string // not updateable - db.ErrInvalidFieldMask
authTokenId string // not updateable - db.ErrInvalidFieldMask
scopeId string // not updateable - db.ErrInvalidFieldMask
}
tests := []struct {
name string
args args
wantRowsUpdate int
wantErr bool
wantIsError error
}{
{
name: "valid",
args: args{
bytesUp: 100,
bytesDown: 110,
terminationReason: Terminated,
serverId: newServerFunc(),
serverType: servers.ServerTypeWorker.String(),
fieldMaskPaths: []string{"BytesUp", "BytesDown", "TerminationReason", "ServerId", "ServerType"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "publicId",
args: args{
publicId: func() *string {
id, err := newId()
require.NoError(t, err)
return &id
}(),
fieldMaskPaths: []string{"PublicId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "userId",
args: args{
userId: func() string {
org, _ := iam.TestScopes(t, iamRepo)
u := iam.TestUser(t, iamRepo, org.PublicId)
return u.PublicId
}(),
fieldMaskPaths: []string{"UserId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "hostId",
args: args{
hostId: func() string {
_, proj := iam.TestScopes(t, iamRepo)
cats := static.TestCatalogs(t, conn, proj.PublicId, 1)
hosts := static.TestHosts(t, conn, cats[0].PublicId, 1)
return hosts[0].PublicId
}(),
fieldMaskPaths: []string{"HostId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "targetId",
args: args{
targetId: func() string {
_, proj := iam.TestScopes(t, iamRepo)
tcpTarget := target.TestTcpTarget(t, conn, proj.PublicId, "test target")
return tcpTarget.PublicId
}(),
fieldMaskPaths: []string{"TargetId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "setId",
args: args{
setId: func() string {
_, proj := iam.TestScopes(t, iamRepo)
cats := static.TestCatalogs(t, conn, proj.PublicId, 1)
sets := static.TestSets(t, conn, cats[0].PublicId, 1)
return sets[0].PublicId
}(),
fieldMaskPaths: []string{"SetId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "AuthTokenId",
args: args{
authTokenId: func() string {
ctx := context.Background()
org, _ := iam.TestScopes(t, iamRepo)
authMethod := password.TestAuthMethods(t, conn, org.PublicId, 1)[0]
acct := password.TestAccounts(t, conn, authMethod.GetPublicId(), 1)[0]
user, err := iamRepo.LookupUserWithLogin(ctx, acct.GetPublicId(), iam.WithAutoVivify(true))
require.NoError(t, err)
authTokenRepo, err := authtoken.NewRepository(rw, rw, kms)
require.NoError(t, err)
at, err := authTokenRepo.CreateAuthToken(ctx, user, acct.GetPublicId())
require.NoError(t, err)
return at.PublicId
}(),
fieldMaskPaths: []string{"AuthTokenId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
{
name: "ScopeId",
args: args{
scopeId: func() string {
_, proj := iam.TestScopes(t, iamRepo)
return proj.PublicId
}(),
fieldMaskPaths: []string{"ScopeId"},
},
wantErr: true,
wantRowsUpdate: 0,
wantIsError: db.ErrInvalidFieldMask,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
composedOf := TestSessionParams(t, conn, wrapper, iamRepo)
s := TestSession(t, conn, composedOf)
updateSession := allocSession()
updateSession.PublicId = s.PublicId
if tt.args.publicId != nil {
updateSession.PublicId = *tt.args.publicId
}
updateSession.BytesUp = tt.args.bytesUp
updateSession.BytesDown = tt.args.bytesDown
updateSession.ServerId = tt.args.serverId
updateSession.ServerType = tt.args.serverType
updateSession.TerminationReason = tt.args.terminationReason.String()
updateSession.Version = s.Version
afterUpdateSession, afterUpdateState, updatedRows, err := repo.UpdateSession(context.Background(), &updateSession, updateSession.Version, tt.args.fieldMaskPaths, tt.args.opt...)
if tt.wantErr {
require.Error(err)
if tt.wantIsError != nil {
assert.Truef(errors.Is(err, tt.wantIsError), "unexpected error: %s", err.Error())
}
assert.Nil(afterUpdateSession)
assert.Nil(afterUpdateState)
assert.Equal(0, updatedRows)
err = db.TestVerifyOplog(t, rw, s.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.Error(err)
assert.True(errors.Is(db.ErrRecordNotFound, err))
return
}
require.NoError(err)
assert.Equal(tt.wantRowsUpdate, updatedRows)
require.NotNil(afterUpdateSession)
require.NotNil(afterUpdateState)
switch tt.name {
case "valid-no-op":
assert.Equal(s.UpdateTime, afterUpdateSession.UpdateTime)
default:
assert.NotEqual(s.UpdateTime, afterUpdateSession.UpdateTime)
}
foundSession, foundStates, err := repo.LookupSession(context.Background(), s.PublicId)
require.NoError(err)
assert.True(proto.Equal(afterUpdateSession, foundSession))
dbassrt := dbassert.New(t, rw)
if tt.args.bytesUp == 0 {
dbassrt.IsNull(foundSession, "BytesUp")
}
dbassrt = dbassert.New(t, rw)
if tt.args.bytesDown == 0 {
dbassrt.IsNull(foundSession, "BytesDown")
}
if tt.args.serverId == "" {
dbassrt.IsNull(foundSession, "ServerId")
}
if tt.args.serverType == "" {
dbassrt.IsNull(foundSession, "ServerType")
}
assert.Equal(tt.args.bytesUp, foundSession.BytesUp)
assert.Equal(tt.args.bytesDown, foundSession.BytesDown)
assert.Equal(tt.args.terminationReason.String(), foundSession.TerminationReason)
assert.Equal(tt.args.serverId, foundSession.ServerId)
assert.Equal(tt.args.serverType, foundSession.ServerType)
err = db.TestVerifyOplog(t, rw, s.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.NoError(err)
require.Equal(1, len(foundStates))
assert.Equal(StatusPending.String(), foundStates[0].Status)
})
}
}

Loading…
Cancel
Save