diff --git a/internal/host/options.go b/internal/host/options.go index edfe8c98d0..5d240326e5 100644 --- a/internal/host/options.go +++ b/internal/host/options.go @@ -6,7 +6,9 @@ package host import ( "errors" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/util" ) // GetOpts - iterate the inbound Options and return a struct @@ -26,6 +28,8 @@ type Option func(*options) error // options = how options are represented type options struct { WithLimit int + WithReader db.Reader + WithWriter db.Writer WithOrderByCreateTime bool Ascending bool WithStartPageAfterItem pagination.Item @@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option { return nil } } + +// WithReaderWriter is used to share the same database reader +// and writer when executing sql within a transaction. +func WithReaderWriter(r db.Reader, w db.Writer) Option { + return func(o *options) error { + if util.IsNil(r) { + return errors.New("reader cannot be nil") + } + if util.IsNil(w) { + return errors.New("writer cannot be nil") + } + o.WithReader = r + o.WithWriter = w + return nil + } +} diff --git a/internal/host/options_test.go b/internal/host/options_test.go index 90c47b493e..9c3336ce07 100644 --- a/internal/host/options_test.go +++ b/internal/host/options_test.go @@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) { assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1") assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime)) }) + t.Run("WithReaderWriter", func(t *testing.T) { + t.Parallel() + t.Run("nil writer", func(t *testing.T) { + t.Parallel() + _, err := GetOpts(WithReaderWriter(&db.Db{}, nil)) + require.Error(t, err) + }) + t.Run("nil reader", func(t *testing.T) { + t.Parallel() + _, err := GetOpts(WithReaderWriter(nil, &db.Db{})) + require.Error(t, err) + }) + reader := &db.Db{} + writer := &db.Db{} + opts, err := GetOpts(WithReaderWriter(reader, writer)) + require.NoError(t, err) + assert.Equal(t, reader, opts.WithReader) + assert.Equal(t, writer, opts.WithWriter) + }) } diff --git a/internal/host/plugin/options.go b/internal/host/plugin/options.go index 936f6b717d..50dba24add 100644 --- a/internal/host/plugin/options.go +++ b/internal/host/plugin/options.go @@ -4,6 +4,7 @@ package plugin import ( + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/pagination" "google.golang.org/protobuf/types/known/structpb" ) @@ -38,6 +39,8 @@ type options struct { withSecretsHmac []byte withStartPageAfterItem pagination.Item withWorkerFilter string + WithReader db.Reader + withWriter db.Writer } func getDefaultOptions() options { @@ -162,3 +165,12 @@ func WithWorkerFilter(wf string) Option { o.withWorkerFilter = wf } } + +// WithReaderWriter is used to share the same database reader +// and writer when executing sql within a transaction. +func WithReaderWriter(r db.Reader, w db.Writer) Option { + return func(o *options) { + o.WithReader = r + o.withWriter = w + } +} diff --git a/internal/host/plugin/options_test.go b/internal/host/plugin/options_test.go index 80ef1df197..24fb5abe3b 100644 --- a/internal/host/plugin/options_test.go +++ b/internal/host/plugin/options_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/pagination" "github.com/stretchr/testify/assert" @@ -113,4 +114,11 @@ func Test_GetOpts(t *testing.T) { testOpts.withWorkerFilter = `"test" in "/tags/type"` assert.Equal(t, opts, testOpts) }) + t.Run("WithReaderWriter", func(t *testing.T) { + reader := &db.Db{} + writer := &db.Db{} + opts := getOpts(WithReaderWriter(reader, writer)) + assert.Equal(t, reader, opts.WithReader) + assert.Equal(t, writer, opts.withWriter) + }) } diff --git a/internal/host/plugin/repository_host_catalog.go b/internal/host/plugin/repository_host_catalog.go index 79830b76c1..9832e117a2 100644 --- a/internal/host/plugin/repository_host_catalog.go +++ b/internal/host/plugin/repository_host_catalog.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" + "github.com/hashicorp/boundary/internal/host" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/libs/patchstruct" "github.com/hashicorp/boundary/internal/oplog" @@ -404,7 +405,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version ctx, db.StdRetryCnt, db.ExpBackoff{}, - func(_ db.Reader, w db.Writer) error { + func(read db.Reader, w db.Writer) error { msgs := make([]*oplog.Message, 0, 3) ticket, err := w.GetTicket(ctx, newCatalog) if err != nil { @@ -528,7 +529,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version if needSetSync { // We also need to mark all host sets in this catalog to be // synced as well. - setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId) + setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w)) if err != nil { return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog")) } @@ -713,14 +714,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, * return c, p, nil } -func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) { +func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) { const op = "plugin.(Repository).getPlugin" if plgId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id") } + opt := getOpts(opts...) + reader := r.reader + if !util.IsNil(opt.WithReader) { + reader = opt.WithReader + } plg := plg.NewPlugin() plg.PublicId = plgId - if err := r.reader.LookupByPublicId(ctx, plg); err != nil { + if err := reader.LookupByPublicId(ctx, plg); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId))) } return plg, nil diff --git a/internal/host/plugin/repository_host_set.go b/internal/host/plugin/repository_host_set.go index a3b32c7857..919564a013 100644 --- a/internal/host/plugin/repository_host_set.go +++ b/internal/host/plugin/repository_host_set.go @@ -804,6 +804,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str limit = opts.WithLimit } + reader := r.reader + writer := r.writer + if !util.IsNil(opts.WithReader) { + reader = opts.WithReader + } + if !util.IsNil(opts.WithWriter) { + writer = opts.WithWriter + } + args := make([]any, 0, 1) var where string @@ -825,7 +834,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str } var aggHostSets []*hostSetAgg - if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil { + if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil { return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId))) } @@ -844,7 +853,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str } var plg *plugin.Plugin if plgId != "" { - plg, err = r.getPlugin(ctx, plgId) + plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer)) if err != nil { return nil, nil, errors.Wrap(ctx, err, op) }