feat(session): Support connecting to an address associated to a Target

pull/2613/head
Damian Debkowski 3 years ago committed by Hugo Vieira
parent 760671f223
commit bca7c371d8

@ -34,21 +34,18 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
if newSession.TargetId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing target id")
}
if newSession.HostId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host id")
}
if newSession.UserId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing user id")
}
if newSession.HostSetId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host set id")
}
if newSession.AuthTokenId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token id")
}
if newSession.ProjectId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing project id")
}
if newSession.HostId == "" && newSession.HostSetId == "" && newSession.Endpoint == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing host source and endpoint")
}
if newSession.CtTofuToken != nil {
return nil, errors.New(ctx, errors.InvalidParameter, op, "ct is not empty")
}
@ -91,6 +88,26 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
return errors.Wrap(ctx, err, op)
}
if newSession.HostSetId != "" && newSession.HostId != "" {
hs, err := NewSessionHostSetHost(newSession.PublicId, newSession.HostSetId, newSession.HostId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if err = w.Create(ctx, hs); err != nil {
return errors.Wrap(ctx, err, op)
}
returnedSession.HostSetId = hs.HostSetId
returnedSession.HostId = hs.HostId
} else if newSession.Endpoint != "" {
ta, err := NewSessionTargetAddress(newSession.PublicId, newSession.TargetId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if err = w.Create(ctx, ta); err != nil {
return errors.Wrap(ctx, err, op)
}
}
for _, cred := range newSession.DynamicCredentials {
cred.SessionId = newSession.PublicId
}
@ -197,6 +214,13 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
session.StaticCredentials = staticCreds
}
sessionHostSetHost := AllocSessionHostSetHost()
if err := read.LookupWhere(ctx, sessionHostSetHost, "session_id = ?", []any{sessionId}); err != nil && !errors.IsNotFoundError(err) {
return errors.Wrap(ctx, err, op)
}
session.HostSetId = sessionHostSetHost.HostSetId
session.HostId = sessionHostSetHost.HostId
connections, err := fetchConnections(ctx, read, sessionId, db.WithOrder("create_time desc"))
if err != nil {
return errors.Wrap(ctx, err, op)

@ -308,24 +308,21 @@ func TestRepository_CreateSession(t *testing.T) {
wantErr: false,
},
{
name: "empty-userId",
name: "valid-static-address",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.UserId = ""
return c
}(),
composedOf: TestSessionTargetAddressParams(t, conn, wrapper, iamRepo),
workerAddresses: workerAddresses,
},
wantErr: true,
wantIsError: errors.InvalidParameter,
wantErr: false,
},
{
name: "empty-hostId",
name: "empty-host-source-endpoint",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.HostId = ""
c.HostSetId = ""
c.Endpoint = ""
return c
}(),
workerAddresses: workerAddresses,
@ -334,11 +331,11 @@ func TestRepository_CreateSession(t *testing.T) {
wantIsError: errors.InvalidParameter,
},
{
name: "empty-targetId",
name: "empty-userId",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.TargetId = ""
c.UserId = ""
return c
}(),
workerAddresses: workerAddresses,
@ -347,11 +344,11 @@ func TestRepository_CreateSession(t *testing.T) {
wantIsError: errors.InvalidParameter,
},
{
name: "empty-hostSetId",
name: "empty-targetId",
args: args{
composedOf: func() ComposedOf {
c := TestSessionParams(t, conn, wrapper, iamRepo)
c.HostSetId = ""
c.TargetId = ""
return c
}(),
workerAddresses: workerAddresses,
@ -440,7 +437,7 @@ func TestRepository_CreateSession(t *testing.T) {
HostSetId: tt.args.composedOf.HostSetId,
AuthTokenId: tt.args.composedOf.AuthTokenId,
ProjectId: tt.args.composedOf.ProjectId,
Endpoint: "tcp://127.0.0.1:22",
Endpoint: tt.args.composedOf.Endpoint,
ExpirationTime: tt.args.composedOf.ExpirationTime,
ConnectionLimit: tt.args.composedOf.ConnectionLimit,
DynamicCredentials: tt.args.composedOf.DynamicCredentials,
@ -1254,19 +1251,19 @@ func TestRepository_CancelSessionViaFKNull(t *testing.T) {
name: "canceled-only-once",
cancelFk: func() cancelFk {
s := setupFn()
var err error
s, err = repo.CancelSession(context.Background(), s.PublicId, s.Version)
require.NoError(t, err)
require.Equal(t, StatusCanceling, s.States[0].Status)
t := &static.Host{
h := &static.Host{
Host: &staticStore.Host{
PublicId: s.HostId,
},
}
var err error
s, err = repo.CancelSession(context.Background(), s.PublicId, s.Version)
require.NoError(t, err)
require.Equal(t, StatusCanceling, s.States[0].Status)
return cancelFk{
s: s,
fkType: t,
fkType: h,
}
}(),
},

@ -68,12 +68,8 @@ type Session struct {
PublicId string `json:"public_id,omitempty" gorm:"primary_key"`
// UserId for the session
UserId string `json:"user_id,omitempty" gorm:"default:null"`
// HostId of the session
HostId string `json:"host_id,omitempty" gorm:"default:null"`
// TargetId for the session
TargetId string `json:"target_id,omitempty" gorm:"default:null"`
// HostSetId for the session
HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"`
// AuthTokenId for the session
AuthTokenId string `json:"auth_token_id,omitempty" gorm:"default:null"`
// ProjectId for the session
@ -124,6 +120,12 @@ type Session struct {
// StaticCredentials for the session.
StaticCredentials []*StaticCredential `gorm:"-"`
// HostSetId for the session
HostSetId string `gorm:"-"`
// HostId of the session
HostId string `gorm:"-"`
// Connections for the session are for read only and are ignored during write operations
Connections []*Connection `gorm:"-"`
@ -347,15 +349,9 @@ func (s *Session) validateNewSession() error {
if s.UserId == "" {
return errors.NewDeprecated(errors.InvalidParameter, op, "missing user id")
}
if s.HostId == "" {
return errors.NewDeprecated(errors.InvalidParameter, op, "missing host id")
}
if s.TargetId == "" {
return errors.NewDeprecated(errors.InvalidParameter, op, "missing target id")
}
if s.HostSetId == "" {
return errors.NewDeprecated(errors.InvalidParameter, op, "missing host set id")
}
if s.AuthTokenId == "" {
return errors.NewDeprecated(errors.InvalidParameter, op, "missing auth token id")
}

@ -0,0 +1,69 @@
package session
import "github.com/hashicorp/boundary/internal/errors"
const (
defaultSessionHostSetHostTableName = "session_host_set_host"
)
// SessionHostSetHost contains information about a user's session with a target that has a host source association.
type SessionHostSetHost struct {
// SessionId of the session
SessionId string `json:"session_id,omitempty" gorm:"primary_key"`
// HostSetId of the session
HostSetId string `json:"host_set_id,omitempty" gorm:"default:null"`
// HostId of the session
HostId string `json:"host_id,omitempty" gorm:"default:null"`
tableName string `gorm:"-"`
}
// NewSessionHostSetHost creates a new in memory session to host set & host association.
func NewSessionHostSetHost(sessionId, hostSetId, hostId string) (*SessionHostSetHost, error) {
const op = "session.NewSessionHostSetHost"
if sessionId == "" {
return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing session id")
}
if hostSetId == "" {
return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing host set id")
}
if hostId == "" {
return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing host id")
}
shs := &SessionHostSetHost{
SessionId: sessionId,
HostSetId: hostSetId,
HostId: hostId,
}
return shs, nil
}
// TableName returns the tablename to override the default gorm table name
func (s *SessionHostSetHost) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultSessionHostSetHostTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (s *SessionHostSetHost) SetTableName(n string) {
s.tableName = n
}
// AllocSessionHostSet will allocate a SessionHostSetHost
func AllocSessionHostSetHost() *SessionHostSetHost {
return &SessionHostSetHost{}
}
// Clone creates a clone of the SessionHostSetHost
func (s *SessionHostSetHost) Clone() any {
clone := &SessionHostSetHost{
SessionId: s.SessionId,
HostSetId: s.HostSetId,
HostId: s.HostId,
}
return clone
}

@ -0,0 +1,62 @@
package session
import "github.com/hashicorp/boundary/internal/errors"
const (
defaultSessionTargetAddressTableName = "session_target_address"
)
// SessionTargetAddress contains information about a user's session with a target that has a direct network address association.
type SessionTargetAddress struct {
// SessionId of the session
SessionId string `json:"session_id,omitempty" gorm:"primary_key"`
// TargetId of the session
TargetId string `json:"target_id,omitempty" gorm:"default:null"`
tableName string `gorm:"-"`
}
// NewSessionTargetAddress creates a new in memory session target address.
func NewSessionTargetAddress(sessionId, targetId string) (*SessionTargetAddress, error) {
const op = "sesssion.NewSessionTargetAddress"
if sessionId == "" {
return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing session id")
}
if targetId == "" {
return nil, errors.NewDeprecated(errors.InvalidParameter, op, "missing target id")
}
sta := &SessionTargetAddress{
SessionId: sessionId,
TargetId: targetId,
}
return sta, nil
}
// TableName returns the tablename to override the default gorm table name
func (s *SessionTargetAddress) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultSessionTargetAddressTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (s *SessionTargetAddress) SetTableName(n string) {
s.tableName = n
}
// AllocSessionTargetAddress will allocate a SessionTargetAddress
func AllocSessionTargetAddress() *SessionTargetAddress {
return &SessionTargetAddress{}
}
// Clone creates a clone of the SessionTargetAddress
func (s *SessionTargetAddress) Clone() any {
clone := &SessionTargetAddress{
SessionId: s.SessionId,
TargetId: s.TargetId,
}
return clone
}

@ -45,7 +45,7 @@ func TestSession_Create(t *testing.T) {
wantCreateErr bool
}{
{
name: "valid",
name: "valid-hostset-host",
args: args{
composedOf: composedOf,
opt: []Option{WithExpirationTime(exp)},
@ -67,24 +67,40 @@ func TestSession_Create(t *testing.T) {
create: true,
},
{
name: "empty-userId",
name: "valid-target-address",
args: args{
composedOf: func() ComposedOf {
c := composedOf
c.UserId = ""
c.HostSetId = ""
c.HostId = ""
return c
}(),
opt: []Option{WithExpirationTime(exp)},
addresses: defaultAddresses,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
want: &Session{
UserId: composedOf.UserId,
HostId: "",
HostSetId: "",
TargetId: composedOf.TargetId,
AuthTokenId: composedOf.AuthTokenId,
ProjectId: composedOf.ProjectId,
Endpoint: "tcp://127.0.0.1:22",
ExpirationTime: composedOf.ExpirationTime,
ConnectionLimit: composedOf.ConnectionLimit,
DynamicCredentials: composedOf.DynamicCredentials,
StaticCredentials: composedOf.StaticCredentials,
},
create: true,
},
{
name: "empty-hostId",
name: "invalid-missing-target-address-host-source",
args: args{
composedOf: func() ComposedOf {
c := composedOf
c.HostSetId = ""
c.HostId = ""
c.Endpoint = ""
return c
}(),
addresses: defaultAddresses,
@ -93,11 +109,11 @@ func TestSession_Create(t *testing.T) {
wantIsErr: errors.InvalidParameter,
},
{
name: "empty-targetId",
name: "empty-userId",
args: args{
composedOf: func() ComposedOf {
c := composedOf
c.TargetId = ""
c.UserId = ""
return c
}(),
addresses: defaultAddresses,
@ -106,11 +122,11 @@ func TestSession_Create(t *testing.T) {
wantIsErr: errors.InvalidParameter,
},
{
name: "empty-hostSetId",
name: "empty-targetId",
args: args{
composedOf: func() ComposedOf {
c := composedOf
c.HostSetId = ""
c.TargetId = ""
return c
}(),
addresses: defaultAddresses,

@ -67,6 +67,28 @@ func TestState(t testing.TB, conn *db.DB, sessionId string, state Status) *State
return s
}
// TestSessionHostSetHost creates a test session to host set host association for the sessionId in the repository.
func TestSessionHostSetHost(t testing.TB, conn *db.DB, sessionId, hostSetId, hostId string) {
t.Helper()
require := require.New(t)
rw := db.New(conn)
hs, err := NewSessionHostSetHost(sessionId, hostSetId, hostId)
require.NoError(err)
err = rw.Create(context.Background(), hs)
require.NoError(err)
}
// TestSessionTargetAddress creates a test session to target address association for the sessionId in the repository.
func TestSessionTargetAddress(t testing.TB, conn *db.DB, sessionId, targetId string) {
t.Helper()
require := require.New(t)
rw := db.New(conn)
ta, err := NewSessionTargetAddress(sessionId, targetId)
require.NoError(err)
err = rw.Create(context.Background(), ta)
require.NoError(err)
}
// TestSession creates a test session composed of c in the repository. Options
// are passed into New, and withServerId is handled locally.
func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c ComposedOf, opt ...Option) *Session {
@ -113,6 +135,12 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp
require.NoError(err)
}
if s.HostId != "" && s.HostSetId != "" {
TestSessionHostSetHost(t, conn, s.PublicId, s.HostSetId, s.HostId)
} else if s.Endpoint != "" {
TestSessionTargetAddress(t, conn, s.PublicId, s.TargetId)
}
ss, err := fetchStates(ctx, rw, s.PublicId, append(opts.withDbOpts, db.WithOrder("start_time desc"))...)
require.NoError(err)
s.States = ss
@ -120,6 +148,48 @@ func TestSession(t testing.TB, conn *db.DB, rootWrapper wrapping.Wrapper, c Comp
return s
}
func TestSessionWithTargetAddress(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository, opt ...Option) *Session {
t.Helper()
composedOf := TestSessionTargetAddressParams(t, conn, wrapper, iamRepo)
future := timestamppb.New(time.Now().Add(time.Hour))
exp := &timestamp.Timestamp{Timestamp: future}
return TestSession(t, conn, wrapper, composedOf, append(opt, WithExpirationTime(exp))...)
}
func TestSessionTargetAddressParams(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository) ComposedOf {
t.Helper()
ctx := context.Background()
require := require.New(t)
rw := db.New(conn)
org, proj := iam.TestScopes(t, iamRepo)
tcpTarget := tcp.TestTarget(ctx, t, conn, proj.PublicId, "test target")
target.TestTargetAddress(t, conn, tcpTarget.GetPublicId(), "tcp://127.0.0.1:22")
kms := kms.TestKms(t, conn, wrapper)
authMethod := password.TestAuthMethods(t, conn, org.PublicId, 1)[0]
acct := password.TestAccount(t, conn, authMethod.GetPublicId(), "name1")
user := iam.TestUser(t, iamRepo, org.PublicId, iam.WithAccountIds(acct.PublicId))
authTokenRepo, err := authtoken.NewRepository(rw, rw, kms)
require.NoError(err)
at, err := authTokenRepo.CreateAuthToken(ctx, user, acct.GetPublicId())
require.NoError(err)
expTime := timestamppb.Now()
expTime.Seconds += int64(tcpTarget.GetSessionMaxSeconds())
return ComposedOf{
UserId: user.PublicId,
TargetId: tcpTarget.GetPublicId(),
AuthTokenId: at.PublicId,
ProjectId: tcpTarget.GetProjectId(),
Endpoint: "tcp://127.0.0.1:22",
ExpirationTime: &timestamp.Timestamp{Timestamp: expTime},
ConnectionLimit: tcpTarget.GetSessionConnectionLimit(),
}
}
// TestDefaultSession creates a test session in the repository using defaults.
func TestDefaultSession(t testing.TB, conn *db.DB, wrapper wrapping.Wrapper, iamRepo *iam.Repository, opt ...Option) *Session {
t.Helper()

Loading…
Cancel
Save