backport of commit 961d9d7d16

pull/5523/head
Damian Debkowski 1 year ago
parent deeac5c457
commit 42bf2f032c

@ -6,7 +6,9 @@ package host
import (
"errors"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/util"
)
// GetOpts - iterate the inbound Options and return a struct
@ -26,6 +28,8 @@ type Option func(*options) error
// options = how options are represented
type options struct {
WithLimit int
WithReader db.Reader
WithWriter db.Writer
WithOrderByCreateTime bool
Ascending bool
WithStartPageAfterItem pagination.Item
@ -66,3 +70,19 @@ func WithStartPageAfterItem(item pagination.Item) Option {
return nil
}
}
// WithReaderWriter is used to share the same database reader
// and writer when executing sql within a transaction.
func WithReaderWriter(r db.Reader, w db.Writer) Option {
return func(o *options) error {
if util.IsNil(r) {
return errors.New("reader cannot be nil")
}
if util.IsNil(w) {
return errors.New("writer cannot be nil")
}
o.WithReader = r
o.WithWriter = w
return nil
}
}

@ -77,4 +77,23 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(opts.WithStartPageAfterItem.GetPublicId(), "s_1")
assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
})
t.Run("WithReaderWriter", func(t *testing.T) {
t.Parallel()
t.Run("nil writer", func(t *testing.T) {
t.Parallel()
_, err := GetOpts(WithReaderWriter(&db.Db{}, nil))
require.Error(t, err)
})
t.Run("nil reader", func(t *testing.T) {
t.Parallel()
_, err := GetOpts(WithReaderWriter(nil, &db.Db{}))
require.Error(t, err)
})
reader := &db.Db{}
writer := &db.Db{}
opts, err := GetOpts(WithReaderWriter(reader, writer))
require.NoError(t, err)
assert.Equal(t, reader, opts.WithReader)
assert.Equal(t, writer, opts.WithWriter)
})
}

@ -4,6 +4,7 @@
package plugin
import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/pagination"
"google.golang.org/protobuf/types/known/structpb"
)
@ -38,6 +39,8 @@ type options struct {
withSecretsHmac []byte
withStartPageAfterItem pagination.Item
withWorkerFilter string
WithReader db.Reader
withWriter db.Writer
}
func getDefaultOptions() options {
@ -162,3 +165,12 @@ func WithWorkerFilter(wf string) Option {
o.withWorkerFilter = wf
}
}
// WithReaderWriter is used to share the same database reader
// and writer when executing sql within a transaction.
func WithReaderWriter(r db.Reader, w db.Writer) Option {
return func(o *options) {
o.WithReader = r
o.withWriter = w
}
}

@ -7,6 +7,7 @@ import (
"testing"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/stretchr/testify/assert"
@ -113,4 +114,11 @@ func Test_GetOpts(t *testing.T) {
testOpts.withWorkerFilter = `"test" in "/tags/type"`
assert.Equal(t, opts, testOpts)
})
t.Run("WithReaderWriter", func(t *testing.T) {
reader := &db.Db{}
writer := &db.Db{}
opts := getOpts(WithReaderWriter(reader, writer))
assert.Equal(t, reader, opts.WithReader)
assert.Equal(t, writer, opts.withWriter)
})
}

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/libs/patchstruct"
"github.com/hashicorp/boundary/internal/oplog"
@ -404,7 +405,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
func(read db.Reader, w db.Writer) error {
msgs := make([]*oplog.Message, 0, 3)
ticket, err := w.GetTicket(ctx, newCatalog)
if err != nil {
@ -528,7 +529,7 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, version
if needSetSync {
// We also need to mark all host sets in this catalog to be
// synced as well.
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId)
setsForCatalog, _, err := r.getSets(ctx, "", returnedCatalog.PublicId, host.WithReaderWriter(read, w))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get sets for host catalog"))
}
@ -713,14 +714,19 @@ func (r *Repository) getCatalog(ctx context.Context, id string) (*HostCatalog, *
return c, p, nil
}
func (r *Repository) getPlugin(ctx context.Context, plgId string) (*plg.Plugin, error) {
func (r *Repository) getPlugin(ctx context.Context, plgId string, opts ...Option) (*plg.Plugin, error) {
const op = "plugin.(Repository).getPlugin"
if plgId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "no plugin id")
}
opt := getOpts(opts...)
reader := r.reader
if !util.IsNil(opt.WithReader) {
reader = opt.WithReader
}
plg := plg.NewPlugin()
plg.PublicId = plgId
if err := r.reader.LookupByPublicId(ctx, plg); err != nil {
if err := reader.LookupByPublicId(ctx, plg); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("unable to get host plugin with id %q", plgId)))
}
return plg, nil

@ -804,6 +804,15 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
limit = opts.WithLimit
}
reader := r.reader
writer := r.writer
if !util.IsNil(opts.WithReader) {
reader = opts.WithReader
}
if !util.IsNil(opts.WithWriter) {
writer = opts.WithWriter
}
args := make([]any, 0, 1)
var where string
@ -825,7 +834,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
}
var aggHostSets []*hostSetAgg
if err := r.reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
if err := reader.SearchWhere(ctx, &aggHostSets, where, args, dbArgs...); err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("in %s", publicId)))
}
@ -844,7 +853,7 @@ func (r *Repository) getSets(ctx context.Context, publicId string, catalogId str
}
var plg *plugin.Plugin
if plgId != "" {
plg, err = r.getPlugin(ctx, plgId)
plg, err = r.getPlugin(ctx, plgId, WithReaderWriter(reader, writer))
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}

Loading…
Cancel
Save