mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
276 lines
8.7 KiB
276 lines
8.7 KiB
package worker
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
|
"github.com/hashicorp/boundary/internal/session"
|
|
)
|
|
|
|
const (
|
|
validateSessionTimeout = 90 * time.Second
|
|
)
|
|
|
|
type connInfo struct {
|
|
id string
|
|
connCtx context.Context
|
|
connCancel context.CancelFunc
|
|
status pbs.CONNECTIONSTATUS
|
|
closeTime time.Time
|
|
}
|
|
|
|
type sessionInfo struct {
|
|
sync.RWMutex
|
|
id string
|
|
sessionTls *tls.Config
|
|
status pbs.SESSIONSTATUS
|
|
lookupSessionResponse *pbs.LookupSessionResponse
|
|
connInfoMap map[string]*connInfo
|
|
}
|
|
|
|
func (w *Worker) getSessionTls(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
var sessionId string
|
|
switch {
|
|
case strings.HasPrefix(hello.ServerName, "s_"):
|
|
w.logger.Trace("got valid session in SNI", "session_id", hello.ServerName)
|
|
sessionId = hello.ServerName
|
|
default:
|
|
w.logger.Trace("invalid session in SNI", "session_id", hello.ServerName)
|
|
return nil, fmt.Errorf("could not find session ID in SNI")
|
|
}
|
|
|
|
rawConn := w.controllerSessionConn.Load()
|
|
if rawConn == nil {
|
|
w.logger.Trace("could not get a controller client", "session_id", sessionId)
|
|
return nil, errors.New("could not get a controller client")
|
|
}
|
|
conn, ok := rawConn.(pbs.SessionServiceClient)
|
|
if !ok {
|
|
w.logger.Trace("could not cast controller client to the real thing", "session_id", sessionId)
|
|
return nil, errors.New("could not cast atomic controller client to the real thing")
|
|
}
|
|
if conn == nil {
|
|
w.logger.Trace("controller client is nil", "session_id", sessionId)
|
|
return nil, errors.New("controller client is nil")
|
|
}
|
|
|
|
timeoutContext, cancel := context.WithTimeout(w.baseContext, validateSessionTimeout)
|
|
defer cancel()
|
|
|
|
w.logger.Trace("looking up session", "session_id", sessionId)
|
|
resp, err := conn.LookupSession(timeoutContext, &pbs.LookupSessionRequest{
|
|
SessionId: sessionId,
|
|
})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error validating session: %w", err)
|
|
}
|
|
|
|
if resp.GetExpiration().AsTime().Before(time.Now()) {
|
|
return nil, fmt.Errorf("session is expired")
|
|
}
|
|
|
|
parsedCert, err := x509.ParseCertificate(resp.GetAuthorization().Certificate)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error parsing session certificate: %w", err)
|
|
}
|
|
|
|
if len(parsedCert.DNSNames) != 1 {
|
|
return nil, fmt.Errorf("invalid length of DNS names (%d) in parsed certificate", len(parsedCert.DNSNames))
|
|
}
|
|
|
|
certPool := x509.NewCertPool()
|
|
certPool.AddCert(parsedCert)
|
|
|
|
tlsConf := &tls.Config{
|
|
Certificates: []tls.Certificate{
|
|
{
|
|
Certificate: [][]byte{resp.GetAuthorization().Certificate},
|
|
PrivateKey: ed25519.PrivateKey(resp.GetAuthorization().PrivateKey),
|
|
Leaf: parsedCert,
|
|
},
|
|
},
|
|
ServerName: parsedCert.DNSNames[0],
|
|
ClientAuth: tls.RequireAndVerifyClientCert,
|
|
ClientCAs: certPool,
|
|
MinVersion: tls.VersionTLS13,
|
|
}
|
|
|
|
si := &sessionInfo{
|
|
id: resp.GetAuthorization().GetSessionId(),
|
|
sessionTls: tlsConf,
|
|
lookupSessionResponse: resp,
|
|
status: resp.GetStatus(),
|
|
connInfoMap: make(map[string]*connInfo),
|
|
}
|
|
// TODO: Periodicially clean this up. We can't rely on things in here but
|
|
// not in cancellation because they could be on the way to being
|
|
// established. However, since cert lifetimes are short, we can simply range
|
|
// through and remove values that are expired.
|
|
actualSiRaw, loaded := w.sessionInfoMap.LoadOrStore(sessionId, si)
|
|
if loaded {
|
|
// Update the response to the latest
|
|
actualSi := actualSiRaw.(*sessionInfo)
|
|
actualSi.Lock()
|
|
actualSi.lookupSessionResponse = resp
|
|
actualSi.Unlock()
|
|
}
|
|
|
|
w.logger.Trace("returning TLS configuration", "session_id", sessionId)
|
|
return tlsConf, nil
|
|
}
|
|
|
|
func (w *Worker) activateSession(ctx context.Context, sessionId, tofuToken string, version uint32) (pbs.SESSIONSTATUS, error) {
|
|
rawConn := w.controllerSessionConn.Load()
|
|
if rawConn == nil {
|
|
return pbs.SESSIONSTATUS_SESSIONSTATUS_UNSPECIFIED, errors.New("could not get a controller client")
|
|
}
|
|
conn, ok := rawConn.(pbs.SessionServiceClient)
|
|
if !ok {
|
|
return pbs.SESSIONSTATUS_SESSIONSTATUS_UNSPECIFIED, errors.New("could not cast atomic controller client to the real thing")
|
|
}
|
|
if conn == nil {
|
|
return pbs.SESSIONSTATUS_SESSIONSTATUS_UNSPECIFIED, errors.New("controller client is nil")
|
|
}
|
|
|
|
resp, err := conn.ActivateSession(ctx, &pbs.ActivateSessionRequest{
|
|
SessionId: sessionId,
|
|
TofuToken: tofuToken,
|
|
Version: version,
|
|
WorkerId: w.conf.RawConfig.Worker.Name,
|
|
})
|
|
if err != nil {
|
|
return pbs.SESSIONSTATUS_SESSIONSTATUS_UNSPECIFIED, fmt.Errorf("error activating session: %w", err)
|
|
}
|
|
return resp.GetStatus(), nil
|
|
}
|
|
|
|
func (w *Worker) authorizeConnection(ctx context.Context, sessionId string) (*connInfo, int32, error) {
|
|
rawConn := w.controllerSessionConn.Load()
|
|
if rawConn == nil {
|
|
return nil, 0, errors.New("could not get a controller client")
|
|
}
|
|
conn, ok := rawConn.(pbs.SessionServiceClient)
|
|
if !ok {
|
|
return nil, 0, errors.New("could not cast atomic controller client to the real thing")
|
|
}
|
|
if conn == nil {
|
|
return nil, 0, errors.New("controller client is nil")
|
|
}
|
|
|
|
resp, err := conn.AuthorizeConnection(ctx, &pbs.AuthorizeConnectionRequest{
|
|
SessionId: sessionId,
|
|
})
|
|
if err != nil {
|
|
return nil, 0, fmt.Errorf("error authorizing connection: %w", err)
|
|
}
|
|
|
|
return &connInfo{
|
|
id: resp.ConnectionId,
|
|
status: resp.GetStatus(),
|
|
}, resp.GetConnectionsLeft(), nil
|
|
}
|
|
|
|
func (w *Worker) connectConnection(ctx context.Context, req *pbs.ConnectConnectionRequest) (pbs.CONNECTIONSTATUS, error) {
|
|
rawConn := w.controllerSessionConn.Load()
|
|
if rawConn == nil {
|
|
return pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_UNSPECIFIED, errors.New("could not get a controller client")
|
|
}
|
|
conn, ok := rawConn.(pbs.SessionServiceClient)
|
|
if !ok {
|
|
return pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_UNSPECIFIED, errors.New("could not cast atomic controller client to the real thing")
|
|
}
|
|
if conn == nil {
|
|
return pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_UNSPECIFIED, errors.New("controller client is nil")
|
|
}
|
|
|
|
resp, err := conn.ConnectConnection(ctx, req)
|
|
if err != nil {
|
|
return pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_UNSPECIFIED, err
|
|
}
|
|
|
|
if resp.GetStatus() != pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CONNECTED {
|
|
return pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_UNSPECIFIED, fmt.Errorf("unexpected state returned: %v", resp.GetStatus().String())
|
|
}
|
|
|
|
return resp.GetStatus(), nil
|
|
}
|
|
|
|
func (w *Worker) closeConnection(ctx context.Context, req *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
|
|
rawConn := w.controllerSessionConn.Load()
|
|
if rawConn == nil {
|
|
return nil, errors.New("could not get a controller client")
|
|
}
|
|
conn, ok := rawConn.(pbs.SessionServiceClient)
|
|
if !ok {
|
|
return nil, errors.New("could not cast atomic controller client to the real thing")
|
|
}
|
|
if conn == nil {
|
|
return nil, errors.New("controller client is nil")
|
|
}
|
|
|
|
resp, err := conn.CloseConnection(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(resp.GetCloseResponseData()) != len(req.GetCloseRequestData()) {
|
|
w.logger.Warn("mismatched number of states returned on connection closed", "expected", len(req.GetCloseRequestData()), "got", len(resp.GetCloseResponseData()))
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (w *Worker) closeConnections(ctx context.Context, closeMap map[string]string) error {
|
|
w.logger.Trace("marking connections as closed", "session_and_connection_ids", fmt.Sprintf("%#v", closeMap))
|
|
|
|
closeData := make([]*pbs.CloseConnectionRequestData, 0, len(closeMap))
|
|
for connId := range closeMap {
|
|
closeData = append(closeData, &pbs.CloseConnectionRequestData{
|
|
ConnectionId: connId,
|
|
Reason: session.UnknownReason.String(),
|
|
})
|
|
}
|
|
closeInfo := &pbs.CloseConnectionRequest{
|
|
CloseRequestData: closeData,
|
|
}
|
|
|
|
connStatus, err := w.closeConnection(ctx, closeInfo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
closedIds := make([]string, 0, len(connStatus.GetCloseResponseData()))
|
|
|
|
// Here we build a reverse map from closeMap, that is, session ID to
|
|
// connection IDs, for more efficient locking
|
|
revMap := make(map[string][]*pbs.CloseConnectionResponseData)
|
|
for _, v := range connStatus.GetCloseResponseData() {
|
|
revMap[closeMap[v.GetConnectionId()]] = append(revMap[closeMap[v.GetConnectionId()]], v)
|
|
}
|
|
for k, v := range revMap {
|
|
siRaw, ok := w.sessionInfoMap.Load(k)
|
|
if !ok {
|
|
w.logger.Warn("could not find session ID in info map after closing connections", "session_id", k)
|
|
continue
|
|
}
|
|
si := siRaw.(*sessionInfo)
|
|
si.Lock()
|
|
for _, connResult := range v {
|
|
ci := si.connInfoMap[connResult.GetConnectionId()]
|
|
ci.status = connResult.GetStatus()
|
|
if ci.status == pbs.CONNECTIONSTATUS_CONNECTIONSTATUS_CLOSED {
|
|
ci.closeTime = time.Now()
|
|
}
|
|
}
|
|
si.Unlock()
|
|
}
|
|
w.logger.Trace("connections successfully marked closed", "connection_ids", closedIds)
|
|
return nil
|
|
}
|