From 9408b2b299dbd9ac357d2f6d516d56547212593a Mon Sep 17 00:00:00 2001 From: Jeff Mitchell Date: Thu, 7 Mar 2024 14:46:10 -0500 Subject: [PATCH] Add some more fixes caught by race detector (#4488) --- .../ldap/service_list_accounts_ext_test.go | 4 ++-- .../oidc/service_list_accounts_ext_test.go | 2 +- .../service_list_accounts_ext_test.go | 4 ++-- .../service_list_auth_methods_ext_test.go | 6 +++--- internal/authtoken/service_list_ext_test.go | 2 +- internal/bsr/bsr_open_test.go | 6 +++--- internal/bsr/bsr_validate_test.go | 6 +++--- internal/bsr/internal/fstest/fs.go | 20 ++++++++++++------- .../server/controller_db_swap_test.go | 6 +++++- .../service_list_credentials_ext_test.go | 2 +- .../service_list_libraries_ext_test.go | 4 ++-- internal/daemon/worker/status.go | 3 ++- internal/daemon/worker/worker.go | 12 +++++++++-- .../plugin/service_list_host_sets_ext_test.go | 2 +- .../plugin/service_list_hosts_ext_test.go | 4 ++-- .../host/service_list_catalogs_ext_test.go | 2 +- .../static/service_list_host_sets_ext_test.go | 4 ++-- .../static/service_list_hosts_ext_test.go | 6 +++--- internal/iam/service_list_ext_test.go | 16 +++++++-------- .../plugin/loopback/testing_grpc_stream.go | 10 ++++++++++ .../loopback/testing_grpc_stream_test.go | 10 ++++++++-- internal/session/service_list_ext_test.go | 4 ++-- internal/target/service_list_ext_test.go | 2 +- 23 files changed, 86 insertions(+), 51 deletions(-) diff --git a/internal/auth/ldap/service_list_accounts_ext_test.go b/internal/auth/ldap/service_list_accounts_ext_test.go index 7a6565be69..a48c472aa2 100644 --- a/internal/auth/ldap/service_list_accounts_ext_test.go +++ b/internal/auth/ldap/service_list_accounts_ext_test.go @@ -172,7 +172,7 @@ func TestService_ListAccounts(t *testing.T) { filterFunc := func(_ context.Context, a auth.Account) (bool, error) { return true, nil } - _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + _, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -338,7 +338,7 @@ func TestService_ListAccounts(t *testing.T) { filterFunc := func(_ context.Context, a auth.Account) (bool, error) { return true, nil } - _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + _, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/auth/oidc/service_list_accounts_ext_test.go b/internal/auth/oidc/service_list_accounts_ext_test.go index af53740a6e..f42ee7ef50 100644 --- a/internal/auth/oidc/service_list_accounts_ext_test.go +++ b/internal/auth/oidc/service_list_accounts_ext_test.go @@ -344,7 +344,7 @@ func TestService_ListAccounts(t *testing.T) { filterFunc := func(_ context.Context, a auth.Account) (bool, error) { return true, nil } - _, err = oidc.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + _, err := oidc.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/auth/password/service_list_accounts_ext_test.go b/internal/auth/password/service_list_accounts_ext_test.go index d122ac6549..3b2fb2a49e 100644 --- a/internal/auth/password/service_list_accounts_ext_test.go +++ b/internal/auth/password/service_list_accounts_ext_test.go @@ -250,7 +250,7 @@ func TestService_ListAccounts(t *testing.T) { filterFunc := func(_ context.Context, a auth.Account) (bool, error) { return true, nil } - _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + _, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("nil repo", func(t *testing.T) { @@ -328,7 +328,7 @@ func TestService_ListAccounts(t *testing.T) { filterFunc := func(_ context.Context, a auth.Account) (bool, error) { return true, nil } - _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + _, err := password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/auth/service_list_auth_methods_ext_test.go b/internal/auth/service_list_auth_methods_ext_test.go index f6c98d886b..d368d72448 100644 --- a/internal/auth/service_list_auth_methods_ext_test.go +++ b/internal/auth/service_list_auth_methods_ext_test.go @@ -237,7 +237,7 @@ func TestStoreService_List(t *testing.T) { filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) { return true, nil } - _, err = auth.ListAuthMethodsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) + _, err := auth.ListAuthMethodsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -325,7 +325,7 @@ func TestStoreService_List(t *testing.T) { filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) { return true, nil } - _, err = auth.ListAuthMethodsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) + _, err := auth.ListAuthMethodsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) require.ErrorContains(t, err, "missing token") }) t.Run("missing repo", func(t *testing.T) { @@ -403,7 +403,7 @@ func TestStoreService_List(t *testing.T) { filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) { return true, nil } - _, err = auth.ListAuthMethodsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) + _, err := auth.ListAuthMethodsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/authtoken/service_list_ext_test.go b/internal/authtoken/service_list_ext_test.go index e97bf1ffa4..2ce0f2e8d5 100644 --- a/internal/authtoken/service_list_ext_test.go +++ b/internal/authtoken/service_list_ext_test.go @@ -331,7 +331,7 @@ func TestService_List(t *testing.T) { filterFunc := func(_ context.Context, at *authtoken.AuthToken) (bool, error) { return true, nil } - _, err = authtoken.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := authtoken.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/bsr/bsr_open_test.go b/internal/bsr/bsr_open_test.go index 3ba32354a0..e848e126cd 100644 --- a/internal/bsr/bsr_open_test.go +++ b/internal/bsr/bsr_open_test.go @@ -538,7 +538,7 @@ func TestCloseBSRMethods(t *testing.T) { // Get session container sessionContainer := f.Containers[fmt.Sprintf(bsrFileNameTemplate, sessionRecordingId)] require.NotNil(t, sessionContainer) - assert.True(t, sessionContainer.Closed) + assert.True(t, sessionContainer.IsClosed()) // Ensure all session files are closed for _, file := range sessionContainer.Files { @@ -548,7 +548,7 @@ func TestCloseBSRMethods(t *testing.T) { // Get connection container connectionContainer := sessionContainer.Sub[fmt.Sprintf(connectionFileNameTemplate, connectionId)] require.NotNil(t, connectionContainer) - assert.True(t, connectionContainer.Closed) + assert.True(t, connectionContainer.IsClosed()) // Ensure all connection files are closed for _, file := range connectionContainer.Files { @@ -558,7 +558,7 @@ func TestCloseBSRMethods(t *testing.T) { // Get channel container channelContainer := connectionContainer.Sub[fmt.Sprintf(channelFileNameTemplate, channelId)] require.NotNil(t, channelContainer) - assert.True(t, channelContainer.Closed) + assert.True(t, channelContainer.IsClosed()) // Ensure all channel files are closed for _, file := range channelContainer.Files { diff --git a/internal/bsr/bsr_validate_test.go b/internal/bsr/bsr_validate_test.go index ea5a6ef344..c758266a89 100644 --- a/internal/bsr/bsr_validate_test.go +++ b/internal/bsr/bsr_validate_test.go @@ -355,7 +355,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) { // Ensure session container is closed sessionContainer := fs.Containers[fmt.Sprintf(bsrFileNameTemplate, validation.SessionRecordingId)] require.NotNil(t, sessionContainer) - assert.True(t, sessionContainer.Closed) + assert.True(t, sessionContainer.IsClosed()) // Validate Multiple Connections for _, connection := range validation.SessionRecordingValidation.SubContainers { @@ -368,7 +368,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) { // Ensure connection container is closed connectionContainer := sessionContainer.Sub[fmt.Sprintf(connectionFileNameTemplate, connection.Name)] require.NotNil(t, connectionContainer) - assert.True(t, connectionContainer.Closed) + assert.True(t, connectionContainer.IsClosed()) // Validate Multiple Channels for _, channel := range connection.SubContainers { @@ -381,7 +381,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) { // Ensure channel container is closed channelContainer := connectionContainer.Sub[fmt.Sprintf(channelFileNameTemplate, channel.Name)] require.NotNil(t, channelContainer) - assert.True(t, channelContainer.Closed) + assert.True(t, channelContainer.IsClosed()) } } }) diff --git a/internal/bsr/internal/fstest/fs.go b/internal/bsr/internal/fstest/fs.go index 25fcf5bcf9..6f520ec2cd 100644 --- a/internal/bsr/internal/fstest/fs.go +++ b/internal/bsr/internal/fstest/fs.go @@ -91,7 +91,7 @@ func (m *MemFS) Open(_ context.Context, n string) (storage.Container, error) { if !ok { return nil, fmt.Errorf("container %s not found: %w", n, ErrDoesNotExist) } - c.Closed = false + c.closed = false return c, nil } @@ -103,7 +103,7 @@ type MemContainer struct { Files map[string]*MemFile originalFile bool - Closed bool + closed bool accessMode storage.AccessMode mode sfs.FileMode @@ -115,13 +115,13 @@ type MemContainer struct { func (m *MemContainer) Close() error { m.Lock() defer m.Unlock() - m.Closed = true + m.closed = true return nil } // Create makes a new storage.File in the container. func (m *MemContainer) Create(ctx context.Context, n string) (storage.File, error) { - if m.Closed { + if m.closed { return nil, fmt.Errorf("create on closed container: %w", ErrClosed) } return m.OpenFile(ctx, n, storage.WithCreateFile(), storage.WithFileAccessMode(storage.ReadWrite)) @@ -138,7 +138,7 @@ func (m *MemContainer) OpenFile(_ context.Context, n string, option ...storage.O m.Lock() defer m.Unlock() - if m.Closed { + if m.closed { return nil, fmt.Errorf("create on closed container: %w", ErrClosed) } opts := storage.GetOpts(option...) @@ -204,7 +204,7 @@ func (m *MemContainer) OpenFile(_ context.Context, n string, option ...storage.O func (m *MemContainer) SubContainer(_ context.Context, n string, option ...storage.Option) (storage.Container, error) { m.Lock() defer m.Unlock() - if m.Closed { + if m.closed { return nil, fmt.Errorf("subcontainer on closed container: %w", ErrClosed) } opts := storage.GetOpts(option...) @@ -231,11 +231,17 @@ func (m *MemContainer) SubContainer(_ context.Context, n string, option ...stora } } c.accessMode = opts.WithFileAccessMode - c.Closed = false + c.closed = false m.Sub[n] = c return c, nil } +func (m *MemContainer) IsClosed() bool { + m.RLock() + defer m.RUnlock() + return m.closed +} + // memFileInfo implements storage.FileInfo type memFileInfo struct { name string diff --git a/internal/cmd/commands/server/controller_db_swap_test.go b/internal/cmd/commands/server/controller_db_swap_test.go index 846d3eafd9..73abff3f66 100644 --- a/internal/cmd/commands/server/controller_db_swap_test.go +++ b/internal/cmd/commands/server/controller_db_swap_test.go @@ -241,6 +241,7 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) { cfgHcl := fmt.Sprintf(dbSwapConfig, urlA, controllerKey, workerAuthKey, recoveryKey) require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644)) + errCh := make(chan error, 1) wg := &sync.WaitGroup{} wg.Add(1) go func() { @@ -250,12 +251,15 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) { exitCode := cmd.Run(args) if exitCode != 0 { output := cmd.UI.(*cli.MockUi).ErrorWriter.String() + cmd.UI.(*cli.MockUi).OutputWriter.String() - t.Errorf("got a non-zero exit status: %s", output) + errCh <- fmt.Errorf("got a non-zero exit status: %s", output) + close(errCh) } }() // Wait until things are up and running (or timeout). select { + case err := <-errCh: + t.Fatal(err) case <-cmd.startedCh: case <-time.After(15 * time.Second): t.Fatal("timeout") diff --git a/internal/credential/service_list_credentials_ext_test.go b/internal/credential/service_list_credentials_ext_test.go index b2cc0a326c..3a588bd0a2 100644 --- a/internal/credential/service_list_credentials_ext_test.go +++ b/internal/credential/service_list_credentials_ext_test.go @@ -348,7 +348,7 @@ func TestService_List(t *testing.T) { filterFunc := func(_ context.Context, c credential.Static) (bool, error) { return true, nil } - _, err = credential.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) + _, err := credential.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/credential/service_list_libraries_ext_test.go b/internal/credential/service_list_libraries_ext_test.go index 27235f932b..614dad9762 100644 --- a/internal/credential/service_list_libraries_ext_test.go +++ b/internal/credential/service_list_libraries_ext_test.go @@ -165,7 +165,7 @@ func TestLibraryService_List(t *testing.T) { filterFunc := func(_ context.Context, l credential.Library) (bool, error) { return true, nil } - _, err = credential.ListLibrariesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) + _, err := credential.ListLibrariesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -253,7 +253,7 @@ func TestLibraryService_List(t *testing.T) { filterFunc := func(_ context.Context, l credential.Library) (bool, error) { return true, nil } - _, err = credential.ListLibrariesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) + _, err := credential.ListLibrariesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("nil repo", func(t *testing.T) { diff --git a/internal/daemon/worker/status.go b/internal/daemon/worker/status.go index 7274823fdc..8b809fe997 100644 --- a/internal/daemon/worker/status.go +++ b/internal/daemon/worker/status.go @@ -359,12 +359,13 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess // If we have post hooks for after the first status check, run them now if w.everAuthenticated.CompareAndSwap(authenticationStatusFirstAuthentication, authenticationStatusFirstStatusRpcSuccessful) { if downstreamWorkersFactory != nil { - w.downstreamWorkers, err = downstreamWorkersFactory(cancelCtx, w.LastStatusSuccess().WorkerId, versionInfo.FullVersionNumber(false)) + downstreamWorkers, err := downstreamWorkersFactory(cancelCtx, w.LastStatusSuccess().WorkerId, versionInfo.FullVersionNumber(false)) if err != nil { event.WriteError(cancelCtx, op, err) w.conf.ServerSideShutdownCh <- struct{}{} return } + w.downstreamWorkers.Store(&downstreamersContainer{downstreamers: downstreamWorkers}) } for _, fn := range firstStatusCheckPostHooks { if err := fn(cancelCtx, w); err != nil { diff --git a/internal/daemon/worker/worker.go b/internal/daemon/worker/worker.go index 093004a60f..cd2d421104 100644 --- a/internal/daemon/worker/worker.go +++ b/internal/daemon/worker/worker.go @@ -68,6 +68,13 @@ type reverseConnReceiver interface { StartProcessingPendingConnections(context.Context, func() string) error } +// downstreamersContainer is a struct that exists purely so we can perform +// atomic swap operations on the interface, to avoid/fix data races in tests +// (and any other potential location). +type downstreamersContainer struct { + downstreamers +} + // downstreamers provides at least a minimum interface that must be met by a // Worker.downstreamWorkers field which is far better than allowing any (empty // interface) @@ -162,7 +169,7 @@ type Worker struct { RecordingStorage storage.RecordingStorage // downstream workers and routes to those workers - downstreamWorkers downstreamers + downstreamWorkers *atomic.Pointer[downstreamersContainer] downstreamReceiver reverseConnReceiver // Timing variables. These are atomics for SIGHUP support, and are int64 @@ -211,6 +218,7 @@ func New(ctx context.Context, conf *Config) (*Worker, error) { successfulStatusGracePeriod: new(atomic.Int64), statusCallTimeoutDuration: new(atomic.Int64), upstreamConnectionState: new(atomic.Value), + downstreamWorkers: new(atomic.Pointer[downstreamersContainer]), } w.operationalState.Store(server.UnknownOperationalState) @@ -301,7 +309,7 @@ func New(ctx context.Context, conf *Config) (*Worker, error) { w.statusCallTimeoutDuration.Store(int64(conf.RawConfig.Worker.StatusCallTimeoutDuration)) } // FIXME: This is really ugly, but works. - session.CloseCallTimeout = w.statusCallTimeoutDuration + session.CloseCallTimeout.Store(w.successfulStatusGracePeriod.Load()) if recorderManagerFactory != nil { var err error diff --git a/internal/host/plugin/service_list_host_sets_ext_test.go b/internal/host/plugin/service_list_host_sets_ext_test.go index 4367a0b1cf..8914cf92f8 100644 --- a/internal/host/plugin/service_list_host_sets_ext_test.go +++ b/internal/host/plugin/service_list_host_sets_ext_test.go @@ -335,7 +335,7 @@ func TestService_ListHostSets(t *testing.T) { filterFunc := func(_ context.Context, s host.Set, plg *plugin.Plugin) (bool, error) { return true, nil } - _, _, err = hostplugin.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, _, err := hostplugin.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/host/plugin/service_list_hosts_ext_test.go b/internal/host/plugin/service_list_hosts_ext_test.go index e7789fd521..58e9fb79ff 100644 --- a/internal/host/plugin/service_list_hosts_ext_test.go +++ b/internal/host/plugin/service_list_hosts_ext_test.go @@ -170,7 +170,7 @@ func TestService_ListHosts(t *testing.T) { filterFunc := func(_ context.Context, h host.Host, plg *plugin.Plugin) (bool, error) { return true, nil } - _, _, err = hostplugin.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, _, err := hostplugin.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -336,7 +336,7 @@ func TestService_ListHosts(t *testing.T) { filterFunc := func(_ context.Context, h host.Host, plg *plugin.Plugin) (bool, error) { return true, nil } - _, _, err = hostplugin.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, _, err := hostplugin.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/host/service_list_catalogs_ext_test.go b/internal/host/service_list_catalogs_ext_test.go index 2e5760521a..1cfc4cada0 100644 --- a/internal/host/service_list_catalogs_ext_test.go +++ b/internal/host/service_list_catalogs_ext_test.go @@ -375,7 +375,7 @@ func TestCatalogService_List(t *testing.T) { filterFunc := func(_ context.Context, c host.Catalog, plgs map[string]*plugin.Plugin) (bool, error) { return true, nil } - _, _, err = host.ListCatalogsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{prj.PublicId}) + _, _, err := host.ListCatalogsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{prj.PublicId}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/host/static/service_list_host_sets_ext_test.go b/internal/host/static/service_list_host_sets_ext_test.go index a863d43749..24594d2630 100644 --- a/internal/host/static/service_list_host_sets_ext_test.go +++ b/internal/host/static/service_list_host_sets_ext_test.go @@ -158,7 +158,7 @@ func TestService_ListHostSets(t *testing.T) { filterFunc := func(_ context.Context, s host.Set) (bool, error) { return true, nil } - _, err = static.ListHostSetsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, err := static.ListHostSetsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -324,7 +324,7 @@ func TestService_ListHostSets(t *testing.T) { filterFunc := func(_ context.Context, s host.Set) (bool, error) { return true, nil } - _, err = static.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, err := static.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/host/static/service_list_hosts_ext_test.go b/internal/host/static/service_list_hosts_ext_test.go index 753a6bfdff..7a6e39066a 100644 --- a/internal/host/static/service_list_hosts_ext_test.go +++ b/internal/host/static/service_list_hosts_ext_test.go @@ -158,7 +158,7 @@ func TestService_ListHosts(t *testing.T) { filterFunc := func(_ context.Context, h host.Host) (bool, error) { return true, nil } - _, err = static.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, err := static.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -246,7 +246,7 @@ func TestService_ListHosts(t *testing.T) { filterFunc := func(_ context.Context, h host.Host) (bool, error) { return true, nil } - _, err = static.ListHostsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, err := static.ListHostsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("nil repo", func(t *testing.T) { @@ -324,7 +324,7 @@ func TestService_ListHosts(t *testing.T) { filterFunc := func(_ context.Context, h host.Host) (bool, error) { return true, nil } - _, err = static.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) + _, err := static.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId()) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/iam/service_list_ext_test.go b/internal/iam/service_list_ext_test.go index 83ed3814c0..bf6ed09e48 100644 --- a/internal/iam/service_list_ext_test.go +++ b/internal/iam/service_list_ext_test.go @@ -151,7 +151,7 @@ func TestService_ListRoles(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Role) (bool, error) { return true, nil } - _, err = iam.ListRolesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListRolesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -239,7 +239,7 @@ func TestService_ListRoles(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Role) (bool, error) { return true, nil } - _, err = iam.ListRolesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListRolesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -327,7 +327,7 @@ func TestService_ListRoles(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Role) (bool, error) { return true, nil } - _, err = iam.ListRolesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListRolesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -751,7 +751,7 @@ func TestService_ListUsers(t *testing.T) { filterFunc := func(_ context.Context, r *iam.User) (bool, error) { return true, nil } - _, err = iam.ListUsersPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListUsersPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -927,7 +927,7 @@ func TestService_ListUsers(t *testing.T) { filterFunc := func(_ context.Context, r *iam.User) (bool, error) { return true, nil } - _, err = iam.ListUsersRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListUsersRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -1544,7 +1544,7 @@ func TestService_ListGroups(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Group) (bool, error) { return true, nil } - _, err = iam.ListGroupsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, relevantScopes) + _, err := iam.ListGroupsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, relevantScopes) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -2045,7 +2045,7 @@ func TestService_ListScopes(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Scope) (bool, error) { return true, nil } - _, err = iam.ListScopesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListScopesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -2133,7 +2133,7 @@ func TestService_ListScopes(t *testing.T) { filterFunc := func(_ context.Context, r *iam.Scope) (bool, error) { return true, nil } - _, err = iam.ListScopesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) + _, err := iam.ListScopesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()}) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/plugin/loopback/testing_grpc_stream.go b/internal/plugin/loopback/testing_grpc_stream.go index b61a2b59ec..62ebd444c4 100644 --- a/internal/plugin/loopback/testing_grpc_stream.go +++ b/internal/plugin/loopback/testing_grpc_stream.go @@ -143,6 +143,9 @@ type getObjectServer struct { // isStreamClosed is used to check if the stream is closed. // This is needed because the channel can be closed by the client or the server. isStreamClosed func() bool + + // This is shared with the stream to prevent sending on closed channels + m *sync.Mutex } // Send will send a message to the client. @@ -156,6 +159,8 @@ func (s *getObjectServer) Send(resp *plgpb.GetObjectResponse) error { return fmt.Errorf("stream is closed") } + s.m.Lock() + defer s.m.Unlock() select { case s.sendToClient <- &getObjectStreamResponse{msg: resp}: case <-s.ctx.Done(): @@ -201,6 +206,8 @@ func (s *getObjectServer) SendMsg(m interface{}) error { if s.isStreamClosed() { return fmt.Errorf("stream is closed") } + s.m.Lock() + defer s.m.Unlock() select { case s.sendToClient <- &getObjectStreamResponse{msg: msg}: case <-s.ctx.Done(): @@ -211,6 +218,8 @@ func (s *getObjectServer) SendMsg(m interface{}) error { return fmt.Errorf("stream is closed") } defer s.closeStream() + s.m.Lock() + defer s.m.Unlock() select { case s.sendToClient <- &getObjectStreamResponse{err: msg}: case <-s.ctx.Done(): @@ -250,6 +259,7 @@ func newGetObjectStream() *getObjectStream { sendToClient: stream.messages, closeStream: stream.Close, isStreamClosed: stream.IsStreamClosed, + m: stream.m, } return stream } diff --git a/internal/plugin/loopback/testing_grpc_stream_test.go b/internal/plugin/loopback/testing_grpc_stream_test.go index 21b46b04df..d33a7832ed 100644 --- a/internal/plugin/loopback/testing_grpc_stream_test.go +++ b/internal/plugin/loopback/testing_grpc_stream_test.go @@ -90,7 +90,10 @@ func Test_GetObjectStream_Client(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - stream.messages <- &getObjectStreamResponse{ + stream.m.Lock() + messages := stream.messages + stream.m.Unlock() + messages <- &getObjectStreamResponse{ msg: &plgpb.GetObjectResponse{}, } }() @@ -116,7 +119,10 @@ func Test_GetObjectStream_Client(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - stream.messages <- &getObjectStreamResponse{ + stream.m.Lock() + messages := stream.messages + stream.m.Unlock() + messages <- &getObjectStreamResponse{ err: fmt.Errorf("mock error"), } }() diff --git a/internal/session/service_list_ext_test.go b/internal/session/service_list_ext_test.go index 84a1bcf67f..a602e77227 100644 --- a/internal/session/service_list_ext_test.go +++ b/internal/session/service_list_ext_test.go @@ -233,7 +233,7 @@ func TestService_List(t *testing.T) { filterFunc := func(_ context.Context, s *session.Session) (bool, error) { return true, nil } - _, err = session.ListPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true) + _, err := session.ListPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { @@ -311,7 +311,7 @@ func TestService_List(t *testing.T) { filterFunc := func(_ context.Context, s *session.Session) (bool, error) { return true, nil } - _, err = session.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true) + _, err := session.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) { diff --git a/internal/target/service_list_ext_test.go b/internal/target/service_list_ext_test.go index 83d8ac9523..ee23b2e752 100644 --- a/internal/target/service_list_ext_test.go +++ b/internal/target/service_list_ext_test.go @@ -316,7 +316,7 @@ func TestService_List(t *testing.T) { filterFunc := func(_ context.Context, t target.Target) (bool, error) { return true, nil } - _, err = target.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo) + _, err := target.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo) require.ErrorContains(t, err, "missing token") }) t.Run("wrong token type", func(t *testing.T) {