diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 87ca1e40e8..901c542c71 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -34,21 +34,18 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. if newSession.TargetId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing target id") } - if newSession.HostId == "" { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host id") - } if newSession.UserId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing user id") } - if newSession.HostSetId == "" { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host set id") - } if newSession.AuthTokenId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token id") } if newSession.ProjectId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing project id") } + if newSession.HostId == "" && newSession.HostSetId == "" && newSession.Endpoint == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host source and endpoint") + } if newSession.CtTofuToken != nil { return nil, errors.New(ctx, errors.InvalidParameter, op, "ct is not empty") } @@ -91,6 +88,26 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. return errors.Wrap(ctx, err, op) } + if newSession.HostSetId != "" && newSession.HostId != "" { + hs, err := NewSessionHostSetHost(newSession.PublicId, newSession.HostSetId, newSession.HostId) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if err = w.Create(ctx, hs); err != nil { + return errors.Wrap(ctx, err, op) + } + returnedSession.HostSetId = hs.HostSetId + returnedSession.HostId = hs.HostId + } else if newSession.Endpoint != "" { + ta, err := NewSessionTargetAddress(newSession.PublicId, newSession.TargetId) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if err = w.Create(ctx, ta); err != nil { + return errors.Wrap(ctx, err, op) + } + } + for _, cred := range newSession.DynamicCredentials { cred.SessionId = newSession.PublicId } @@ -197,6 +214,13 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. session.StaticCredentials = staticCreds } + sessionHostSetHost := AllocSessionHostSetHost() + if err := read.LookupWhere(ctx, sessionHostSetHost, "session_id = ?", []any{sessionId}); err != nil && !errors.IsNotFoundError(err) { + return errors.Wrap(ctx, err, op) + } + session.HostSetId = sessionHostSetHost.HostSetId + session.HostId = sessionHostSetHost.HostId + connections, err := fetchConnections(ctx, read, sessionId, db.WithOrder("create_time desc")) if err != nil { return errors.Wrap(ctx, err, op) diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 73db29fa2c..0524d55400 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -308,24 +308,21 @@ func TestRepository_CreateSession(t *testing.T) { wantErr: false, }, { - name: "empty-userId", + name: "valid-static-address", args: args{ - composedOf: func() ComposedOf { - c := TestSessionParams(t, conn, wrapper, iamRepo) - c.UserId = "" - return c - }(), + composedOf: TestSessionTargetAddressParams(t, conn, wrapper, iamRepo), workerAddresses: workerAddresses, }, - wantErr: true, - wantIsError: errors.InvalidParameter, + wantErr: false, }, { - name: "empty-hostId", + name: "empty-host-source-endpoint", args: args{ composedOf: func() ComposedOf { c := TestSessionParams(t, conn, wrapper, iamRepo) c.HostId = "" + c.HostSetId = "" + c.Endpoint = "" return c }(), workerAddresses: workerAddresses, @@ -334,11 +331,11 @@ func TestRepository_CreateSession(t *testing.T) { wantIsError: errors.InvalidParameter, }, { - name: "empty-targetId", + name: "empty-userId", args: args{ composedOf: func() ComposedOf { c := TestSessionParams(t, conn, wrapper, iamRepo) - c.TargetId = "" + c.UserId = "" return c }(), workerAddresses: workerAddresses, @@ -347,11 +344,11 @@ func TestRepository_CreateSession(t *testing.T) { wantIsError: errors.InvalidParameter, }, { - name: "empty-hostSetId", + name: "empty-targetId", args: args{ composedOf: func() ComposedOf { c := TestSessionParams(t, conn, wrapper, iamRepo) - c.HostSetId = "" + c.TargetId = "" return c }(), workerAddresses: workerAddresses, @@ -440,7 +437,7 @@ func TestRepository_CreateSession(t *testing.T) { HostSetId: tt.args.composedOf.HostSetId, AuthTokenId: tt.args.composedOf.AuthTokenId, ProjectId: tt.args.composedOf.ProjectId, - Endpoint: "tcp://127.0.0.1:22", + Endpoint: tt.args.composedOf.Endpoint, ExpirationTime: tt.args.composedOf.ExpirationTime, ConnectionLimit: tt.args.composedOf.ConnectionLimit, DynamicCredentials: tt.args.composedOf.DynamicCredentials, @@ -1254,19 +1251,19 @@ func TestRepository_CancelSessionViaFKNull(t *testing.T) { name: "canceled-only-once", cancelFk: func() cancelFk { s := setupFn() - var err error - s, err = repo.CancelSession(context.Background(), s.PublicId, s.Version) - require.NoError(t, err) - require.Equal(t, StatusCanceling, s.States[0].Status) - - t := &static.Host{ + h := &static.Host{ Host: &staticStore.Host{ PublicId: s.HostId, }, } + + var err error + s, err = repo.CancelSession(context.Background(), s.PublicId, s.Version) + require.NoError(t, err) + require.Equal(t, StatusCanceling, s.States[0].Status) return cancelFk{ s: s, - fkType: t, + fkType: h, } }(), }, diff --git a/internal/session/session.go b/internal/session/session.go index 3a89ecd18c..25de9bf822 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -68,12 +68,8 @@ type Session struct { PublicId string `json:"public_id,omitempty" gorm:"primary_key"` // UserId for the session UserId string `json:"user_id,omitempty" gorm:"default:null"` - // HostId of the session - HostId string `json:"host_id,omitempty" gorm:"default:null"` // TargetId for the session TargetId string `json:"target_id,omitempty" gorm:"default:null"` - // HostSetId for the session - HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"` // AuthTokenId for the session AuthTokenId string `json:"auth_token_id,omitempty" gorm:"default:null"` // ProjectId for the session @@ -124,6 +120,12 @@ type Session struct { // StaticCredentials for the session. StaticCredentials []*StaticCredential `gorm:"-"` + // HostSetId for the session + HostSetId string `gorm:"-"` + + // HostId of the session + HostId string `gorm:"-"` + // Connections for the session are for read only and are ignored during write operations Connections []*Connection `gorm:"-"` @@ -347,15 +349,9 @@ func (s *Session) validateNewSession() error { if s.UserId == "" { return errors.NewDeprecated(errors.InvalidParameter, op, "missing user id") } - if s.HostId == "" { - return errors.NewDeprecated(errors.InvalidParameter, op, "missing host id") - } if s.TargetId == "" { return errors.NewDeprecated(errors.InvalidParameter, op, "missing target id") } - if s.HostSetId == "" { - return errors.NewDeprecated(errors.InvalidParameter, op, "missing host set id") - } if s.AuthTokenId == "" { return errors.NewDeprecated(errors.InvalidParameter, op, "missing auth token id") } diff --git a/internal/session/session_host_set_host.go b/internal/session/session_host_set_host.go new file mode 100644 index 0000000000..3f82d9e2e0 --- /dev/null +++ b/internal/session/session_host_set_host.go @@ -0,0 +1,69 @@ +package session + +import "github.com/hashicorp/boundary/internal/errors" + +const ( + defaultSessionHostSetHostTableName = "session_host_set_host" +) + +// SessionHostSetHost contains information about a user's session with a target that has a host source association. +type SessionHostSetHost struct { + // SessionId of the session + SessionId string `json:"session_id,omitempty" gorm:"primary_key"` + // HostSetId of the session + HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"` + // HostId of the session + HostId string `json:"host_id,omitempty" gorm:"default:null"` + + tableName string `gorm:"-"` +} + +// NewSessionHostSetHost creates a new in memory session to host set & host association. +func NewSessionHostSetHost(sessionId, hostSetId, hostId string) (*SessionHostSetHost, error) { + const op = "session.NewSessionHostSetHost" + if sessionId == "" { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing session id") + } + if hostSetId == "" { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing host set id") + } + if hostId == "" { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing host id") + } + shs := &SessionHostSetHost{ + SessionId: sessionId, + HostSetId: hostSetId, + HostId: hostId, + } + return shs, nil +} + +// TableName returns the tablename to override the default gorm table name +func (s *SessionHostSetHost) TableName() string { + if s.tableName != "" { + return s.tableName + } + return defaultSessionHostSetHostTableName +} + +// SetTableName sets the tablename and satisfies the ReplayableMessage +// interface. If the caller attempts to set the name to "" the name will be +// reset to the default name. +func (s *SessionHostSetHost) SetTableName(n string) { + s.tableName = n +} + +// AllocSessionHostSet will allocate a SessionHostSetHost +func AllocSessionHostSetHost() *SessionHostSetHost { + return &SessionHostSetHost{} +} + +// Clone creates a clone of the SessionHostSetHost +func (s *SessionHostSetHost) Clone() any { + clone := &SessionHostSetHost{ + SessionId: s.SessionId, + HostSetId: s.HostSetId, + HostId: s.HostId, + } + return clone +} diff --git a/internal/session/session_target_address.go b/internal/session/session_target_address.go new file mode 100644 index 0000000000..83e290ce17 --- /dev/null +++ b/internal/session/session_target_address.go @@ -0,0 +1,62 @@ +package session + +import "github.com/hashicorp/boundary/internal/errors" + +const ( + defaultSessionTargetAddressTableName = "session_target_address" +) + +// SessionTargetAddress contains information about a user's session with a target that has a direct network address association. +type SessionTargetAddress struct { + // SessionId of the session + SessionId string `json:"session_id,omitempty" gorm:"primary_key"` + // TargetId of the session + TargetId string `json:"target_id,omitempty" gorm:"default:null"` + + tableName string `gorm:"-"` +} + +// NewSessionTargetAddress creates a new in memory session target address. +func NewSessionTargetAddress(sessionId, targetId string) (*SessionTargetAddress, error) { + const op = "sesssion.NewSessionTargetAddress" + if sessionId == "" { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing session id") + } + if targetId == "" { + return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing target id") + } + sta := &SessionTargetAddress{ + SessionId: sessionId, + TargetId: targetId, + } + return sta, nil +} + +// TableName returns the tablename to override the default gorm table name +func (s *SessionTargetAddress) TableName() string { + if s.tableName != "" { + return s.tableName + } + return defaultSessionTargetAddressTableName +} + +// SetTableName sets the tablename and satisfies the ReplayableMessage +// interface. If the caller attempts to set the name to "" the name will be +// reset to the default name. +func (s *SessionTargetAddress) SetTableName(n string) { + s.tableName = n +} + +// AllocSessionTargetAddress will allocate a SessionTargetAddress +func AllocSessionTargetAddress() *SessionTargetAddress { + return &SessionTargetAddress{} +} + +// Clone creates a clone of the SessionTargetAddress +func (s *SessionTargetAddress) Clone() any { + clone := &SessionTargetAddress{ + SessionId: s.SessionId, + TargetId: s.TargetId, + } + return clone +} diff --git a/internal/session/session_test.go b/internal/session/session_test.go index d90fba65a4..1f650e9da3 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -45,7 +45,7 @@ func TestSession_Create(t *testing.T) { wantCreateErr bool }{ { - name: "valid", + name: "valid-hostset-host", args: args{ composedOf: composedOf, opt: []Option{WithExpirationTime(exp)}, @@ -67,24 +67,40 @@ func TestSession_Create(t *testing.T) { create: true, }, { - name: "empty-userId", + name: "valid-target-address", args: args{ composedOf: func() ComposedOf { c := composedOf - c.UserId = "" + c.HostSetId = "" + c.HostId = "" return c }(), + opt: []Option{WithExpirationTime(exp)}, addresses: defaultAddresses, }, - wantErr: true, - wantIsErr: errors.InvalidParameter, + want: &Session{ + UserId: composedOf.UserId, + HostId: "", + HostSetId: "", + TargetId: composedOf.TargetId, + AuthTokenId: composedOf.AuthTokenId, + ProjectId: composedOf.ProjectId, + Endpoint: "tcp://127.0.0.1:22", + ExpirationTime: composedOf.ExpirationTime, + ConnectionLimit: composedOf.ConnectionLimit, + DynamicCredentials: composedOf.DynamicCredentials, + StaticCredentials: composedOf.StaticCredentials, + }, + create: true, }, { - name: "empty-hostId", + name: "invalid-missing-target-address-host-source", args: args{ composedOf: func() ComposedOf { c := composedOf + c.HostSetId = "" c.HostId = "" + c.Endpoint = "" return c }(), addresses: defaultAddresses, @@ -93,11 +109,11 @@ func TestSession_Create(t *testing.T) { wantIsErr: errors.InvalidParameter, }, { - name: "empty-targetId", + name: "empty-userId", args: args{ composedOf: func() ComposedOf { c := composedOf - c.TargetId = "" + c.UserId = "" return c }(), addresses: defaultAddresses, @@ -106,11 +122,11 @@ func TestSession_Create(t *testing.T) { wantIsErr: errors.InvalidParameter, }, { - name: "empty-hostSetId", + name: "empty-targetId", args: args{ composedOf: func() ComposedOf { c := composedOf - c.HostSetId = "" + c.TargetId = "" return c }(), addresses: defaultAddresses, diff --git a/internal/session/testing.go b/internal/session/testing.go index 99cfc81257..8bb4861766 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -67,6 +67,28 @@ func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State return s } +// TestSessionHostSetHost creates a test session to host set host association for the sessionId in the repository. +func TestSessionHostSetHost(t testing.TB, conn *db.DB, sessionId, hostSetId, hostId string) { + t.Helper() + require := require.New(t) + rw := db.New(conn) + hs, err := NewSessionHostSetHost(sessionId, hostSetId, hostId) + require.NoError(err) + err = rw.Create(context.Background(), hs) + require.NoError(err) +} + +// TestSessionTargetAddress creates a test session to target address association for the sessionId in the repository. +func TestSessionTargetAddress(t testing.TB, conn *db.DB, sessionId, targetId string) { + t.Helper() + require := require.New(t) + rw := db.New(conn) + ta, err := NewSessionTargetAddress(sessionId, targetId) + require.NoError(err) + err = rw.Create(context.Background(), ta) + require.NoError(err) +} + // TestSession creates a test session composed of c in the repository. Options // are passed into New, and withServerId is handled locally. func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c ComposedOf, opt ...Option) *Session { @@ -113,6 +135,12 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp require.NoError(err) } + if s.HostId != "" && s.HostSetId != "" { + TestSessionHostSetHost(t, conn, s.PublicId, s.HostSetId, s.HostId) + } else if s.Endpoint != "" { + TestSessionTargetAddress(t, conn, s.PublicId, s.TargetId) + } + ss, err := fetchStates(ctx, rw, s.PublicId, append(opts.withDbOpts, db.WithOrder("start_time desc"))...) require.NoError(err) s.States = ss @@ -120,6 +148,48 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp return s } +func TestSessionWithTargetAddress(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository, opt ...Option) *Session { + t.Helper() + composedOf := TestSessionTargetAddressParams(t, conn, wrapper, iamRepo) + future := timestamppb.New(time.Now().Add(time.Hour)) + exp := ×tamp.Timestamp{Timestamp: future} + return TestSession(t, conn, wrapper, composedOf, append(opt, WithExpirationTime(exp))...) +} + +func TestSessionTargetAddressParams(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository) ComposedOf { + t.Helper() + ctx := context.Background() + + require := require.New(t) + rw := db.New(conn) + org, proj := iam.TestScopes(t, iamRepo) + + tcpTarget := tcp.TestTarget(ctx, t, conn, proj.PublicId, "test target") + target.TestTargetAddress(t, conn, tcpTarget.GetPublicId(), "tcp://127.0.0.1:22") + + kms := kms.TestKms(t, conn, wrapper) + authMethod := password.TestAuthMethods(t, conn, org.PublicId, 1)[0] + acct := password.TestAccount(t, conn, authMethod.GetPublicId(), "name1") + user := iam.TestUser(t, iamRepo, org.PublicId, iam.WithAccountIds(acct.PublicId)) + + authTokenRepo, err := authtoken.NewRepository(rw, rw, kms) + require.NoError(err) + at, err := authTokenRepo.CreateAuthToken(ctx, user, acct.GetPublicId()) + require.NoError(err) + + expTime := timestamppb.Now() + expTime.Seconds += int64(tcpTarget.GetSessionMaxSeconds()) + return ComposedOf{ + UserId: user.PublicId, + TargetId: tcpTarget.GetPublicId(), + AuthTokenId: at.PublicId, + ProjectId: tcpTarget.GetProjectId(), + Endpoint: "tcp://127.0.0.1:22", + ExpirationTime: ×tamp.Timestamp{Timestamp: expTime}, + ConnectionLimit: tcpTarget.GetSessionConnectionLimit(), + } +} + // TestDefaultSession creates a test session in the repository using defaults. func TestDefaultSession(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository, opt ...Option) *Session { t.Helper()