mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
358 lines
12 KiB
358 lines
12 KiB
package workers
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/internal/gen/controller/api/resources/targets"
|
|
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/hashicorp/boundary/internal/servers/controller/common"
|
|
"github.com/hashicorp/boundary/internal/session"
|
|
"github.com/hashicorp/boundary/internal/types/resource"
|
|
"github.com/hashicorp/go-hclog"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
type workerServiceServer struct {
|
|
logger hclog.Logger
|
|
serversRepoFn common.ServersRepoFactory
|
|
sessionRepoFn common.SessionRepoFactory
|
|
updateTimes *sync.Map
|
|
kms *kms.Kms
|
|
}
|
|
|
|
func NewWorkerServiceServer(
|
|
logger hclog.Logger,
|
|
serversRepoFn common.ServersRepoFactory,
|
|
sessionRepoFn common.SessionRepoFactory,
|
|
updateTimes *sync.Map,
|
|
kms *kms.Kms) *workerServiceServer {
|
|
return &workerServiceServer{
|
|
logger: logger,
|
|
serversRepoFn: serversRepoFn,
|
|
sessionRepoFn: sessionRepoFn,
|
|
updateTimes: updateTimes,
|
|
kms: kms,
|
|
}
|
|
}
|
|
|
|
var _ pbs.SessionServiceServer = &workerServiceServer{}
|
|
var _ pbs.ServerCoordinationServiceServer = &workerServiceServer{}
|
|
|
|
func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusRequest) (*pbs.StatusResponse, error) {
|
|
ws.logger.Trace("got status request from worker", "name", req.Worker.Name, "address", req.Worker.Address, "jobs", req.GetJobs())
|
|
ws.updateTimes.Store(req.Worker.Name, time.Now())
|
|
repo, err := ws.serversRepoFn()
|
|
if err != nil {
|
|
ws.logger.Error("error getting servers repo", "error", err)
|
|
return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error aqcuiring repo to store worker status: %v", err)
|
|
}
|
|
req.Worker.Type = resource.Worker.String()
|
|
controllers, _, err := repo.UpsertServer(ctx, req.Worker)
|
|
if err != nil {
|
|
ws.logger.Error("error storing worker status", "error", err)
|
|
return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error storing worker status: %v", err)
|
|
}
|
|
ret := &pbs.StatusResponse{
|
|
Controllers: controllers,
|
|
}
|
|
|
|
// Happy path
|
|
if len(req.GetJobs()) == 0 {
|
|
return ret, nil
|
|
}
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error getting session repo: %v", err)
|
|
}
|
|
|
|
for _, jobStatus := range req.GetJobs() {
|
|
switch jobStatus.Job.GetType() {
|
|
// Check for session cancelation
|
|
case pbs.JOBTYPE_JOBTYPE_SESSION:
|
|
si := jobStatus.GetJob().GetSessionInfo()
|
|
if si == nil {
|
|
return nil, status.Error(codes.Internal, "Error getting session info at status time")
|
|
}
|
|
switch si.Status {
|
|
case pbs.SESSIONSTATUS_SESSIONSTATUS_CANCELING,
|
|
pbs.SESSIONSTATUS_SESSIONSTATUS_TERMINATED:
|
|
// No need to see about canceling anything
|
|
continue
|
|
}
|
|
sessionId := si.GetSessionId()
|
|
sessionInfo, _, err := sessRepo.LookupSession(ctx, sessionId)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error looking up session with id %s: %v", sessionId, err)
|
|
}
|
|
if sessionInfo == nil {
|
|
return nil, status.Errorf(codes.Internal, "Unknown session ID %s at status time.", sessionId)
|
|
}
|
|
if len(sessionInfo.States) == 0 {
|
|
return nil, status.Error(codes.Internal, "Empty session states during lookup at status time.")
|
|
}
|
|
// If the session from the DB is in canceling status, and we're
|
|
// here, it means the job is in pending or active; cancel it. If
|
|
// it's in termianted status something went wrong and we're
|
|
// mismatched, so ensure we cancel it also.
|
|
currState := sessionInfo.States[0].Status
|
|
if currState.ProtoVal() != si.Status {
|
|
switch currState {
|
|
case session.StatusCanceling,
|
|
session.StatusTerminated:
|
|
// If we're here the job is pending or active so we do want
|
|
// to actually send a change request
|
|
ret.JobsRequests = append(ret.JobsRequests, &pbs.JobChangeRequest{
|
|
Job: &pbs.Job{
|
|
Type: pbs.JOBTYPE_JOBTYPE_SESSION,
|
|
JobInfo: &pbs.Job_SessionInfo{
|
|
SessionInfo: &pbs.SessionJobInfo{
|
|
SessionId: sessionId,
|
|
Status: currState.ProtoVal(),
|
|
},
|
|
},
|
|
},
|
|
RequestType: pbs.CHANGETYPE_CHANGETYPE_UPDATE_STATE,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return ret, nil
|
|
}
|
|
|
|
func (ws *workerServiceServer) LookupSession(ctx context.Context, req *pbs.LookupSessionRequest) (*pbs.LookupSessionResponse, error) {
|
|
ws.logger.Trace("got validate session request from worker", "session_id", req.GetSessionId())
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error getting session repo: %v", err)
|
|
}
|
|
|
|
sessionInfo, authzSummary, err := sessRepo.LookupSession(ctx, req.GetSessionId())
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error looking up session: %v", err)
|
|
}
|
|
if sessionInfo == nil {
|
|
return nil, status.Error(codes.PermissionDenied, "Unknown session ID.")
|
|
}
|
|
if len(sessionInfo.States) == 0 {
|
|
return nil, status.Error(codes.Internal, "Empty session states during lookup.")
|
|
}
|
|
|
|
resp := &pbs.LookupSessionResponse{
|
|
Authorization: &targets.SessionAuthorizationData{
|
|
SessionId: sessionInfo.GetPublicId(),
|
|
Certificate: sessionInfo.Certificate,
|
|
},
|
|
Status: sessionInfo.States[0].Status.ProtoVal(),
|
|
Version: sessionInfo.Version,
|
|
TofuToken: string(sessionInfo.TofuToken),
|
|
Endpoint: sessionInfo.Endpoint,
|
|
Expiration: sessionInfo.ExpirationTime.Timestamp,
|
|
ConnectionLimit: sessionInfo.ConnectionLimit,
|
|
ConnectionsLeft: authzSummary.ConnectionLimit,
|
|
HostId: sessionInfo.HostId,
|
|
HostSetId: sessionInfo.HostSetId,
|
|
TargetId: sessionInfo.TargetId,
|
|
UserId: sessionInfo.UserId,
|
|
}
|
|
if resp.ConnectionsLeft != -1 {
|
|
resp.ConnectionsLeft -= int32(authzSummary.CurrentConnectionCount)
|
|
}
|
|
|
|
wrapper, err := ws.kms.GetWrapper(ctx, sessionInfo.ScopeId, kms.KeyPurposeSessions)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error getting sessions wrapper: %v", err)
|
|
}
|
|
|
|
// Derive the private key, which should match. Deriving on both ends allows
|
|
// us to not store it in the DB.
|
|
_, resp.Authorization.PrivateKey, err = session.DeriveED25519Key(wrapper, sessionInfo.UserId, sessionInfo.GetPublicId())
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "Error deriving session key: %v", err)
|
|
}
|
|
|
|
return resp, nil
|
|
}
|
|
|
|
func (ws *workerServiceServer) ActivateSession(ctx context.Context, req *pbs.ActivateSessionRequest) (*pbs.ActivateSessionResponse, error) {
|
|
ws.logger.Trace("got activate session request from worker", "session_id", req.GetSessionId())
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
|
|
}
|
|
|
|
sessionInfo, sessionStates, err := sessRepo.ActivateSession(
|
|
ctx,
|
|
req.GetSessionId(),
|
|
req.GetVersion(),
|
|
req.GetWorkerId(),
|
|
resource.Worker.String(),
|
|
[]byte(req.GetTofuToken()))
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "error looking up session: %v", err)
|
|
}
|
|
if sessionInfo == nil {
|
|
return nil, status.Error(codes.PermissionDenied, "Unknown session ID.")
|
|
}
|
|
if len(sessionStates) == 0 {
|
|
return nil, status.Error(codes.Internal, "Invalid session state in activate response.")
|
|
}
|
|
|
|
ws.logger.Info("session activated",
|
|
"session_id", sessionInfo.PublicId,
|
|
"target_id", sessionInfo.TargetId,
|
|
"user_id", sessionInfo.UserId,
|
|
"host_set_id", sessionInfo.HostSetId,
|
|
"host_id", sessionInfo.HostId)
|
|
|
|
return &pbs.ActivateSessionResponse{
|
|
Status: sessionStates[0].Status.ProtoVal(),
|
|
}, nil
|
|
}
|
|
|
|
func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs.AuthorizeConnectionRequest) (*pbs.AuthorizeConnectionResponse, error) {
|
|
ws.logger.Trace("got authorize connection request from worker", "session_id", req.GetSessionId())
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
|
|
}
|
|
|
|
connectionInfo, connStates, authzSummary, err := sessRepo.AuthorizeConnection(ctx, req.GetSessionId())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if connectionInfo == nil {
|
|
return nil, status.Error(codes.Internal, "Invalid authorize connection response.")
|
|
}
|
|
if len(connStates) == 0 {
|
|
return nil, status.Error(codes.Internal, "Invalid connection state in authorize response.")
|
|
}
|
|
|
|
ret := &pbs.AuthorizeConnectionResponse{
|
|
ConnectionId: connectionInfo.GetPublicId(),
|
|
Status: connStates[0].Status.ProtoVal(),
|
|
ConnectionsLeft: authzSummary.ConnectionLimit,
|
|
}
|
|
if ret.ConnectionsLeft != -1 {
|
|
ret.ConnectionsLeft -= int32(authzSummary.CurrentConnectionCount)
|
|
}
|
|
|
|
ws.logger.Info("authorized connection",
|
|
"session_id", req.GetSessionId(),
|
|
"connection_id", ret.ConnectionId,
|
|
"connections_left", ret.ConnectionsLeft)
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (ws *workerServiceServer) ConnectConnection(ctx context.Context, req *pbs.ConnectConnectionRequest) (*pbs.ConnectConnectionResponse, error) {
|
|
ws.logger.Trace("got connection established information from worker", "connection_id", req.GetConnectionId())
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
|
|
}
|
|
|
|
connectionInfo, connStates, err := sessRepo.ConnectConnection(ctx, session.ConnectWith{
|
|
ConnectionId: req.GetConnectionId(),
|
|
ClientTcpAddress: req.GetClientTcpAddress(),
|
|
ClientTcpPort: req.GetClientTcpPort(),
|
|
EndpointTcpAddress: req.GetEndpointTcpAddress(),
|
|
EndpointTcpPort: req.GetEndpointTcpPort(),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if connectionInfo == nil {
|
|
return nil, status.Error(codes.Internal, "Invalid connect connection response.")
|
|
}
|
|
|
|
ret := &pbs.ConnectConnectionResponse{
|
|
Status: connStates[0].Status.ProtoVal(),
|
|
}
|
|
|
|
loggerPairs := []interface{}{
|
|
"session_id", connectionInfo.SessionId,
|
|
"connection_id", req.ConnectionId,
|
|
"client_tcp_address", req.ClientTcpAddress,
|
|
"client_tcp_port", req.ClientTcpPort,
|
|
}
|
|
switch req.GetType() {
|
|
case "tcp":
|
|
loggerPairs = append(loggerPairs,
|
|
"endpoint_tcp_address", connectionInfo.EndpointTcpAddress,
|
|
"endpoint_tcp_port", connectionInfo.EndpointTcpPort,
|
|
)
|
|
}
|
|
|
|
ws.logger.Info("connection established", loggerPairs...)
|
|
|
|
return ret, nil
|
|
}
|
|
|
|
func (ws *workerServiceServer) CloseConnection(ctx context.Context, req *pbs.CloseConnectionRequest) (*pbs.CloseConnectionResponse, error) {
|
|
numCloses := len(req.GetCloseRequestData())
|
|
if numCloses == 0 {
|
|
return &pbs.CloseConnectionResponse{}, nil
|
|
}
|
|
|
|
closeWiths := make([]session.CloseWith, 0, numCloses)
|
|
closeIds := make([]string, 0, numCloses)
|
|
|
|
for _, v := range req.GetCloseRequestData() {
|
|
closeIds = append(closeIds, v.GetConnectionId())
|
|
closeWiths = append(closeWiths, session.CloseWith{
|
|
ConnectionId: v.GetConnectionId(),
|
|
BytesUp: v.GetBytesUp(),
|
|
BytesDown: v.GetBytesDown(),
|
|
ClosedReason: session.ClosedReason(v.GetReason()),
|
|
})
|
|
}
|
|
ws.logger.Trace("got connection close information from worker", "connection_ids", closeIds)
|
|
|
|
sessRepo, err := ws.sessionRepoFn()
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "error getting session repo: %v", err)
|
|
}
|
|
|
|
closeInfos, err := sessRepo.CloseConnections(ctx, closeWiths)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if closeInfos == nil {
|
|
return nil, status.Error(codes.Internal, "Invalid close connection response.")
|
|
}
|
|
|
|
closeData := make([]*pbs.CloseConnectionResponseData, 0, numCloses)
|
|
for _, v := range closeInfos {
|
|
if v.Connection == nil {
|
|
return nil, status.Errorf(codes.Internal, "No connection found while closing one of the connection IDs: %v", closeIds)
|
|
}
|
|
if len(v.ConnectionStates) == 0 {
|
|
return nil, status.Errorf(codes.Internal, "No connection states found while closing one of the connection IDs: %v", closeIds)
|
|
}
|
|
closeData = append(closeData, &pbs.CloseConnectionResponseData{
|
|
ConnectionId: v.Connection.GetPublicId(),
|
|
Status: v.ConnectionStates[0].Status.ProtoVal(),
|
|
})
|
|
}
|
|
|
|
for _, v := range req.GetCloseRequestData() {
|
|
ws.logger.Info("connection closed", "connection_id", v.ConnectionId)
|
|
}
|
|
|
|
ret := &pbs.CloseConnectionResponse{
|
|
CloseResponseData: closeData,
|
|
}
|
|
|
|
return ret, nil
|
|
}
|