You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/servers/worker/controller_connection.go

253 lines
7.3 KiB

package worker
import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"math"
"math/big"
mathrand "math/rand"
"net"
"strings"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
"github.com/hashicorp/vault/sdk/helper/base62"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver"
"google.golang.org/protobuf/proto"
)
func (w *Worker) startControllerConnections() error {
initialAddrs := make([]resolver.Address, 0, len(w.conf.RawConfig.Worker.Controllers))
for _, addr := range w.conf.RawConfig.Worker.Controllers {
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
w.logger.Trace("missing port in controller address, using port 9201", "address", addr)
host, port, err = net.SplitHostPort(fmt.Sprintf("%s:%s", addr, "9201"))
}
if err != nil {
return fmt.Errorf("error parsing controller address: %w", err)
}
initialAddrs = append(initialAddrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", host, port)})
}
if len(initialAddrs) == 0 {
return errors.New("no initial controller addresses found")
}
w.Resolver().InitialState(resolver.State{
Addresses: initialAddrs,
})
if err := w.createClientConn(initialAddrs[0].Addr); err != nil {
return fmt.Errorf("error making client connection to controller: %w", err)
}
return nil
}
func (w Worker) controllerDialerFunc() func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
tlsConf, authInfo, err := w.workerAuthTLSConfig()
if err != nil {
return nil, fmt.Errorf("error creating tls config for worker auth: %w", err)
}
dialer := &net.Dialer{}
nonTlsConn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("unable to dial to controller: %w", err)
}
tlsConn := tls.Client(nonTlsConn, tlsConf)
written, err := tlsConn.Write([]byte(authInfo.ConnectionNonce))
if err != nil {
if err := nonTlsConn.Close(); err != nil {
w.logger.Error("error closing connection after writing failure", "error", err)
}
return nil, fmt.Errorf("unable to write connection nonce: %w", err)
}
if written != len(authInfo.ConnectionNonce) {
if err := nonTlsConn.Close(); err != nil {
w.logger.Error("error closing connection after writing failure", "error", err)
}
return nil, fmt.Errorf("expected to write %d bytes of connection nonce, wrote %d", len(authInfo.ConnectionNonce), written)
}
return tlsConn, nil
}
}
func (w *Worker) createClientConn(addr string) error {
defaultTimeout := (time.Second + time.Nanosecond).String()
defServiceConfig := fmt.Sprintf(`
{
"loadBalancingConfig": [ { "round_robin": {} } ],
"methodConfig": [
{
"name": [],
"timeout": %q,
"waitForReady": true
}
]
}
`, defaultTimeout)
cc, err := grpc.DialContext(w.baseContext,
fmt.Sprintf("%s:///%s", w.Resolver().Scheme(), addr),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(math.MaxInt32)),
grpc.WithContextDialer(w.controllerDialerFunc()),
grpc.WithInsecure(),
grpc.WithDefaultServiceConfig(defServiceConfig),
// Don't have the resolver reach out for a service config from the
// resolver, use the one specified as default
grpc.WithDisableServiceConfig(),
)
if err != nil {
return fmt.Errorf("error dialing controller for worker auth: %w", err)
}
w.controllerStatusConn.Store(pbs.NewServerCoordinationServiceClient(cc))
w.controllerSessionConn.Store(pbs.NewSessionServiceClient(cc))
w.logger.Info("connected to controller", "address", addr)
return nil
}
func (w Worker) workerAuthTLSConfig() (*tls.Config, *base.WorkerAuthInfo, error) {
var err error
info := &base.WorkerAuthInfo{
Name: w.conf.RawConfig.Worker.Name,
Description: w.conf.RawConfig.Worker.Description,
}
if info.ConnectionNonce, err = base62.Random(20); err != nil {
return nil, nil, err
}
_, caKey, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
caHost, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: caHost,
},
DNSNames: []string{caHost},
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(3 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
return nil, nil, err
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
info.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
//
// Certs generation
//
_, key, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
host, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(2 * time.Minute),
}
certBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
return nil, nil, err
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
info.CertPEM = pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
info.KeyPEM = pem.EncodeToMemory(keyPEMBlock)
// Marshal and encrypt
marshaledInfo, err := json.Marshal(info)
if err != nil {
return nil, nil, err
}
encInfo, err := w.conf.WorkerAuthKms.Encrypt(context.Background(), marshaledInfo, nil)
if err != nil {
return nil, nil, err
}
marshaledEncInfo, err := proto.Marshal(encInfo)
if err != nil {
return nil, nil, err
}
b64alpn := base64.RawStdEncoding.EncodeToString(marshaledEncInfo)
var nextProtos []string
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
// Build local tls config
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
if err != nil {
return nil, nil, err
}
tlsConfig := &tls.Config{
ServerName: host,
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
}
return tlsConfig, info, nil
}