mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
253 lines
7.3 KiB
253 lines
7.3 KiB
package worker
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"math/big"
|
|
mathrand "math/rand"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/internal/cmd/base"
|
|
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
|
"github.com/hashicorp/vault/sdk/helper/base62"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/resolver"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func (w *Worker) startControllerConnections() error {
|
|
initialAddrs := make([]resolver.Address, 0, len(w.conf.RawConfig.Worker.Controllers))
|
|
for _, addr := range w.conf.RawConfig.Worker.Controllers {
|
|
host, port, err := net.SplitHostPort(addr)
|
|
if err != nil && strings.Contains(err.Error(), "missing port in address") {
|
|
w.logger.Trace("missing port in controller address, using port 9201", "address", addr)
|
|
host, port, err = net.SplitHostPort(fmt.Sprintf("%s:%s", addr, "9201"))
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing controller address: %w", err)
|
|
}
|
|
initialAddrs = append(initialAddrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", host, port)})
|
|
}
|
|
|
|
if len(initialAddrs) == 0 {
|
|
return errors.New("no initial controller addresses found")
|
|
}
|
|
|
|
w.Resolver().InitialState(resolver.State{
|
|
Addresses: initialAddrs,
|
|
})
|
|
if err := w.createClientConn(initialAddrs[0].Addr); err != nil {
|
|
return fmt.Errorf("error making client connection to controller: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w Worker) controllerDialerFunc() func(context.Context, string) (net.Conn, error) {
|
|
return func(ctx context.Context, addr string) (net.Conn, error) {
|
|
tlsConf, authInfo, err := w.workerAuthTLSConfig()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error creating tls config for worker auth: %w", err)
|
|
}
|
|
dialer := &net.Dialer{}
|
|
nonTlsConn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("unable to dial to controller: %w", err)
|
|
}
|
|
tlsConn := tls.Client(nonTlsConn, tlsConf)
|
|
written, err := tlsConn.Write([]byte(authInfo.ConnectionNonce))
|
|
if err != nil {
|
|
if err := nonTlsConn.Close(); err != nil {
|
|
w.logger.Error("error closing connection after writing failure", "error", err)
|
|
}
|
|
return nil, fmt.Errorf("unable to write connection nonce: %w", err)
|
|
}
|
|
if written != len(authInfo.ConnectionNonce) {
|
|
if err := nonTlsConn.Close(); err != nil {
|
|
w.logger.Error("error closing connection after writing failure", "error", err)
|
|
}
|
|
return nil, fmt.Errorf("expected to write %d bytes of connection nonce, wrote %d", len(authInfo.ConnectionNonce), written)
|
|
}
|
|
return tlsConn, nil
|
|
}
|
|
}
|
|
|
|
func (w *Worker) createClientConn(addr string) error {
|
|
defaultTimeout := (time.Second + time.Nanosecond).String()
|
|
defServiceConfig := fmt.Sprintf(`
|
|
{
|
|
"loadBalancingConfig": [ { "round_robin": {} } ],
|
|
"methodConfig": [
|
|
{
|
|
"name": [],
|
|
"timeout": %q,
|
|
"waitForReady": true
|
|
}
|
|
]
|
|
}
|
|
`, defaultTimeout)
|
|
cc, err := grpc.DialContext(w.baseContext,
|
|
fmt.Sprintf("%s:///%s", w.Resolver().Scheme(), addr),
|
|
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
|
|
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(math.MaxInt32)),
|
|
grpc.WithContextDialer(w.controllerDialerFunc()),
|
|
grpc.WithInsecure(),
|
|
grpc.WithDefaultServiceConfig(defServiceConfig),
|
|
// Don't have the resolver reach out for a service config from the
|
|
// resolver, use the one specified as default
|
|
grpc.WithDisableServiceConfig(),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("error dialing controller for worker auth: %w", err)
|
|
}
|
|
|
|
w.controllerStatusConn.Store(pbs.NewServerCoordinationServiceClient(cc))
|
|
w.controllerSessionConn.Store(pbs.NewSessionServiceClient(cc))
|
|
|
|
w.logger.Info("connected to controller", "address", addr)
|
|
return nil
|
|
}
|
|
|
|
func (w Worker) workerAuthTLSConfig() (*tls.Config, *base.WorkerAuthInfo, error) {
|
|
var err error
|
|
info := &base.WorkerAuthInfo{
|
|
Name: w.conf.RawConfig.Worker.Name,
|
|
Description: w.conf.RawConfig.Worker.Description,
|
|
}
|
|
if info.ConnectionNonce, err = base62.Random(20); err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
_, caKey, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
caHost, err := base62.Random(20)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
caCertTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: caHost,
|
|
},
|
|
DNSNames: []string{caHost},
|
|
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(3 * time.Minute),
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
caBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
caCertPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: caBytes,
|
|
}
|
|
info.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
|
|
caCert, err := x509.ParseCertificate(caBytes)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
|
|
//
|
|
// Certs generation
|
|
//
|
|
_, key, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
host, err := base62.Random(20)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
certTemplate := &x509.Certificate{
|
|
Subject: pkix.Name{
|
|
CommonName: host,
|
|
},
|
|
DNSNames: []string{host},
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageServerAuth,
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
|
|
SerialNumber: big.NewInt(mathrand.Int63()),
|
|
NotBefore: time.Now().Add(-30 * time.Second),
|
|
NotAfter: time.Now().Add(2 * time.Minute),
|
|
}
|
|
certBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
certPEMBlock := &pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certBytes,
|
|
}
|
|
info.CertPEM = pem.EncodeToMemory(certPEMBlock)
|
|
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
keyPEMBlock := &pem.Block{
|
|
Type: "PRIVATE KEY",
|
|
Bytes: marshaledKey,
|
|
}
|
|
info.KeyPEM = pem.EncodeToMemory(keyPEMBlock)
|
|
|
|
// Marshal and encrypt
|
|
marshaledInfo, err := json.Marshal(info)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
encInfo, err := w.conf.WorkerAuthKms.Encrypt(context.Background(), marshaledInfo, nil)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
marshaledEncInfo, err := proto.Marshal(encInfo)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
b64alpn := base64.RawStdEncoding.EncodeToString(marshaledEncInfo)
|
|
var nextProtos []string
|
|
var count int
|
|
for i := 0; i < len(b64alpn); i += 230 {
|
|
end := i + 230
|
|
if end > len(b64alpn) {
|
|
end = len(b64alpn)
|
|
}
|
|
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
|
|
count++
|
|
}
|
|
|
|
// Build local tls config
|
|
rootCAs := x509.NewCertPool()
|
|
rootCAs.AddCert(caCert)
|
|
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
tlsConfig := &tls.Config{
|
|
ServerName: host,
|
|
Certificates: []tls.Certificate{tlsCert},
|
|
RootCAs: rootCAs,
|
|
NextProtos: nextProtos,
|
|
MinVersion: tls.VersionTLS13,
|
|
}
|
|
|
|
return tlsConfig, info, nil
|
|
}
|