You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/worker/worker_test.go

557 lines
18 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package worker
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"sync"
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/worker/common"
"github.com/hashicorp/boundary/internal/daemon/worker/session"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/gen/controller/servers/services"
wpbs "github.com/hashicorp/boundary/internal/gen/worker/servers/services"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/configutil/v2"
"github.com/hashicorp/go-secure-stdlib/listenerutil"
"github.com/hashicorp/nodeenrollment"
"github.com/hashicorp/nodeenrollment/registration"
"github.com/hashicorp/nodeenrollment/rotation"
nodeefile "github.com/hashicorp/nodeenrollment/storage/file"
"github.com/hashicorp/nodeenrollment/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestWorkerNew(t *testing.T) {
tests := []struct {
name string
in *Config
expErr bool
expErrMsg string
assertions func(t *testing.T, w *Worker)
}{
{
name: "nil listeners",
in: &Config{Server: &base.Server{Listeners: nil}},
expErr: true,
expErrMsg: "exactly one proxy listener is required",
},
{
name: "zero listeners",
in: &Config{Server: &base.Server{Listeners: []*base.ServerListener{}}},
expErr: true,
expErrMsg: "exactly one proxy listener is required",
},
{
name: "populated with nil values",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
nil,
{Config: nil},
{Config: &listenerutil.ListenerConfig{Purpose: nil}},
},
},
},
expErr: true,
expErrMsg: "exactly one proxy listener is required",
},
{
name: "multiple purposes",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
nil,
{Config: nil},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api", "proxy"}}},
},
},
},
expErr: true,
expErrMsg: `found listener with multiple purposes "api,proxy"`,
},
{
name: "too many proxy listeners",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
},
},
expErr: true,
expErrMsg: "exactly one proxy listener is required",
},
{
name: "valid listeners",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
},
},
expErr: false,
},
{
name: "worker nonce func is set",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
require.NotNil(t, w.nonceFn)
},
},
{
name: "worker recording storage path is not set",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
Eventer: &event.Eventer{},
},
RawConfig: &config.Config{
Worker: &config.Worker{},
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
assert.Equal(t, w.conf.RawConfig.Worker.RecordingStoragePath, "")
assert.Equal(t, w.localStorageState.Load().(server.LocalStorageState).String(), server.NotConfiguredLocalStorageState.String())
},
},
{
name: "worker recording storage path is set",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
Worker: &config.Worker{
RecordingStoragePath: "/tmp",
},
SharedConfig: &configutil.SharedConfig{
DisableMlock: true,
},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
assert.Equal(t, w.conf.RawConfig.Worker.RecordingStoragePath, "/tmp")
assert.Equal(t, w.localStorageState.Load().(server.LocalStorageState).String(), server.UnknownLocalStorageState.String())
},
},
{
name: "worker host service server is the unimplemented one by default",
in: &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
},
},
RawConfig: &config.Config{
Worker: &config.Worker{},
SharedConfig: &configutil.SharedConfig{DisableMlock: true},
},
},
expErr: false,
assertions: func(t *testing.T, w *Worker) {
assert.Equal(t, wpbs.UnimplementedHostServiceServer{}, w.HostServiceServer)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// New() panics if these aren't set
tt.in.Logger = hclog.Default()
if tt.in.RawConfig == nil {
tt.in.RawConfig = &config.Config{SharedConfig: &configutil.SharedConfig{DisableMlock: true}}
}
if util.IsNil(tt.in.Eventer) {
require.NoError(t, event.InitSysEventer(hclog.Default(), &sync.Mutex{}, "worker_test", event.WithEventerConfig(&event.EventerConfig{})))
t.Cleanup(func() { event.TestResetSystEventer(t) })
tt.in.Eventer = event.SysEventer()
}
currentHostServiceFactory := hostServiceServerFactory
hostServiceServerFactory = nil
t.Cleanup(func() { hostServiceServerFactory = currentHostServiceFactory })
w, err := New(context.Background(), tt.in)
if tt.expErr {
require.EqualError(t, err, tt.expErrMsg)
require.Nil(t, w)
return
}
require.NoError(t, err)
if tt.assertions != nil {
tt.assertions(t, w)
}
})
}
}
func TestWorkerReload(t *testing.T) {
t.Run("default config is the same as the reload config", func(t *testing.T) {
require, assert := require.New(t), assert.New(t)
cfg := &Config{
Server: &base.Server{
Logger: hclog.Default(),
Eventer: &event.Eventer{},
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
},
RawConfig: &config.Config{
SharedConfig: &configutil.SharedConfig{DisableMlock: true},
Worker: &config.Worker{},
},
}
w, err := New(context.Background(), cfg)
require.NoError(err)
assert.Equal(int64(server.DefaultLiveness), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(server.DefaultLiveness), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(int64(server.DefaultLiveness), session.CloseCallTimeout.Load())
assert.Equal(int64(common.DefaultRoutingInfoTimeout), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(common.DefaultStatisticsTimeout), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load())
w.Reload(context.Background(), cfg.RawConfig)
assert.Equal(int64(server.DefaultLiveness), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(server.DefaultLiveness), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(int64(server.DefaultLiveness), session.CloseCallTimeout.Load())
assert.Equal(int64(common.DefaultRoutingInfoTimeout), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(common.DefaultStatisticsTimeout), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(common.DefaultSessionInfoTimeout), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(server.DefaultLiveness), w.getDownstreamWorkersTimeoutDuration.Load())
})
t.Run("new config is the same as the reload config", func(t *testing.T) {
require, assert := require.New(t), assert.New(t)
cfg := &Config{
Server: &base.Server{
Logger: hclog.Default(),
Eventer: &event.Eventer{},
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
},
RawConfig: &config.Config{
SharedConfig: &configutil.SharedConfig{DisableMlock: true},
Worker: &config.Worker{
SuccessfulControllerRPCGracePeriodDuration: 5 * time.Second,
ControllerRPCCallTimeoutDuration: 10 * time.Second,
GetDownstreamWorkersTimeoutDuration: 20 * time.Second,
},
},
}
w, err := New(context.Background(), cfg)
require.NoError(err)
assert.Equal(int64(5*time.Second), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(5*time.Second), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load())
assert.Equal(int64(10*time.Second), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
w.Reload(context.Background(), cfg.RawConfig)
assert.Equal(int64(5*time.Second), w.successfulRoutingInfoGracePeriod.Load())
assert.Equal(int64(5*time.Second), w.successfulSessionInfoGracePeriod.Load())
assert.Equal(w.successfulRoutingInfoGracePeriod.Load(), session.CloseCallTimeout.Load())
assert.Equal(int64(10*time.Second), w.routingInfoCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.statisticsCallTimeoutDuration.Load())
assert.Equal(int64(10*time.Second), w.sessionInfoCallTimeoutDuration.Load())
assert.Equal(int64(20*time.Second), w.getDownstreamWorkersTimeoutDuration.Load())
})
}
func TestSetupWorkerAuthStorage(t *testing.T) {
ctx := context.Background()
ts := db.TestWrapper(t)
keyId, err := ts.KeyId(ctx)
require.NoError(t, err)
// First, just test the key ID is populated
tmpDir := t.TempDir()
tw := NewTestWorker(t, &TestWorkerOpts{
WorkerAuthStorageKms: ts,
WorkerAuthStoragePath: tmpDir,
DisableAutoStart: true,
})
err = tw.Worker().Start()
require.NoError(t, err)
wKeyId, err := tw.Config().WorkerAuthStorageKms.KeyId(ctx)
require.NoError(t, err)
assert.Equal(t, keyId, wKeyId)
// Create a fresh persistent dir for the following tests
tmpDir = t.TempDir()
// Get an initial set of authorized node credentials
initStorage, err := nodeefile.New(ctx)
require.NoError(t, err)
t.Cleanup(func() { initStorage.Cleanup(ctx) })
_, err = rotation.RotateRootCertificates(ctx, initStorage)
require.NoError(t, err)
initNodeCreds, err := types.NewNodeCredentials(ctx, initStorage)
require.NoError(t, err)
req, err := initNodeCreds.CreateFetchNodeCredentialsRequest(ctx)
require.NoError(t, err)
_, err = registration.AuthorizeNode(ctx, initStorage, req)
require.NoError(t, err)
fetchResp, err := registration.FetchNodeCredentials(ctx, initStorage, req)
require.NoError(t, err)
initNodeCreds, err = initNodeCreds.HandleFetchNodeCredentialsResponse(ctx, initStorage, fetchResp)
require.NoError(t, err)
initKeyId, err := nodeenrollment.KeyIdFromPkix(initNodeCreds.CertificatePublicKeyPkix)
require.NoError(t, err)
nonce := make([]byte, nodeenrollment.NonceSize)
_, err = rand.Reader.Read(nonce)
require.NoError(t, err)
// What's going on here: in each test we are simulating a startup of a
// worker that has storage in various states. The input is a function to
// modify the current state of node credentials by using the worker's
// storage, but this happens before Start so we haven't done checking yet;
// the assertions check what the final result is.
tests := []struct {
name string
in func(*testing.T, nodeenrollment.Storage, *Worker)
expKeyId string // If set, the existing key ID to expect
expRegistrationRequest bool // Whether we should have seen a registration request generated
expError string // Some other error
}{
{
name: "no creds",
in: func(t *testing.T, storage nodeenrollment.Storage, w *Worker) {
// Do nothing; in this case it will have already been cleared
},
expRegistrationRequest: true,
},
{
name: "valid creds",
in: func(t *testing.T, storage nodeenrollment.Storage, w *Worker) {
// Store the authorized creds
require.NoError(t, initNodeCreds.Store(ctx, storage))
},
expKeyId: initKeyId,
},
{
name: "existing but not validated",
in: func(t *testing.T, storage nodeenrollment.Storage, w *Worker) {
creds := proto.Clone(initNodeCreds).(*types.NodeCredentials)
creds.CertificateBundles = nil
creds.RegistrationNonce = nonce
require.NoError(t, creds.Store(ctx, storage))
},
expKeyId: initKeyId,
expRegistrationRequest: true,
},
{
name: "existing and outside cert times", // Note that cert from next CA will already not be valid
in: func(t *testing.T, storage nodeenrollment.Storage, w *Worker) {
creds := proto.Clone(initNodeCreds).(*types.NodeCredentials)
creds.CertificateBundles[0].CertificateNotBefore = timestamppb.New(time.Time{})
creds.CertificateBundles[0].CertificateNotAfter = timestamppb.New(time.Time{})
require.NoError(t, creds.Store(ctx, storage))
},
expRegistrationRequest: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tw := NewTestWorker(t, &TestWorkerOpts{
WorkerAuthStoragePath: tmpDir,
DisableAutoStart: true,
})
// Always clear out storage that was there before, ignore errors
storage, err := nodeefile.New(tw.Context(), nodeefile.WithBaseDirectory(tmpDir))
require.NoError(t, err)
_ = storage.Remove(ctx, &types.NodeCredentials{Id: string(nodeenrollment.CurrentId)})
// Run node credentials modification
tt.in(t, storage, tw.Worker())
// Start up to run logic
require.NoError(t, tw.Worker().Start())
// Validate existing key was loaded or new key was created and loaded
if tt.expKeyId != "" {
assert.Equal(t, tt.expKeyId, tw.Worker().WorkerAuthCurrentKeyId.Load())
} else {
assert.NotEmpty(t, tw.Worker().WorkerAuthCurrentKeyId.Load())
}
if tt.expRegistrationRequest {
assert.NotEmpty(t, tw.Worker().WorkerAuthRegistrationRequest)
} else {
assert.Empty(t, tw.Worker().WorkerAuthRegistrationRequest)
}
})
}
}
func Test_Worker_getSessionTls(t *testing.T) {
require.NoError(t, event.InitSysEventer(hclog.Default(), &sync.Mutex{}, "worker_test", event.WithEventerConfig(&event.EventerConfig{})))
t.Cleanup(func() { event.TestResetSystEventer(t) })
conf := &Config{
Server: &base.Server{
Listeners: []*base.ServerListener{
{Config: &listenerutil.ListenerConfig{Purpose: []string{"api"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"proxy"}}},
{Config: &listenerutil.ListenerConfig{Purpose: []string{"cluster"}}},
},
Eventer: event.SysEventer(),
Logger: hclog.Default(),
},
}
conf.RawConfig = &config.Config{SharedConfig: &configutil.SharedConfig{DisableMlock: true}}
w, err := New(context.Background(), conf)
require.NoError(t, err)
w.lastRoutingInfoSuccess.Store(&LastRoutingInfo{RoutingInfoResponse: &services.RoutingInfoResponse{}, RoutingInfoTime: time.Now(), LastCalculatedUpstreams: nil})
w.baseContext = context.Background()
t.Run("success", func(t *testing.T) {
m := &fakeManager{
session: &fakeSession{
cert: &x509.Certificate{
Raw: []byte("something"),
},
privateKey: []byte("something_else"),
},
}
hello := &tls.ClientHelloInfo{ServerName: "s_1234567890"}
tlsConf, err := w.getSessionTls(m)(hello)
require.NoError(t, err)
require.Len(t, tlsConf.Certificates, 1)
require.Len(t, tlsConf.Certificates[0].Certificate, 1)
assert.Equal(t, m.session.cert.Raw, tlsConf.Certificates[0].Certificate[0])
assert.Equal(t, m.session.cert, tlsConf.Certificates[0].Leaf)
assert.Equal(t, ed25519.PrivateKey(m.session.privateKey), tlsConf.Certificates[0].PrivateKey)
require.Len(t, tlsConf.NextProtos, 1)
assert.Equal(t, "http/1.1", tlsConf.NextProtos[0])
assert.Equal(t, tls.VersionTLS13, int(tlsConf.MinVersion))
assert.Equal(t, tls.RequireAnyClientCert, tlsConf.ClientAuth)
assert.Equal(t, true, tlsConf.InsecureSkipVerify)
assert.NotNil(t, tlsConf.VerifyConnection)
})
t.Run("errors-on-empty-cert", func(t *testing.T) {
m := &fakeManager{
session: &fakeSession{
cert: nil,
privateKey: []byte("something_else"),
},
}
hello := &tls.ClientHelloInfo{ServerName: "s_1234567890"}
_, err := w.getSessionTls(m)(hello)
require.Error(t, err)
})
t.Run("errors-on-empty-cert-der", func(t *testing.T) {
m := &fakeManager{
session: &fakeSession{
cert: &x509.Certificate{
Raw: nil,
},
privateKey: []byte("something_else"),
},
}
hello := &tls.ClientHelloInfo{ServerName: "s_1234567890"}
_, err := w.getSessionTls(m)(hello)
require.Error(t, err)
})
t.Run("errors-on-empty-private-key", func(t *testing.T) {
m := &fakeManager{
session: &fakeSession{
cert: &x509.Certificate{
Raw: []byte("something"),
},
privateKey: nil,
},
}
hello := &tls.ClientHelloInfo{ServerName: "s_1234567890"}
_, err := w.getSessionTls(m)(hello)
require.Error(t, err)
})
}
type fakeSession struct {
cert *x509.Certificate
privateKey []byte
session.Session
}
func (f *fakeSession) GetCertificate() *x509.Certificate {
return f.cert
}
func (f *fakeSession) GetPrivateKey() []byte {
return f.privateKey
}
type fakeManager struct {
session.Manager
session *fakeSession
}
func (f *fakeManager) LoadLocalSession(ctx context.Context, id string, workerId string) (session.Session, error) {
return f.session, nil
}