|
|
|
|
@ -279,17 +279,20 @@ func (r *Repository) TerminateSession(ctx context.Context, sessionId string, ses
|
|
|
|
|
return &updatedSession, returnedStates, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ConnectSession creates a connection in the repo with a state of "connected".
|
|
|
|
|
// Returns an ErrCancelledOrTerminatedSession error if a connection cannot be made
|
|
|
|
|
// because the session was cancelled or terminated.
|
|
|
|
|
func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) {
|
|
|
|
|
// ConnectWith.validate will check all the fields...
|
|
|
|
|
if err := c.validate(); err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("connect session: %w", err)
|
|
|
|
|
// AuthorizeConnection will check to see if a connection is allowed. Currently,
|
|
|
|
|
// that authorization checks:
|
|
|
|
|
// * the hasn't expired based on the session.Expiration
|
|
|
|
|
// * number of connections already created is less than session.ConnectionLimit
|
|
|
|
|
// If authorization is success, it creates/stores a new connection in the repo
|
|
|
|
|
// and returns it, along with it's states. If the authorization fails, it
|
|
|
|
|
// an error of ErrInvalidStateForOperation.
|
|
|
|
|
func (r *Repository) AuthorizeConnection(ctx context.Context, sessionId string) (*Connection, []*ConnectionState, error) {
|
|
|
|
|
if sessionId == "" {
|
|
|
|
|
return nil, nil, fmt.Errorf("authorize connection: missing session id: %w", db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
connectionId, err := newConnectionId()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("connect session: %w", err)
|
|
|
|
|
return nil, nil, fmt.Errorf("authorize connection: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
connection := AllocConnection()
|
|
|
|
|
@ -300,15 +303,15 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec
|
|
|
|
|
db.StdRetryCnt,
|
|
|
|
|
db.ExpBackoff{},
|
|
|
|
|
func(reader db.Reader, w db.Writer) error {
|
|
|
|
|
rowsAffected, err := w.Exec(createConnectionCte, []interface{}{c.SessionId, connectionId, c.ClientTcpAddress, c.ClientTcpPort, c.EndpointTcpAddress, c.EndpointTcpPort})
|
|
|
|
|
rowsAffected, err := w.Exec(authorizeConnectionCte, []interface{}{sessionId, connectionId})
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("unable to connect session %s: %w", c.SessionId, err)
|
|
|
|
|
return fmt.Errorf("unable to authorize connection %s: %w", sessionId, err)
|
|
|
|
|
}
|
|
|
|
|
if rowsAffected == 0 {
|
|
|
|
|
return fmt.Errorf("session %s is not active: %w", c.SessionId, ErrInvalidStateForOperation)
|
|
|
|
|
return fmt.Errorf("session %s is not authorized (not active, expired or connection limit reached): %w", sessionId, ErrInvalidStateForOperation)
|
|
|
|
|
}
|
|
|
|
|
if err := reader.LookupById(ctx, &connection); err != nil {
|
|
|
|
|
return fmt.Errorf("lookup session: failed %w for %s", err, c.SessionId)
|
|
|
|
|
return fmt.Errorf("lookup connection: failed %w for %s", err, sessionId)
|
|
|
|
|
}
|
|
|
|
|
connectionStates, err = fetchConnectionStates(ctx, reader, connectionId, db.WithOrder("start_time desc"))
|
|
|
|
|
if err != nil {
|
|
|
|
|
@ -317,6 +320,59 @@ func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connec
|
|
|
|
|
return nil
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("authorize connection: %w", err)
|
|
|
|
|
}
|
|
|
|
|
return &connection, connectionStates, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ConnectSession updates a connection in the repo with a state of "connected".
|
|
|
|
|
func (r *Repository) ConnectSession(ctx context.Context, c ConnectWith) (*Connection, []*ConnectionState, error) {
|
|
|
|
|
// ConnectWith.validate will check all the fields...
|
|
|
|
|
if err := c.validate(); err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("connect session: %w", err)
|
|
|
|
|
}
|
|
|
|
|
var connection Connection
|
|
|
|
|
var connectionStates []*ConnectionState
|
|
|
|
|
_, err := r.writer.DoTx(
|
|
|
|
|
ctx,
|
|
|
|
|
db.StdRetryCnt,
|
|
|
|
|
db.ExpBackoff{},
|
|
|
|
|
func(reader db.Reader, w db.Writer) error {
|
|
|
|
|
connection = AllocConnection()
|
|
|
|
|
connection.PublicId = c.ConnectionId
|
|
|
|
|
connection.ClientTcpAddress = c.ClientTcpAddress
|
|
|
|
|
connection.ClientTcpPort = c.ClientTcpPort
|
|
|
|
|
connection.EndpointTcpAddress = c.EndpointTcpAddress
|
|
|
|
|
connection.EndpointTcpPort = c.EndpointTcpPort
|
|
|
|
|
fieldMask := []string{
|
|
|
|
|
"ClientTcpAddress",
|
|
|
|
|
"ClientTcpPort",
|
|
|
|
|
"EndpointTcpAddress",
|
|
|
|
|
"EndpointTcpPort",
|
|
|
|
|
}
|
|
|
|
|
rowsUpdated, err := w.Update(ctx, &connection, fieldMask, nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err == nil && rowsUpdated > 1 {
|
|
|
|
|
// return err, which will result in a rollback of the update
|
|
|
|
|
return errors.New("error more than 1 connection would have been updated ")
|
|
|
|
|
}
|
|
|
|
|
newState, err := NewConnectionState(connection.PublicId, StatusConnected)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
if err := w.Create(ctx, newState); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
connectionStates, err = fetchConnectionStates(ctx, reader, c.ConnectionId, db.WithOrder("start_time desc"))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, nil, fmt.Errorf("connect session: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|