mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
280 lines
9.0 KiB
280 lines
9.0 KiB
package session
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/servers"
|
|
)
|
|
|
|
// deadWorkerConnCloseMinGrace is the minimum allowable setting for
|
|
// the CloseConnectionsForDeadWorkers method. This is synced with
|
|
// the default server liveness setting.
|
|
var DeadWorkerConnCloseMinGrace = int(servers.DefaultLiveness.Seconds())
|
|
|
|
// LookupConnection will look up a connection in the repository and return the connection
|
|
// with its states. If the connection is not found, it will return nil, nil, nil.
|
|
// No options are currently supported.
|
|
func (r *Repository) LookupConnection(ctx context.Context, connectionId string, _ ...Option) (*Connection, []*ConnectionState, error) {
|
|
const op = "session.(Repository).LookupConnection"
|
|
if connectionId == "" {
|
|
return nil, nil, errors.New(errors.InvalidParameter, op, "missing connectionId id")
|
|
}
|
|
connection := AllocConnection()
|
|
connection.PublicId = connectionId
|
|
var states []*ConnectionState
|
|
_, err := r.writer.DoTx(
|
|
ctx,
|
|
db.StdRetryCnt,
|
|
db.ExpBackoff{},
|
|
func(read db.Reader, w db.Writer) error {
|
|
if err := read.LookupById(ctx, &connection); err != nil {
|
|
return errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("failed for %s", connectionId)))
|
|
}
|
|
var err error
|
|
if states, err = fetchConnectionStates(ctx, read, connectionId, db.WithOrder("start_time desc")); err != nil {
|
|
return errors.Wrap(err, op)
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
if err != nil {
|
|
if errors.IsNotFoundError(err) {
|
|
return nil, nil, nil
|
|
}
|
|
return nil, nil, errors.Wrap(err, op)
|
|
}
|
|
return &connection, states, nil
|
|
}
|
|
|
|
// ListConnectionsBySessionId will list connections by session ID. Supports the
|
|
// WithLimit and WithOrder options.
|
|
func (r *Repository) ListConnectionsBySessionId(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) {
|
|
const op = "session.(Repository).ListConnectionsBySessionId"
|
|
if sessionId == "" {
|
|
return nil, errors.New(errors.InvalidParameter, op, "no session ID supplied")
|
|
}
|
|
var connections []*Connection
|
|
err := r.list(ctx, &connections, "session_id = ?", []interface{}{sessionId}, opt...) // pass options, so WithLimit and WithOrder are supported
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, op)
|
|
}
|
|
return connections, nil
|
|
}
|
|
|
|
// DeleteConnection will delete a connection from the repository.
|
|
func (r *Repository) DeleteConnection(ctx context.Context, publicId string, _ ...Option) (int, error) {
|
|
const op = "session.(Repository).DeleteConnection"
|
|
if publicId == "" {
|
|
return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing public id")
|
|
}
|
|
connection := AllocConnection()
|
|
connection.PublicId = publicId
|
|
if err := r.reader.LookupByPublicId(ctx, &connection); err != nil {
|
|
return db.NoRowsAffected, errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("failed for %s", publicId)))
|
|
}
|
|
|
|
var rowsDeleted int
|
|
_, err := r.writer.DoTx(
|
|
ctx,
|
|
db.StdRetryCnt,
|
|
db.ExpBackoff{},
|
|
func(_ db.Reader, w db.Writer) error {
|
|
deleteConnection := connection.Clone()
|
|
var err error
|
|
rowsDeleted, err = w.Delete(
|
|
ctx,
|
|
deleteConnection,
|
|
)
|
|
if err != nil {
|
|
return errors.Wrap(err, op)
|
|
}
|
|
if rowsDeleted > 1 {
|
|
// return err, which will result in a rollback of the delete
|
|
return errors.New(errors.MultipleRecords, op, "more than 1 resource would have been deleted")
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
if err != nil {
|
|
return db.NoRowsAffected, errors.Wrap(err, op, errors.WithMsg(fmt.Sprintf("failed for %s", publicId)))
|
|
}
|
|
return rowsDeleted, nil
|
|
}
|
|
|
|
// CloseDeadConnectionsForWorker will run closeDeadConnectionsCte to look for
|
|
// connections that should be marked closed because they are no longer claimed
|
|
// by a server.
|
|
//
|
|
// The foundConns input should be the currently-claimed connections; the CTE
|
|
// uses a NOT IN clause to ensure these are excluded. It is not an error for
|
|
// this to be empty as the worker could claim no connections; in that case all
|
|
// connections will immediately transition to closed.
|
|
func (r *Repository) CloseDeadConnectionsForWorker(ctx context.Context, serverId string, foundConns []string) (int, error) {
|
|
const op = "session.(Repository).CloseDeadConnectionsForWorker"
|
|
if serverId == "" {
|
|
return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing server id")
|
|
}
|
|
|
|
args := make([]interface{}, 0, len(foundConns)+1)
|
|
args = append(args, serverId)
|
|
|
|
var publicIdStr string
|
|
if len(foundConns) > 0 {
|
|
publicIdStr = `public_id not in (%s) and`
|
|
params := make([]string, len(foundConns))
|
|
for i, connId := range foundConns {
|
|
params[i] = fmt.Sprintf("$%d", i+2) // Add one for server ID, and offsets start at 1
|
|
args = append(args, connId)
|
|
}
|
|
publicIdStr = fmt.Sprintf(publicIdStr, strings.Join(params, ","))
|
|
}
|
|
var rowsAffected int
|
|
_, err := r.writer.DoTx(
|
|
ctx,
|
|
db.StdRetryCnt,
|
|
db.ExpBackoff{},
|
|
func(reader db.Reader, w db.Writer) error {
|
|
var err error
|
|
rowsAffected, err = w.Exec(ctx, fmt.Sprintf(closeDeadConnectionsCte, publicIdStr), args)
|
|
if err != nil {
|
|
return errors.Wrap(err, op)
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
if err != nil {
|
|
return db.NoRowsAffected, errors.Wrap(err, op)
|
|
}
|
|
return rowsAffected, nil
|
|
}
|
|
|
|
type CloseConnectionsForDeadWorkersResult struct {
|
|
ServerId string
|
|
LastUpdateTime time.Time
|
|
NumberConnectionsClosed int
|
|
}
|
|
|
|
// CloseConnectionsForDeadWorkers will run
|
|
// closeConnectionsForDeadServersCte to look for connections that
|
|
// should be marked because they are on a server that is no longer
|
|
// sending status updates to the controller(s).
|
|
//
|
|
// The only input to the method is the grace period, in seconds.
|
|
func (r *Repository) CloseConnectionsForDeadWorkers(ctx context.Context, gracePeriod int) ([]CloseConnectionsForDeadWorkersResult, error) {
|
|
const op = "session.(Repository).CloseConnectionsForDeadWorkers"
|
|
if gracePeriod < DeadWorkerConnCloseMinGrace {
|
|
return nil, errors.New(
|
|
errors.InvalidParameter, op, fmt.Sprintf("gracePeriod must be at least %d seconds", DeadWorkerConnCloseMinGrace))
|
|
}
|
|
|
|
results := make([]CloseConnectionsForDeadWorkersResult, 0)
|
|
_, err := r.writer.DoTx(
|
|
ctx,
|
|
db.StdRetryCnt,
|
|
db.ExpBackoff{},
|
|
func(reader db.Reader, w db.Writer) error {
|
|
rows, err := w.Query(ctx, closeConnectionsForDeadServersCte, []interface{}{gracePeriod})
|
|
if err != nil {
|
|
return errors.Wrap(err, op)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var result CloseConnectionsForDeadWorkersResult
|
|
if err := w.ScanRows(rows, &result); err != nil {
|
|
return errors.Wrap(err, op)
|
|
}
|
|
|
|
results = append(results, result)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, op)
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
// ShouldCloseConnectionsOnWorker will run shouldCloseConnectionsCte to look
|
|
// for connections that the worker should close because they are currently
|
|
// reporting them as open incorrectly.
|
|
//
|
|
// The foundConns input here is used to filter closed connection states. This
|
|
// is further filtered against the filterSessions input, which is expected to
|
|
// be a set of sessions we've already submitted close requests for, so adding
|
|
// them again would be redundant.
|
|
//
|
|
// The returned map[string][]string is indexed by session ID.
|
|
func (r *Repository) ShouldCloseConnectionsOnWorker(ctx context.Context, foundConns, filterSessions []string) (map[string][]string, error) {
|
|
const op = "session.(Repository).ShouldCloseConnectionsOnWorker"
|
|
if len(foundConns) < 1 {
|
|
return nil, nil // nothing to do
|
|
}
|
|
|
|
args := make([]interface{}, 0, len(foundConns)+len(filterSessions))
|
|
|
|
// foundConns first
|
|
connsParams := make([]string, len(foundConns))
|
|
for i, connId := range foundConns {
|
|
connsParams[i] = fmt.Sprintf("$%d", i+1)
|
|
args = append(args, connId)
|
|
}
|
|
connsStr := strings.Join(connsParams, ",")
|
|
|
|
// then filterSessions
|
|
var sessionsStr string
|
|
if len(filterSessions) > 0 {
|
|
offset := len(foundConns) + 1
|
|
sessionsParams := make([]string, len(filterSessions))
|
|
for i, sessionId := range filterSessions {
|
|
sessionsParams[i] = fmt.Sprintf("$%d", i+offset)
|
|
args = append(args, sessionId)
|
|
}
|
|
|
|
const sessionIdFmtStr = `and session_id not in (%s)`
|
|
sessionsStr = fmt.Sprintf(sessionIdFmtStr, strings.Join(sessionsParams, ","))
|
|
}
|
|
|
|
rows, err := r.reader.Query(
|
|
ctx,
|
|
fmt.Sprintf(shouldCloseConnectionsCte, connsStr, sessionsStr),
|
|
args,
|
|
)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, op)
|
|
}
|
|
defer rows.Close()
|
|
|
|
result := make(map[string][]string)
|
|
for rows.Next() {
|
|
var connectionId, sessionId string
|
|
if err := rows.Scan(&connectionId, &sessionId); err != nil {
|
|
return nil, errors.Wrap(err, op)
|
|
}
|
|
|
|
result[sessionId] = append(result[sessionId], connectionId)
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func fetchConnectionStates(ctx context.Context, r db.Reader, connectionId string, opt ...db.Option) ([]*ConnectionState, error) {
|
|
const op = "session.fetchConnectionStates"
|
|
var states []*ConnectionState
|
|
if err := r.SearchWhere(ctx, &states, "connection_id = ?", []interface{}{connectionId}, opt...); err != nil {
|
|
return nil, errors.Wrap(err, op)
|
|
}
|
|
if len(states) == 0 {
|
|
return nil, nil
|
|
}
|
|
return states, nil
|
|
}
|