Add some more fixes caught by race detector (#4488)

pull/4496/head
Jeff Mitchell 2 years ago committed by GitHub
parent 39a0c0d833
commit 9408b2b299
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -172,7 +172,7 @@ func TestService_ListAccounts(t *testing.T) {
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
_, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -338,7 +338,7 @@ func TestService_ListAccounts(t *testing.T) {
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
_, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -344,7 +344,7 @@ func TestService_ListAccounts(t *testing.T) {
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = oidc.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
_, err := oidc.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -250,7 +250,7 @@ func TestService_ListAccounts(t *testing.T) {
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
_, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("nil repo", func(t *testing.T) {
@ -328,7 +328,7 @@ func TestService_ListAccounts(t *testing.T) {
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
_, err := password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -237,7 +237,7 @@ func TestStoreService_List(t *testing.T) {
filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) {
return true, nil
}
_, err = auth.ListAuthMethodsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
_, err := auth.ListAuthMethodsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -325,7 +325,7 @@ func TestStoreService_List(t *testing.T) {
filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) {
return true, nil
}
_, err = auth.ListAuthMethodsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
_, err := auth.ListAuthMethodsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
require.ErrorContains(t, err, "missing token")
})
t.Run("missing repo", func(t *testing.T) {
@ -403,7 +403,7 @@ func TestStoreService_List(t *testing.T) {
filterFunc := func(_ context.Context, am auth.AuthMethod) (bool, error) {
return true, nil
}
_, err = auth.ListAuthMethodsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
_, err := auth.ListAuthMethodsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.PublicId}, false)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -331,7 +331,7 @@ func TestService_List(t *testing.T) {
filterFunc := func(_ context.Context, at *authtoken.AuthToken) (bool, error) {
return true, nil
}
_, err = authtoken.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := authtoken.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -538,7 +538,7 @@ func TestCloseBSRMethods(t *testing.T) {
// Get session container
sessionContainer := f.Containers[fmt.Sprintf(bsrFileNameTemplate, sessionRecordingId)]
require.NotNil(t, sessionContainer)
assert.True(t, sessionContainer.Closed)
assert.True(t, sessionContainer.IsClosed())
// Ensure all session files are closed
for _, file := range sessionContainer.Files {
@ -548,7 +548,7 @@ func TestCloseBSRMethods(t *testing.T) {
// Get connection container
connectionContainer := sessionContainer.Sub[fmt.Sprintf(connectionFileNameTemplate, connectionId)]
require.NotNil(t, connectionContainer)
assert.True(t, connectionContainer.Closed)
assert.True(t, connectionContainer.IsClosed())
// Ensure all connection files are closed
for _, file := range connectionContainer.Files {
@ -558,7 +558,7 @@ func TestCloseBSRMethods(t *testing.T) {
// Get channel container
channelContainer := connectionContainer.Sub[fmt.Sprintf(channelFileNameTemplate, channelId)]
require.NotNil(t, channelContainer)
assert.True(t, channelContainer.Closed)
assert.True(t, channelContainer.IsClosed())
// Ensure all channel files are closed
for _, file := range channelContainer.Files {

@ -355,7 +355,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) {
// Ensure session container is closed
sessionContainer := fs.Containers[fmt.Sprintf(bsrFileNameTemplate, validation.SessionRecordingId)]
require.NotNil(t, sessionContainer)
assert.True(t, sessionContainer.Closed)
assert.True(t, sessionContainer.IsClosed())
// Validate Multiple Connections
for _, connection := range validation.SessionRecordingValidation.SubContainers {
@ -368,7 +368,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) {
// Ensure connection container is closed
connectionContainer := sessionContainer.Sub[fmt.Sprintf(connectionFileNameTemplate, connection.Name)]
require.NotNil(t, connectionContainer)
assert.True(t, connectionContainer.Closed)
assert.True(t, connectionContainer.IsClosed())
// Validate Multiple Channels
for _, channel := range connection.SubContainers {
@ -381,7 +381,7 @@ func TestBSR_Validate_ValidateBSR(t *testing.T) {
// Ensure channel container is closed
channelContainer := connectionContainer.Sub[fmt.Sprintf(channelFileNameTemplate, channel.Name)]
require.NotNil(t, channelContainer)
assert.True(t, channelContainer.Closed)
assert.True(t, channelContainer.IsClosed())
}
}
})

@ -91,7 +91,7 @@ func (m *MemFS) Open(_ context.Context, n string) (storage.Container, error) {
if !ok {
return nil, fmt.Errorf("container %s not found: %w", n, ErrDoesNotExist)
}
c.Closed = false
c.closed = false
return c, nil
}
@ -103,7 +103,7 @@ type MemContainer struct {
Files map[string]*MemFile
originalFile bool
Closed bool
closed bool
accessMode storage.AccessMode
mode sfs.FileMode
@ -115,13 +115,13 @@ type MemContainer struct {
func (m *MemContainer) Close() error {
m.Lock()
defer m.Unlock()
m.Closed = true
m.closed = true
return nil
}
// Create makes a new storage.File in the container.
func (m *MemContainer) Create(ctx context.Context, n string) (storage.File, error) {
if m.Closed {
if m.closed {
return nil, fmt.Errorf("create on closed container: %w", ErrClosed)
}
return m.OpenFile(ctx, n, storage.WithCreateFile(), storage.WithFileAccessMode(storage.ReadWrite))
@ -138,7 +138,7 @@ func (m *MemContainer) OpenFile(_ context.Context, n string, option ...storage.O
m.Lock()
defer m.Unlock()
if m.Closed {
if m.closed {
return nil, fmt.Errorf("create on closed container: %w", ErrClosed)
}
opts := storage.GetOpts(option...)
@ -204,7 +204,7 @@ func (m *MemContainer) OpenFile(_ context.Context, n string, option ...storage.O
func (m *MemContainer) SubContainer(_ context.Context, n string, option ...storage.Option) (storage.Container, error) {
m.Lock()
defer m.Unlock()
if m.Closed {
if m.closed {
return nil, fmt.Errorf("subcontainer on closed container: %w", ErrClosed)
}
opts := storage.GetOpts(option...)
@ -231,11 +231,17 @@ func (m *MemContainer) SubContainer(_ context.Context, n string, option ...stora
}
}
c.accessMode = opts.WithFileAccessMode
c.Closed = false
c.closed = false
m.Sub[n] = c
return c, nil
}
func (m *MemContainer) IsClosed() bool {
m.RLock()
defer m.RUnlock()
return m.closed
}
// memFileInfo implements storage.FileInfo
type memFileInfo struct {
name string

@ -241,6 +241,7 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) {
cfgHcl := fmt.Sprintf(dbSwapConfig, urlA, controllerKey, workerAuthKey, recoveryKey)
require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644))
errCh := make(chan error, 1)
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
@ -250,12 +251,15 @@ func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) {
exitCode := cmd.Run(args)
if exitCode != 0 {
output := cmd.UI.(*cli.MockUi).ErrorWriter.String() + cmd.UI.(*cli.MockUi).OutputWriter.String()
t.Errorf("got a non-zero exit status: %s", output)
errCh <- fmt.Errorf("got a non-zero exit status: %s", output)
close(errCh)
}
}()
// Wait until things are up and running (or timeout).
select {
case err := <-errCh:
t.Fatal(err)
case <-cmd.startedCh:
case <-time.After(15 * time.Second):
t.Fatal("timeout")

@ -348,7 +348,7 @@ func TestService_List(t *testing.T) {
filterFunc := func(_ context.Context, c credential.Static) (bool, error) {
return true, nil
}
_, err = credential.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
_, err := credential.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -165,7 +165,7 @@ func TestLibraryService_List(t *testing.T) {
filterFunc := func(_ context.Context, l credential.Library) (bool, error) {
return true, nil
}
_, err = credential.ListLibrariesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
_, err := credential.ListLibrariesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -253,7 +253,7 @@ func TestLibraryService_List(t *testing.T) {
filterFunc := func(_ context.Context, l credential.Library) (bool, error) {
return true, nil
}
_, err = credential.ListLibrariesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
_, err := credential.ListLibrariesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, credStore.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("nil repo", func(t *testing.T) {

@ -359,12 +359,13 @@ func (w *Worker) sendWorkerStatus(cancelCtx context.Context, sessionManager sess
// If we have post hooks for after the first status check, run them now
if w.everAuthenticated.CompareAndSwap(authenticationStatusFirstAuthentication, authenticationStatusFirstStatusRpcSuccessful) {
if downstreamWorkersFactory != nil {
w.downstreamWorkers, err = downstreamWorkersFactory(cancelCtx, w.LastStatusSuccess().WorkerId, versionInfo.FullVersionNumber(false))
downstreamWorkers, err := downstreamWorkersFactory(cancelCtx, w.LastStatusSuccess().WorkerId, versionInfo.FullVersionNumber(false))
if err != nil {
event.WriteError(cancelCtx, op, err)
w.conf.ServerSideShutdownCh <- struct{}{}
return
}
w.downstreamWorkers.Store(&downstreamersContainer{downstreamers: downstreamWorkers})
}
for _, fn := range firstStatusCheckPostHooks {
if err := fn(cancelCtx, w); err != nil {

@ -68,6 +68,13 @@ type reverseConnReceiver interface {
StartProcessingPendingConnections(context.Context, func() string) error
}
// downstreamersContainer is a struct that exists purely so we can perform
// atomic swap operations on the interface, to avoid/fix data races in tests
// (and any other potential location).
type downstreamersContainer struct {
downstreamers
}
// downstreamers provides at least a minimum interface that must be met by a
// Worker.downstreamWorkers field which is far better than allowing any (empty
// interface)
@ -162,7 +169,7 @@ type Worker struct {
RecordingStorage storage.RecordingStorage
// downstream workers and routes to those workers
downstreamWorkers downstreamers
downstreamWorkers *atomic.Pointer[downstreamersContainer]
downstreamReceiver reverseConnReceiver
// Timing variables. These are atomics for SIGHUP support, and are int64
@ -211,6 +218,7 @@ func New(ctx context.Context, conf *Config) (*Worker, error) {
successfulStatusGracePeriod: new(atomic.Int64),
statusCallTimeoutDuration: new(atomic.Int64),
upstreamConnectionState: new(atomic.Value),
downstreamWorkers: new(atomic.Pointer[downstreamersContainer]),
}
w.operationalState.Store(server.UnknownOperationalState)
@ -301,7 +309,7 @@ func New(ctx context.Context, conf *Config) (*Worker, error) {
w.statusCallTimeoutDuration.Store(int64(conf.RawConfig.Worker.StatusCallTimeoutDuration))
}
// FIXME: This is really ugly, but works.
session.CloseCallTimeout = w.statusCallTimeoutDuration
session.CloseCallTimeout.Store(w.successfulStatusGracePeriod.Load())
if recorderManagerFactory != nil {
var err error

@ -335,7 +335,7 @@ func TestService_ListHostSets(t *testing.T) {
filterFunc := func(_ context.Context, s host.Set, plg *plugin.Plugin) (bool, error) {
return true, nil
}
_, _, err = hostplugin.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, _, err := hostplugin.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -170,7 +170,7 @@ func TestService_ListHosts(t *testing.T) {
filterFunc := func(_ context.Context, h host.Host, plg *plugin.Plugin) (bool, error) {
return true, nil
}
_, _, err = hostplugin.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, _, err := hostplugin.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -336,7 +336,7 @@ func TestService_ListHosts(t *testing.T) {
filterFunc := func(_ context.Context, h host.Host, plg *plugin.Plugin) (bool, error) {
return true, nil
}
_, _, err = hostplugin.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, _, err := hostplugin.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -375,7 +375,7 @@ func TestCatalogService_List(t *testing.T) {
filterFunc := func(_ context.Context, c host.Catalog, plgs map[string]*plugin.Plugin) (bool, error) {
return true, nil
}
_, _, err = host.ListCatalogsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{prj.PublicId})
_, _, err := host.ListCatalogsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{prj.PublicId})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -158,7 +158,7 @@ func TestService_ListHostSets(t *testing.T) {
filterFunc := func(_ context.Context, s host.Set) (bool, error) {
return true, nil
}
_, err = static.ListHostSetsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, err := static.ListHostSetsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -324,7 +324,7 @@ func TestService_ListHostSets(t *testing.T) {
filterFunc := func(_ context.Context, s host.Set) (bool, error) {
return true, nil
}
_, err = static.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, err := static.ListHostSetsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -158,7 +158,7 @@ func TestService_ListHosts(t *testing.T) {
filterFunc := func(_ context.Context, h host.Host) (bool, error) {
return true, nil
}
_, err = static.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, err := static.ListHostsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -246,7 +246,7 @@ func TestService_ListHosts(t *testing.T) {
filterFunc := func(_ context.Context, h host.Host) (bool, error) {
return true, nil
}
_, err = static.ListHostsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, err := static.ListHostsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("nil repo", func(t *testing.T) {
@ -324,7 +324,7 @@ func TestService_ListHosts(t *testing.T) {
filterFunc := func(_ context.Context, h host.Host) (bool, error) {
return true, nil
}
_, err = static.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
_, err := static.ListHostsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, catalog.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -151,7 +151,7 @@ func TestService_ListRoles(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Role) (bool, error) {
return true, nil
}
_, err = iam.ListRolesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListRolesPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -239,7 +239,7 @@ func TestService_ListRoles(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Role) (bool, error) {
return true, nil
}
_, err = iam.ListRolesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListRolesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -327,7 +327,7 @@ func TestService_ListRoles(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Role) (bool, error) {
return true, nil
}
_, err = iam.ListRolesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListRolesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -751,7 +751,7 @@ func TestService_ListUsers(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.User) (bool, error) {
return true, nil
}
_, err = iam.ListUsersPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListUsersPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -927,7 +927,7 @@ func TestService_ListUsers(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.User) (bool, error) {
return true, nil
}
_, err = iam.ListUsersRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListUsersRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -1544,7 +1544,7 @@ func TestService_ListGroups(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Group) (bool, error) {
return true, nil
}
_, err = iam.ListGroupsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, relevantScopes)
_, err := iam.ListGroupsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, relevantScopes)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -2045,7 +2045,7 @@ func TestService_ListScopes(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Scope) (bool, error) {
return true, nil
}
_, err = iam.ListScopesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListScopesRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -2133,7 +2133,7 @@ func TestService_ListScopes(t *testing.T) {
filterFunc := func(_ context.Context, r *iam.Scope) (bool, error) {
return true, nil
}
_, err = iam.ListScopesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
_, err := iam.ListScopesRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, []string{org.GetPublicId()})
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -143,6 +143,9 @@ type getObjectServer struct {
// isStreamClosed is used to check if the stream is closed.
// This is needed because the channel can be closed by the client or the server.
isStreamClosed func() bool
// This is shared with the stream to prevent sending on closed channels
m *sync.Mutex
}
// Send will send a message to the client.
@ -156,6 +159,8 @@ func (s *getObjectServer) Send(resp *plgpb.GetObjectResponse) error {
return fmt.Errorf("stream is closed")
}
s.m.Lock()
defer s.m.Unlock()
select {
case s.sendToClient <- &getObjectStreamResponse{msg: resp}:
case <-s.ctx.Done():
@ -201,6 +206,8 @@ func (s *getObjectServer) SendMsg(m interface{}) error {
if s.isStreamClosed() {
return fmt.Errorf("stream is closed")
}
s.m.Lock()
defer s.m.Unlock()
select {
case s.sendToClient <- &getObjectStreamResponse{msg: msg}:
case <-s.ctx.Done():
@ -211,6 +218,8 @@ func (s *getObjectServer) SendMsg(m interface{}) error {
return fmt.Errorf("stream is closed")
}
defer s.closeStream()
s.m.Lock()
defer s.m.Unlock()
select {
case s.sendToClient <- &getObjectStreamResponse{err: msg}:
case <-s.ctx.Done():
@ -250,6 +259,7 @@ func newGetObjectStream() *getObjectStream {
sendToClient: stream.messages,
closeStream: stream.Close,
isStreamClosed: stream.IsStreamClosed,
m: stream.m,
}
return stream
}

@ -90,7 +90,10 @@ func Test_GetObjectStream_Client(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
stream.messages <- &getObjectStreamResponse{
stream.m.Lock()
messages := stream.messages
stream.m.Unlock()
messages <- &getObjectStreamResponse{
msg: &plgpb.GetObjectResponse{},
}
}()
@ -116,7 +119,10 @@ func Test_GetObjectStream_Client(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
stream.messages <- &getObjectStreamResponse{
stream.m.Lock()
messages := stream.messages
stream.m.Unlock()
messages <- &getObjectStreamResponse{
err: fmt.Errorf("mock error"),
}
}()

@ -233,7 +233,7 @@ func TestService_List(t *testing.T) {
filterFunc := func(_ context.Context, s *session.Session) (bool, error) {
return true, nil
}
_, err = session.ListPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true)
_, err := session.ListPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
@ -311,7 +311,7 @@ func TestService_List(t *testing.T) {
filterFunc := func(_ context.Context, s *session.Session) (bool, error) {
return true, nil
}
_, err = session.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true)
_, err := session.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, true)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

@ -316,7 +316,7 @@ func TestService_List(t *testing.T) {
filterFunc := func(_ context.Context, t target.Target) (bool, error) {
return true, nil
}
_, err = target.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo)
_, err := target.ListRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo)
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {

Loading…
Cancel
Save