From 5054a5c83610f03aad56ca376fb8245c3073415f Mon Sep 17 00:00:00 2001 From: Damian Debkowski Date: Thu, 6 Feb 2025 16:51:16 +0000 Subject: [PATCH 1/4] backport of commit 0529264f938ff7e296f3cac798a2150cb3b3af2d --- internal/auth/oidc/repository_auth_method.go | 2 +- internal/auth/repository_auth_method.go | 2 +- internal/credential/repository_store.go | 2 +- internal/host/repository_catalog.go | 2 +- internal/host/static/repository_host.go | 4 ++-- internal/iam/repository_role.go | 2 +- internal/iam/repository_role_grant.go | 2 +- internal/server/options.go | 12 ++++++++++++ internal/server/options_test.go | 16 ++++++++++++++++ internal/server/repository_worker.go | 6 +++--- ...ory_worker_storage_bucket_credential_state.go | 12 +++++++++--- 11 files changed, 48 insertions(+), 14 deletions(-) diff --git a/internal/auth/oidc/repository_auth_method.go b/internal/auth/oidc/repository_auth_method.go index 0fff0ca776..c8bf4c592d 100644 --- a/internal/auth/oidc/repository_auth_method.go +++ b/internal/auth/oidc/repository_auth_method.go @@ -179,7 +179,7 @@ func (r *Repository) upsertAccount(ctx context.Context, am *AuthMethod, IdTokenC var rowCnt int for rows.Next() { rowCnt += 1 - err = r.reader.ScanRows(ctx, rows, &result) + err = reader.ScanRows(ctx, rows, &result) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to scan rows for account")) } diff --git a/internal/auth/repository_auth_method.go b/internal/auth/repository_auth_method.go index d97d226959..bbd35fc549 100644 --- a/internal/auth/repository_auth_method.go +++ b/internal/auth/repository_auth_method.go @@ -147,7 +147,7 @@ func (amr *AuthMethodRepository) ListDeletedIds(ctx context.Context, since time. var deletedAuthMethodIDs []string var transactionTimestamp time.Time if _, err := amr.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { - rows, err := amr.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) + rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) if err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/credential/repository_store.go b/internal/credential/repository_store.go index 02886f9bc1..338cff4452 100644 --- a/internal/credential/repository_store.go +++ b/internal/credential/repository_store.go @@ -118,7 +118,7 @@ func (s *StoreRepository) ListDeletedIds(ctx context.Context, since time.Time) ( var deletedStoreIDs []string var transactionTimestamp time.Time if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { - rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) + rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) if err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/host/repository_catalog.go b/internal/host/repository_catalog.go index be9c04dcc3..b1fdf57118 100644 --- a/internal/host/repository_catalog.go +++ b/internal/host/repository_catalog.go @@ -119,7 +119,7 @@ func (s *CatalogRepository) ListDeletedIds(ctx context.Context, since time.Time) var deletedCatalogIDs []string var transactionTimestamp time.Time if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { - rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) + rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)}) if err != nil { return errors.Wrap(ctx, err, op) } diff --git a/internal/host/static/repository_host.go b/internal/host/static/repository_host.go index acf4e970c5..c2d4df5652 100644 --- a/internal/host/static/repository_host.go +++ b/internal/host/static/repository_host.go @@ -171,7 +171,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host, var rowsUpdated int var returnedHost *Host _, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(_ db.Reader, w db.Writer) error { + func(r db.Reader, w db.Writer) error { returnedHost = h.clone() var err error rowsUpdated, err = w.Update(ctx, returnedHost, dbMask, nullFields, @@ -186,7 +186,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host, ha := &hostAgg{ PublicId: h.PublicId, } - if err := r.reader.LookupByPublicId(ctx, ha); err != nil { + if err := r.LookupByPublicId(ctx, ha); err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("failed to lookup host after update")) } returnedHost.SetIds = ha.getSetIds() diff --git a/internal/iam/repository_role.go b/internal/iam/repository_role.go index 9503303b10..e0feb9cd06 100644 --- a/internal/iam/repository_role.go +++ b/internal/iam/repository_role.go @@ -325,7 +325,7 @@ func (r *Repository) queryRoles(ctx context.Context, whereClause string, args [] for _, retRole := range retRoles { roleIds = append(roleIds, retRole.PublicId) } - retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds) + retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds, WithReaderWriter(rd, w)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query role grant scopes")) } diff --git a/internal/iam/repository_role_grant.go b/internal/iam/repository_role_grant.go index d9a8d63172..92ed50914b 100644 --- a/internal/iam/repository_role_grant.go +++ b/internal/iam/repository_role_grant.go @@ -359,7 +359,7 @@ func (r *Repository) SetRoleGrants(ctx context.Context, roleId string, roleVersi return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog")) } - currentRoleGrants, err = r.ListRoleGrants(ctx, roleId) + currentRoleGrants, err = r.ListRoleGrants(ctx, roleId, WithReaderWriter(reader, w)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to retrieve current role grants after set")) } diff --git a/internal/server/options.go b/internal/server/options.go index a62e18556c..a28a8c2288 100644 --- a/internal/server/options.go +++ b/internal/server/options.go @@ -7,6 +7,7 @@ import ( "context" "time" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/version" "github.com/hashicorp/nodeenrollment/types" ) @@ -63,6 +64,8 @@ type options struct { withWorkerPool []string withFilterWorkersByStorageBucketCredentialState *StorageBucketCredentialInfo withFilterWorkersByLocalStorageState bool + WithReader db.Reader + WithWriter db.Writer } func getDefaultOptions() options { @@ -303,3 +306,12 @@ func WithFilterWorkersByLocalStorageState(filter bool) Option { o.withFilterWorkersByLocalStorageState = filter } } + +// WithReaderWriter is used to share the same database reader +// and writer when executing sql within a transaction. +func WithReaderWriter(r db.Reader, w db.Writer) Option { + return func(o *options) { + o.WithReader = r + o.WithWriter = w + } +} diff --git a/internal/server/options_test.go b/internal/server/options_test.go index cb4e620d63..7f0584cd9f 100644 --- a/internal/server/options_test.go +++ b/internal/server/options_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/version" "github.com/stretchr/testify/assert" ) @@ -252,4 +253,19 @@ func Test_GetOpts(t *testing.T) { testOpts.withNewIdFunc = nil assert.Equal(t, opts, testOpts) }) + t.Run("WithReaderWriter", func(t *testing.T) { + reader := &db.Db{} + writer := &db.Db{} + testOpts := getDefaultOptions() + assert.Nil(t, testOpts.WithReader) + assert.Nil(t, testOpts.WithWriter) + testOpts.WithReader = reader + testOpts.WithWriter = writer + opts := GetOpts(WithReaderWriter(reader, writer)) + opts.withNewIdFunc = nil + testOpts.withNewIdFunc = nil + assert.Equal(t, reader, opts.WithReader) + assert.Equal(t, writer, opts.WithWriter) + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/server/repository_worker.go b/internal/server/repository_worker.go index 435678b867..d77e64a8fe 100644 --- a/internal/server/repository_worker.go +++ b/internal/server/repository_worker.go @@ -600,7 +600,7 @@ func (r *Repository) UpdateWorker(ctx context.Context, worker *Worker, version u if err != nil { return errors.Wrap(ctx, err, op) } - ret.RemoteStorageStates, err = r.ListWorkerStorageBucketCredentialState(ctx, ret.GetPublicId()) + ret.RemoteStorageStates, err = r.ListWorkerStorageBucketCredentialState(ctx, ret.GetPublicId(), WithReaderWriter(reader, w)) if err != nil { return err } @@ -925,7 +925,7 @@ func (r *Repository) SelectSessionWorkers(ctx context.Context, db.StdRetryCnt, db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { - rows, err := r.reader.Query(ctx, query, []any{}) + rows, err := reader.Query(ctx, query, []any{}) if err != nil { return err } @@ -935,7 +935,7 @@ func (r *Repository) SelectSessionWorkers(ctx context.Context, // a Worker object can hold, only a subset. Check the query to // learn exactly what fields are present. var worker Worker - if err := r.reader.ScanRows(ctx, rows, &worker); err != nil { + if err := reader.ScanRows(ctx, rows, &worker); err != nil { return err } livingWorkers = append(livingWorkers, &worker) diff --git a/internal/server/repository_worker_storage_bucket_credential_state.go b/internal/server/repository_worker_storage_bucket_credential_state.go index cf14e4fe4c..ac80bec17b 100644 --- a/internal/server/repository_worker_storage_bucket_credential_state.go +++ b/internal/server/repository_worker_storage_bucket_credential_state.go @@ -9,11 +9,12 @@ import ( "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/util" plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin" ) // ListWorkerStorageBucketCredentialState returns a list of storage bucket credential states for the given worker. -func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, workerId string) (map[string]*plgpb.StorageBucketCredentialState, error) { +func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, workerId string, opts ...Option) (map[string]*plgpb.StorageBucketCredentialState, error) { const op = "server.(Repository).ListWorkerStorageBucketCredentialState" if workerId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "empty worker id") @@ -25,7 +26,12 @@ func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, CheckedAt *timestamp.Timestamp ErrorDetails string } - rows, err := r.reader.Query(ctx, getStorageBucketCredentialStatesByWorkerId, []any{sql.Named("worker_id", workerId)}) + opt := GetOpts(opts...) + reader := r.reader + if !util.IsNil(opt.WithReader) { + reader = opt.WithReader + } + rows, err := reader.Query(ctx, getStorageBucketCredentialStatesByWorkerId, []any{sql.Named("worker_id", workerId)}) if err != nil && !errors.Match(errors.T(errors.RecordNotFound), err) { return nil, errors.Wrap(ctx, err, op) } @@ -36,7 +42,7 @@ func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, return nil, errors.Wrap(ctx, err, op) } var row remoteStorageState - if err := r.reader.ScanRows(ctx, rows, &row); err != nil { + if err := reader.ScanRows(ctx, rows, &row); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("failed to fetch remote storage state")) } s, ok := remoteStorageStates[row.StorageBucketId] From deeac5c457e86c6c2bf6db50f00a385ecbf4a360 Mon Sep 17 00:00:00 2001 From: Johan Brandhorst-Satzkorn Date: Thu, 6 Feb 2025 17:33:01 +0000 Subject: [PATCH 2/4] backport of commit 8466210d99ee21a3153fbc8d3896b8f69ce862ee --- internal/server/repository_worker.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/server/repository_worker.go b/internal/server/repository_worker.go index d77e64a8fe..e41d7309b0 100644 --- a/internal/server/repository_worker.go +++ b/internal/server/repository_worker.go @@ -255,7 +255,7 @@ func ListWorkers(ctx context.Context, reader db.Reader, scopeIds []string, opt . defer rows.Close() for rows.Next() { var worker Worker - if err := reader.ScanRows(context.Background(), rows, &worker); err != nil { + if err := reader.ScanRows(ctx, rows, &worker); err != nil { return nil, err } workers = append(workers, &worker) From 42bf2f032c5b9b0ac8245ab0782ea78bd71faa71 Mon Sep 17 00:00:00 2001 From: Damian Debkowski Date: Thu, 6 Feb 2025 18:05:12 +0000 Subject: [PATCH 3/4] backport of commit 961d9d7d16a4c62053dfbbc0f7161c4dac4f53d3 --- internal/host/options.go | 20 +++++++++++++++++++ internal/host/options_test.go | 19 ++++++++++++++++++ internal/host/plugin/options.go | 12 +++++++++++ internal/host/plugin/options_test.go | 8 ++++++++ .../host/plugin/repository_host_catalog.go | 14 +++++++++---- internal/host/plugin/repository_host_set.go | 13 ++++++++++-- 6 files changed, 80 insertions(+), 6 deletions(-) diff --git a/internal/host/options.go b/internal/host/options.go index edfe8c98d0..5d240326e5 100644 --- a/internal/host/options.go +++ b/internal/host/options.go @@ -6,7 +6,9 @@ package host import ( "errors" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/util" ) // GetOpts - iterate the inbound Options and return a struct @@ -26,6 +28,8 @@ type Option func(*options) error // options = how options are represented type options struct { WithLimit int + WithReader db.Reader + WithWriter db.Writer WithOrderByCreateTime bool Ascending bool WithStartPageAfterItem pagination.Item @@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option { return nil } } + +// WithReaderWriter is used to share the same database reader +// and writer when executing sql within a transaction. +func WithReaderWriter(r db.Reader, w db.Writer) Option { + return func(o *options) error { + if util.IsNil(r) { + return errors.New("reader cannot be nil") + } + if util.IsNil(w) { + return errors.New("writer cannot be nil") + } + o.WithReader = r + o.WithWriter = w + return nil + } +} diff --git a/internal/host/options_test.go b/internal/host/options_test.go index 90c47b493e..9c3336ce07 100644 --- a/internal/host/options_test.go +++ b/internal/host/options_test.go @@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) { assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1") assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime)) }) + t.Run("WithReaderWriter", func(t *testing.T) { + t.Parallel() + t.Run("nil writer", func(t *testing.T) { + t.Parallel() + _, err := GetOpts(WithReaderWriter(&db.Db{}, nil)) + require.Error(t, err) + }) + t.Run("nil reader", func(t *testing.T) { + t.Parallel() + _, err := GetOpts(WithReaderWriter(nil, &db.Db{})) + require.Error(t, err) + }) + reader := &db.Db{} + writer := &db.Db{} + opts, err := GetOpts(WithReaderWriter(reader, writer)) + require.NoError(t, err) + assert.Equal(t, reader, opts.WithReader) + assert.Equal(t, writer, opts.WithWriter) + }) } diff --git a/internal/host/plugin/options.go b/internal/host/plugin/options.go index 936f6b717d..50dba24add 100644 --- a/internal/host/plugin/options.go +++ b/internal/host/plugin/options.go @@ -4,6 +4,7 @@ package plugin import ( + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/pagination" "google.golang.org/protobuf/types/known/structpb" ) @@ -38,6 +39,8 @@ type options struct { withSecretsHmac []byte withStartPageAfterItem pagination.Item withWorkerFilter string + WithReader db.Reader + withWriter db.Writer } func getDefaultOptions() options { @@ -162,3 +165,12 @@ func WithWorkerFilter(wf string) Option { o.withWorkerFilter = wf } } + +// WithReaderWriter is used to share the same database reader +// and writer when executing sql within a transaction. +func WithReaderWriter(r db.Reader, w db.Writer) Option { + return func(o *options) { + o.WithReader = r + o.withWriter = w + } +} diff --git a/internal/host/plugin/options_test.go b/internal/host/plugin/options_test.go index 80ef1df197..24fb5abe3b 100644 --- a/internal/host/plugin/options_test.go +++ b/internal/host/plugin/options_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/pagination" "github.com/stretchr/testify/assert" @@ -113,4 +114,11 @@ func Test_GetOpts(t *testing.T) { testOpts.withWorkerFilter = `"test" in "/tags/type"` assert.Equal(t, opts, testOpts) }) + t.Run("WithReaderWriter", func(t *testing.T) { + reader := &db.Db{} + writer := &db.Db{} + opts := getOpts(WithReaderWriter(reader, writer)) + assert.Equal(t, reader, opts.WithReader) + assert.Equal(t, writer, opts.withWriter) + }) } diff --git a/internal/host/plugin/repository_host_catalog.go b/internal/host/plugin/repository_host_catalog.go index 79830b76c1..9832e117a2 100644 --- a/internal/host/plugin/repository_host_catalog.go +++ b/internal/host/plugin/repository_host_catalog.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" + "github.com/hashicorp/boundary/internal/host" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/libs/patchstruct" "github.com/hashicorp/boundary/internal/oplog" @@ -404,7 +405,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(_ db.Reader, w db.Writer) error { + func(read db.Reader, w db.Writer) error { msgs := make([]*oplog.Message, 0, 3) ticket, err := w.GetTicket(ctx, newCatalog) if err != nil { @@ -528,7 +529,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version if needSetSync { // We also need to mark all host sets in this catalog to be // synced as well. - setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId) + setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog")) } @@ -713,14 +714,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, * return c, p, nil } -func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) { +func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) { const op = "plugin.(Repository).getPlugin" if plgId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id") } + opt := getOpts(opts...) + reader := r.reader + if !util.IsNil(opt.WithReader) { + reader = opt.WithReader + } plg := plg.NewPlugin() plg.PublicId = plgId - if err := r.reader.LookupByPublicId(ctx, plg); err != nil { + if err := reader.LookupByPublicId(ctx, plg); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId))) } return plg, nil diff --git a/internal/host/plugin/repository_host_set.go b/internal/host/plugin/repository_host_set.go index a3b32c7857..919564a013 100644 --- a/internal/host/plugin/repository_host_set.go +++ b/internal/host/plugin/repository_host_set.go @@ -804,6 +804,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str limit = opts.WithLimit } + reader := r.reader + writer := r.writer + if !util.IsNil(opts.WithReader) { + reader = opts.WithReader + } + if !util.IsNil(opts.WithWriter) { + writer = opts.WithWriter + } + args := make([]any, 0, 1) var where string @@ -825,7 +834,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str } var aggHostSets []*hostSetAgg - if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil { + if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId))) } @@ -844,7 +853,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str } var plg *plugin.Plugin if plgId != "" { - plg, err = r.getPlugin(ctx, plgId) + plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer)) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) } From 2e2d18a4804d4b2f661f1d51d00c0e172930b15a Mon Sep 17 00:00:00 2001 From: Damian Debkowski Date: Thu, 6 Feb 2025 19:46:11 +0000 Subject: [PATCH 4/4] backport of commit 5dfb5a04826b21ef2c2c72a3822a99252cb3d8e6 --- internal/auth/oidc/repository_managed_group_members.go | 5 +++-- internal/iam/repository.go | 5 +++-- internal/iam/repository_grant_scope.go | 3 ++- internal/iam/repository_role.go | 3 ++- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/auth/oidc/repository_managed_group_members.go b/internal/auth/oidc/repository_managed_group_members.go index 6c65413196..2f7ae1e3c9 100644 --- a/internal/auth/oidc/repository_managed_group_members.go +++ b/internal/auth/oidc/repository_managed_group_members.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/util" ) // SetManagedGroupMemberships will set the managed groups for the given account @@ -207,7 +208,7 @@ func (r *Repository) ListManagedGroupMembershipsByMember(ctx context.Context, wi limit = opts.withLimit } reader := r.reader - if opts.withReader != nil { + if !util.IsNil(opts.withReader) { reader = opts.withReader } var mgs []*ManagedGroupMemberAccount @@ -232,7 +233,7 @@ func (r *Repository) ListManagedGroupMembershipsByGroup(ctx context.Context, wit limit = opts.withLimit } reader := r.reader - if opts.withReader != nil { + if !util.IsNil(opts.withReader) { reader = opts.withReader } var mgs []*ManagedGroupMemberAccount diff --git a/internal/iam/repository.go b/internal/iam/repository.go index 33c2d8b69a..52e87283a8 100644 --- a/internal/iam/repository.go +++ b/internal/iam/repository.go @@ -13,6 +13,7 @@ import ( "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/types/scope" + "github.com/hashicorp/boundary/internal/util" ) var ErrMetadataScopeNotFound = errors.New(context.Background(), errors.RecordNotFound, "iam", "scope not found for metadata", errors.WithoutEvent()) @@ -65,7 +66,7 @@ func (r *Repository) list(ctx context.Context, resources any, where string, args limit = opts.withLimit } reader := r.reader - if opts.withReader != nil { + if !util.IsNil(opts.withReader) { reader = opts.withReader } return reader.SearchWhere(ctx, resources, where, args, db.WithLimit(limit)) @@ -150,7 +151,7 @@ func (r *Repository) update(ctx context.Context, resource Resource, version uint reader := r.reader writer := r.writer needFreshReaderWriter := true - if opts.withReader != nil && opts.withWriter != nil { + if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) { reader = opts.withReader writer = opts.withWriter if !writer.IsTx(ctx) { diff --git a/internal/iam/repository_grant_scope.go b/internal/iam/repository_grant_scope.go index db2c3f6f65..efdf52aa28 100644 --- a/internal/iam/repository_grant_scope.go +++ b/internal/iam/repository_grant_scope.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" + "github.com/hashicorp/boundary/internal/util" ) // AddRoleGrantScopes will add role grant scopes associated with the role ID in @@ -235,7 +236,7 @@ func (r *Repository) SetRoleGrantScopes(ctx context.Context, roleId string, role writer := r.writer needFreshReaderWriter := true opts := getOpts(opt...) - if opts.withReader != nil && opts.withWriter != nil { + if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) { reader = opts.withReader writer = opts.withWriter needFreshReaderWriter = false diff --git a/internal/iam/repository_role.go b/internal/iam/repository_role.go index e0feb9cd06..9250d420a0 100644 --- a/internal/iam/repository_role.go +++ b/internal/iam/repository_role.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/go-dbw" ) @@ -193,7 +194,7 @@ func (r *Repository) LookupRole(ctx context.Context, withPublicId string, opt .. } var err error - if opts.withReader != nil && opts.withWriter != nil { + if !util.IsNil(opts.withReader) && !util.IsNil(opts.withWriter) { if !opts.withWriter.IsTx(ctx) { return nil, nil, nil, nil, errors.New(ctx, errors.Internal, op, "writer is not in transaction") }