You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/worker/controller_connection.go

320 lines
10 KiB

package worker
import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math"
"math/big"
mathrand "math/rand"
"net"
"os"
"path/filepath"
"strings"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/daemon/worker/internal/metric"
"github.com/hashicorp/boundary/internal/errors"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/boundary/internal/observability/event"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/multihop"
"github.com/hashicorp/nodeenrollment/protocol"
"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/resolver"
"google.golang.org/protobuf/proto"
)
const hcpbUrlSuffix = ".proxy.boundary.hashicorp.cloud:9202"
// StartControllerConnections starts up the resolver and initiates controller
// connection client creation
func (w *Worker) StartControllerConnections() error {
const op = "worker.(Worker).StartControllerConnections"
initialAddrs := make([]string, 0, len(w.conf.RawConfig.Worker.InitialUpstreams))
for _, addr := range w.conf.RawConfig.Worker.InitialUpstreams {
switch {
case strings.HasPrefix(addr, "/"):
initialAddrs = append(initialAddrs, addr)
default:
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
host, port, err = net.SplitHostPort(net.JoinHostPort(addr, "9201"))
}
if err != nil {
return fmt.Errorf("error parsing upstream address: %w", err)
}
initialAddrs = append(initialAddrs, net.JoinHostPort(host, port))
}
}
if len(initialAddrs) == 0 {
if w.conf.RawConfig.HcpbClusterId != "" {
clusterAddress := fmt.Sprintf("%s%s", w.conf.RawConfig.HcpbClusterId, hcpbUrlSuffix)
initialAddrs = append(initialAddrs, clusterAddress)
event.WriteSysEvent(w.baseContext, op, fmt.Sprintf("Setting HCP Boundary cluster address %s as upstream address", clusterAddress))
} else {
return errors.New(w.baseContext, errors.InvalidParameter, op, "no initial upstream addresses found")
}
}
for _, ar := range w.addressReceivers {
ar.InitialAddresses(initialAddrs)
}
if err := w.createClientConn(initialAddrs[0]); err != nil {
return fmt.Errorf("error making client connection to upstream address %s: %w", initialAddrs[0], err)
}
return nil
}
func (w *Worker) controllerDialerFunc(extraAlpnProtos ...string) func(context.Context, string) (net.Conn, error) {
const op = "worker.(Worker).controllerDialerFunc"
return func(ctx context.Context, addr string) (net.Conn, error) {
var conn net.Conn
var err error
switch {
case w.conf.WorkerAuthKms != nil && !w.conf.DevUsePkiForUpstream:
conn, err = w.v1KmsAuthDialFn(ctx, addr, extraAlpnProtos...)
default:
conn, err = protocol.Dial(ctx, w.WorkerAuthStorage, addr, nodeenrollment.WithWrapper(w.conf.WorkerAuthStorageKms), nodeenrollment.WithExtraAlpnProtos(extraAlpnProtos))
// No error and a valid connection means the WorkerAuthRegistrationRequest was populated
// We can remove the stored workerAuthRequest file
if err == nil && conn != nil {
if w.WorkerAuthStorage.BaseDir() != "" {
workerAuthReqFilePath := filepath.Join(w.WorkerAuthStorage.BaseDir(), base.WorkerAuthReqFile)
// Intentionally ignoring any error removing this file
_ = os.Remove(workerAuthReqFilePath)
}
}
}
switch err {
case nil:
case nodeenrollment.ErrNotAuthorized:
// We don't event in this case, because the function retries often
// and will spam the logs. The status function will event indicating
// that it can't send status because it's not authorized, so that
// will be a fine hint to the user as to the issue.
default:
event.WriteError(ctx, op, err)
}
if err == nil && conn != nil {
if w.everAuthenticated.Load() == authenticationStatusNeverAuthenticated {
w.everAuthenticated.Store(authenticationStatusFirstAuthentication)
}
event.WriteSysEvent(ctx, op, "worker has successfully authenticated")
}
return conn, err
}
}
func (w *Worker) v1KmsAuthDialFn(ctx context.Context, addr string, extraAlpnProtos ...string) (net.Conn, error) {
const op = "worker.(Worker).v1KmsAuthDialFn"
tlsConf, authInfo, err := w.workerAuthTLSConfig(extraAlpnProtos...)
if err != nil {
return nil, fmt.Errorf("error creating tls config for worker auth: %w", err)
}
dialer := &net.Dialer{}
var nonTlsConn net.Conn
switch {
case strings.HasPrefix(addr, "/"):
nonTlsConn, err = dialer.DialContext(ctx, "unix", addr)
default:
nonTlsConn, err = dialer.DialContext(ctx, "tcp", addr)
}
if err != nil {
return nil, fmt.Errorf("unable to dial to upstream: %w", err)
}
tlsConn := tls.Client(nonTlsConn, tlsConf)
written, err := tlsConn.Write([]byte(authInfo.ConnectionNonce))
if err != nil {
if err := nonTlsConn.Close(); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing connection after writing failure"))
}
return nil, fmt.Errorf("unable to write connection nonce: %w", err)
}
if written != len(authInfo.ConnectionNonce) {
if err := nonTlsConn.Close(); err != nil {
event.WriteError(ctx, op, err, event.WithInfoMsg("error closing connection after writing failure"))
}
return nil, fmt.Errorf("expected to write %d bytes of connection nonce, wrote %d", len(authInfo.ConnectionNonce), written)
}
return tlsConn, nil
}
func (w *Worker) createClientConn(addr string) error {
const op = "worker.(Worker).createClientConn"
defaultTimeout := (time.Second + time.Nanosecond).String()
defServiceConfig := fmt.Sprintf(`
{
"loadBalancingConfig": [ { "round_robin": {} } ],
"methodConfig": [
{
"name": [],
"timeout": %q,
"waitForReady": true
}
]
}
`, defaultTimeout)
var res resolver.Builder
for _, v := range w.addressReceivers {
if rec, ok := v.(*grpcResolverReceiver); ok {
res = rec.Resolver
}
}
if res == nil {
return errors.New(w.baseContext, errors.Internal, op, "unable to find a resolver.Builder amongst the address receivers")
}
dialOpts := []grpc.DialOption{
grpc.WithResolvers(res),
grpc.WithUnaryInterceptor(metric.InstrumentClusterClient()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(math.MaxInt32)),
grpc.WithContextDialer(w.controllerDialerFunc()),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultServiceConfig(defServiceConfig),
// Don't have the resolver reach out for a service config from the
// resolver, use the one specified as default
grpc.WithDisableServiceConfig(),
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: time.Second,
Multiplier: 1.2,
Jitter: 0.2,
MaxDelay: 3 * time.Second,
},
}),
}
cc, err := grpc.DialContext(w.baseContext,
fmt.Sprintf("%s:///%s", res.Scheme(), addr),
dialOpts...,
)
if err != nil {
return fmt.Errorf("error dialing controller for worker auth: %w", err)
}
w.GrpcClientConn = cc
w.controllerStatusConn.Store(pbs.NewServerCoordinationServiceClient(cc))
w.controllerMultihopConn.Store(multihop.NewMultihopServiceClient(cc))
return nil
}
func (w *Worker) workerAuthTLSConfig(extraAlpnProtos ...string) (*tls.Config, *base.WorkerAuthInfo, error) {
var err error
info := &base.WorkerAuthInfo{
Name: w.conf.RawConfig.Worker.Name,
Description: w.conf.RawConfig.Worker.Description,
ProxyAddress: w.conf.RawConfig.Worker.PublicAddr,
}
info.ConnectionNonce, err = w.nonceFn(20)
if err != nil {
return nil, nil, err
}
pubKey, privKey, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
host, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
template := &x509.Certificate{
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
DNSNames: []string{host},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(globals.WorkerAuthNonceValidityPeriod),
BasicConstraintsValid: true,
IsCA: true,
}
certBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, template, template, pubKey, privKey)
if err != nil {
return nil, nil, err
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
info.CertPEM = pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(privKey)
if err != nil {
return nil, nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
info.KeyPEM = pem.EncodeToMemory(keyPEMBlock)
// Marshal and encrypt
marshaledInfo, err := json.Marshal(info)
if err != nil {
return nil, nil, err
}
encInfo, err := w.conf.WorkerAuthKms.Encrypt(w.baseContext, marshaledInfo)
if err != nil {
return nil, nil, err
}
marshaledEncInfo, err := proto.Marshal(encInfo)
if err != nil {
return nil, nil, err
}
b64alpn := base64.RawStdEncoding.EncodeToString(marshaledEncInfo)
var nextProtos []string
nextProtos = append(nextProtos, extraAlpnProtos...)
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
cert, err := x509.ParseCertificate(certBytes)
if err != nil {
return nil, nil, err
}
// Build local tls config
rootCAs := x509.NewCertPool()
rootCAs.AddCert(cert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
if err != nil {
return nil, nil, err
}
tlsConfig := &tls.Config{
ServerName: host,
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
}
return tlsConfig, info, nil
}