diff --git a/internal/daemon/cluster/handlers/worker_service.go b/internal/daemon/cluster/handlers/worker_service.go index a7963c127b..8e21989ee9 100644 --- a/internal/daemon/cluster/handlers/worker_service.go +++ b/internal/daemon/cluster/handlers/worker_service.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/daemon/controller/common" "github.com/hashicorp/boundary/internal/daemon/controller/handlers" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + intglobals "github.com/hashicorp/boundary/internal/globals" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/observability/event" "github.com/hashicorp/boundary/internal/server" @@ -39,6 +40,7 @@ type workerServiceServer struct { updateTimes *sync.Map kms *kms.Kms livenessTimeToStale *atomic.Int64 + controllerExt intglobals.ControllerExtension } var ( @@ -73,6 +75,7 @@ func NewWorkerServiceServer( updateTimes *sync.Map, kms *kms.Kms, livenessTimeToStale *atomic.Int64, + controllerExt intglobals.ControllerExtension, ) *workerServiceServer { return &workerServiceServer{ serversRepoFn: serversRepoFn, @@ -83,6 +86,7 @@ func NewWorkerServiceServer( updateTimes: updateTimes, kms: kms, livenessTimeToStale: livenessTimeToStale, + controllerExt: controllerExt, } } @@ -381,7 +385,16 @@ func egressFilterSelector(sessionInfo *session.Session) string { } // noProtocolContext doesn't provide any protocol context since tcp doesn't need any -func noProtocolContext(context.Context, *session.Repository, *server.Repository, common.WorkerAuthRepoStorageFactory, *pbs.AuthorizeConnectionRequest, []string) (*anypb.Any, error) { +func noProtocolContext( + context.Context, + *session.Repository, + *server.Repository, + common.WorkerAuthRepoStorageFactory, + *pbs.AuthorizeConnectionRequest, + []string, + string, + intglobals.ControllerExtension, +) (*anypb.Any, error) { return nil, nil } @@ -625,7 +638,16 @@ func (ws *workerServiceServer) AuthorizeConnection(ctx context.Context, req *pbs ConnectionsLeft: authzSummary.ConnectionLimit, Route: route, } - if pc, err := getProtocolContext(ctx, sessionRepo, serversRepo, ws.workerAuthRepoFn, req, route); err != nil { + if pc, err := getProtocolContext( + ctx, + sessionRepo, + serversRepo, + ws.workerAuthRepoFn, + req, + route, + ret.ConnectionId, + ws.controllerExt, + ); err != nil { return nil, err } else { ret.ProtocolContext = pc diff --git a/internal/daemon/cluster/handlers/worker_service_status_test.go b/internal/daemon/cluster/handlers/worker_service_status_test.go index 0019b24951..59089d38ff 100644 --- a/internal/daemon/cluster/handlers/worker_service_status_test.go +++ b/internal/daemon/cluster/handlers/worker_service_status_test.go @@ -63,6 +63,10 @@ func TestStatus(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } repo, err := sessionRepoFn() require.NoError(t, err) @@ -109,7 +113,7 @@ func TestStatus(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) @@ -495,6 +499,10 @@ func TestStatusSessionClosed(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } repo, err := sessionRepoFn() require.NoError(t, err) @@ -543,7 +551,7 @@ func TestStatusSessionClosed(t *testing.T) { sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2) require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) @@ -683,6 +691,10 @@ func TestStatusDeadConnection(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms, session.WithWorkerStateDelay(0)) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } repo, err := sessionRepoFn() require.NoError(t, err) @@ -729,7 +741,7 @@ func TestStatusDeadConnection(t *testing.T) { sess2, _, err = repo.ActivateSession(ctx, sess2.PublicId, sess2.Version, tofu2) require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) @@ -830,6 +842,10 @@ func TestStatusWorkerWithKeyId(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } repo, err := sessionRepoFn() require.NoError(t, err) @@ -889,7 +905,7 @@ func TestStatusWorkerWithKeyId(t *testing.T) { require.NoError(t, err) require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) connection, _, err := connRepo.AuthorizeConnection(ctx, sess.PublicId, worker1.PublicId) @@ -1027,13 +1043,17 @@ func TestStatusAuthorizedWorkers(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kmsCache) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } worker1 := server.TestKmsWorker(t, conn, wrapper) var w1KeyId, w2KeyId string w1 := server.TestPkiWorker(t, conn, wrapper, server.WithTestPkiWorkerAuthorizedKeyId(&w1KeyId)) w2 := server.TestPkiWorker(t, conn, wrapper, server.WithTestPkiWorkerAuthorizedKeyId(&w2KeyId)) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kmsCache, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kmsCache, new(atomic.Int64), fce) require.NotNil(t, s) cases := []struct { @@ -1238,10 +1258,14 @@ func TestWorkerOperationalStatus(t *testing.T) { connRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } worker1 := server.TestKmsWorker(t, conn, wrapper) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) cases := []struct { diff --git a/internal/daemon/cluster/handlers/worker_service_test.go b/internal/daemon/cluster/handlers/worker_service_test.go index 7bf2905f00..851b48e55e 100644 --- a/internal/daemon/cluster/handlers/worker_service_test.go +++ b/internal/daemon/cluster/handlers/worker_service_test.go @@ -18,6 +18,7 @@ import ( dcommon "github.com/hashicorp/boundary/internal/daemon/common" "github.com/hashicorp/boundary/internal/db" pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services" + intglobals "github.com/hashicorp/boundary/internal/globals" "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" @@ -35,6 +36,15 @@ import ( "google.golang.org/protobuf/testing/protocmp" ) +type fakeControllerExtension struct { + reader db.Reader + writer db.Writer +} + +var _ intglobals.ControllerExtension = (*fakeControllerExtension)(nil) + +func (f *fakeControllerExtension) Start(_ context.Context) error { return nil } + func TestLookupSession(t *testing.T) { ctx := context.Background() conn, _ := db.TestSetup(t, "postgres") @@ -55,6 +65,10 @@ func TestLookupSession(t *testing.T) { connectionRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) uId := at.GetIamUserId() @@ -144,7 +158,7 @@ func TestLookupSession(t *testing.T) { err = repo.AddSessionCredentials(ctx, sessWithCreds.ProjectId, sessWithCreds.GetPublicId(), workerCreds) require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) oldFn := connectionRouteFn @@ -310,6 +324,10 @@ func TestAuthorizeConnection(t *testing.T) { connectionRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kmsCache) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } var workerKeyId string worker := server.TestPkiWorker(t, conn, wrapper, server.WithTestPkiWorkerAuthorizedKeyId(&workerKeyId)) @@ -340,7 +358,7 @@ func TestAuthorizeConnection(t *testing.T) { repo, err := sessionRepoFn() require.NoError(t, err) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kmsCache, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kmsCache, new(atomic.Int64), fce) require.NotNil(t, s) cases := []struct { @@ -462,6 +480,10 @@ func TestCancelSession(t *testing.T) { connectionRepoFn := func() (*session.ConnectionRepository, error) { return session.NewConnectionRepository(ctx, rw, rw, kms) } + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId()) uId := at.GetIamUserId() @@ -480,7 +502,7 @@ func TestCancelSession(t *testing.T) { ProjectId: prj.GetPublicId(), Endpoint: "tcp://127.0.0.1:22", }) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kms, new(atomic.Int64)) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kms, new(atomic.Int64), fce) require.NotNil(t, s) cases := []struct { name string @@ -559,6 +581,10 @@ func TestHcpbWorkers(t *testing.T) { } var liveDur atomic.Int64 liveDur.Store(int64(1 * time.Second)) + fce := &fakeControllerExtension{ + reader: rw, + writer: rw, + } // Stale/unalive kms worker aren't expected... server.TestKmsWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"}), @@ -578,7 +604,7 @@ func TestHcpbWorkers(t *testing.T) { // PKI workers aren't expected server.TestPkiWorker(t, conn, wrapper, server.WithWorkerTags(&server.Tag{Key: dcommon.ManagedWorkerTag, Value: "true"})) - s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kmsCache, &liveDur) + s := NewWorkerServiceServer(serversRepoFn, workerAuthRepoFn, sessionRepoFn, connectionRepoFn, nil, new(sync.Map), kmsCache, &liveDur, fce) require.NotNil(t, s) res, err := s.ListHcpbWorkers(ctx, &pbs.ListHcpbWorkersRequest{}) diff --git a/internal/daemon/controller/rpc_registration.go b/internal/daemon/controller/rpc_registration.go index e426866734..fc3ca3dd8a 100644 --- a/internal/daemon/controller/rpc_registration.go +++ b/internal/daemon/controller/rpc_registration.go @@ -41,8 +41,17 @@ func registerControllerServerCoordinationService(ctx context.Context, c *Control return fmt.Errorf("%s: server is nil", op) } - workerService := handlers.NewWorkerServiceServer(c.ServersRepoFn, c.WorkerAuthRepoStorageFn, - c.SessionRepoFn, c.ConnectionRepoFn, c.downstreamWorkers, c.workerStatusUpdateTimes, c.kms, c.livenessTimeToStale) + workerService := handlers.NewWorkerServiceServer( + c.ServersRepoFn, + c.WorkerAuthRepoStorageFn, + c.SessionRepoFn, + c.ConnectionRepoFn, + c.downstreamWorkers, + c.workerStatusUpdateTimes, + c.kms, + c.livenessTimeToStale, + c.ControllerExtension, + ) pbs.RegisterServerCoordinationServiceServer(server, workerService) return nil } @@ -59,8 +68,17 @@ func registerControllerSessionService(ctx context.Context, c *Controller, server return fmt.Errorf("%s: server is nil", op) } - workerService := handlers.NewWorkerServiceServer(c.ServersRepoFn, c.WorkerAuthRepoStorageFn, - c.SessionRepoFn, c.ConnectionRepoFn, c.downstreamWorkers, c.workerStatusUpdateTimes, c.kms, c.livenessTimeToStale) + workerService := handlers.NewWorkerServiceServer( + c.ServersRepoFn, + c.WorkerAuthRepoStorageFn, + c.SessionRepoFn, + c.ConnectionRepoFn, + c.downstreamWorkers, + c.workerStatusUpdateTimes, + c.kms, + c.livenessTimeToStale, + c.ControllerExtension, + ) pbs.RegisterSessionServiceServer(server, workerService) return nil }