fix: use the db transaction reader and writer

pull/5522/head
Damian Debkowski 1 year ago
parent 2d802fe3f5
commit 0529264f93

@ -179,7 +179,7 @@ func (r *Repository) upsertAccount(ctx context.Context, am *AuthMethod, IdTokenC
var rowCnt int
for rows.Next() {
rowCnt += 1
err = r.reader.ScanRows(ctx, rows, &result)
err = reader.ScanRows(ctx, rows, &result)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to scan rows for account"))
}

@ -147,7 +147,7 @@ func (amr *AuthMethodRepository) ListDeletedIds(ctx context.Context, since time.
var deletedAuthMethodIDs []string
var transactionTimestamp time.Time
if _, err := amr.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := amr.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}

@ -118,7 +118,7 @@ func (s *StoreRepository) ListDeletedIds(ctx context.Context, since time.Time) (
var deletedStoreIDs []string
var transactionTimestamp time.Time
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}

@ -119,7 +119,7 @@ func (s *CatalogRepository) ListDeletedIds(ctx context.Context, since time.Time)
var deletedCatalogIDs []string
var transactionTimestamp time.Time
if _, err := s.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rows, err := s.writer.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
rows, err := w.Query(ctx, listDeletedIdsQuery, []any{sql.Named("since", since)})
if err != nil {
return errors.Wrap(ctx, err, op)
}

@ -171,7 +171,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host,
var rowsUpdated int
var returnedHost *Host
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
func(r db.Reader, w db.Writer) error {
returnedHost = h.clone()
var err error
rowsUpdated, err = w.Update(ctx, returnedHost, dbMask, nullFields,
@ -186,7 +186,7 @@ func (r *Repository) UpdateHost(ctx context.Context, projectId string, h *Host,
ha := &hostAgg{
PublicId: h.PublicId,
}
if err := r.reader.LookupByPublicId(ctx, ha); err != nil {
if err := r.LookupByPublicId(ctx, ha); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to lookup host after update"))
}
returnedHost.SetIds = ha.getSetIds()

@ -325,7 +325,7 @@ func (r *Repository) queryRoles(ctx context.Context, whereClause string, args []
for _, retRole := range retRoles {
roleIds = append(roleIds, retRole.PublicId)
}
retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds)
retRoleGrantScopes, err = r.ListRoleGrantScopes(ctx, roleIds, WithReaderWriter(rd, w))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query role grant scopes"))
}

@ -359,7 +359,7 @@ func (r *Repository) SetRoleGrants(ctx context.Context, roleId string, roleVersi
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to write oplog"))
}
currentRoleGrants, err = r.ListRoleGrants(ctx, roleId)
currentRoleGrants, err = r.ListRoleGrants(ctx, roleId, WithReaderWriter(reader, w))
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to retrieve current role grants after set"))
}

@ -7,6 +7,7 @@ import (
"context"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/version"
"github.com/hashicorp/nodeenrollment/types"
)
@ -63,6 +64,8 @@ type options struct {
withWorkerPool []string
withFilterWorkersByStorageBucketCredentialState *StorageBucketCredentialInfo
withFilterWorkersByLocalStorageState bool
WithReader db.Reader
WithWriter db.Writer
}
func getDefaultOptions() options {
@ -303,3 +306,12 @@ func WithFilterWorkersByLocalStorageState(filter bool) Option {
o.withFilterWorkersByLocalStorageState = filter
}
}
// WithReaderWriter is used to share the same database reader
// and writer when executing sql within a transaction.
func WithReaderWriter(r db.Reader, w db.Writer) Option {
return func(o *options) {
o.WithReader = r
o.WithWriter = w
}
}

@ -10,6 +10,7 @@ import (
"testing"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/version"
"github.com/stretchr/testify/assert"
)
@ -252,4 +253,19 @@ func Test_GetOpts(t *testing.T) {
testOpts.withNewIdFunc = nil
assert.Equal(t, opts, testOpts)
})
t.Run("WithReaderWriter", func(t *testing.T) {
reader := &db.Db{}
writer := &db.Db{}
testOpts := getDefaultOptions()
assert.Nil(t, testOpts.WithReader)
assert.Nil(t, testOpts.WithWriter)
testOpts.WithReader = reader
testOpts.WithWriter = writer
opts := GetOpts(WithReaderWriter(reader, writer))
opts.withNewIdFunc = nil
testOpts.withNewIdFunc = nil
assert.Equal(t, reader, opts.WithReader)
assert.Equal(t, writer, opts.WithWriter)
assert.Equal(t, opts, testOpts)
})
}

@ -600,7 +600,7 @@ func (r *Repository) UpdateWorker(ctx context.Context, worker *Worker, version u
if err != nil {
return errors.Wrap(ctx, err, op)
}
ret.RemoteStorageStates, err = r.ListWorkerStorageBucketCredentialState(ctx, ret.GetPublicId())
ret.RemoteStorageStates, err = r.ListWorkerStorageBucketCredentialState(ctx, ret.GetPublicId(), WithReaderWriter(reader, w))
if err != nil {
return err
}
@ -925,7 +925,7 @@ func (r *Repository) SelectSessionWorkers(ctx context.Context,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
rows, err := r.reader.Query(ctx, query, []any{})
rows, err := reader.Query(ctx, query, []any{})
if err != nil {
return err
}
@ -935,7 +935,7 @@ func (r *Repository) SelectSessionWorkers(ctx context.Context,
// a Worker object can hold, only a subset. Check the query to
// learn exactly what fields are present.
var worker Worker
if err := r.reader.ScanRows(ctx, rows, &worker); err != nil {
if err := reader.ScanRows(ctx, rows, &worker); err != nil {
return err
}
livingWorkers = append(livingWorkers, &worker)

@ -9,11 +9,12 @@ import (
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/util"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
)
// ListWorkerStorageBucketCredentialState returns a list of storage bucket credential states for the given worker.
func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, workerId string) (map[string]*plgpb.StorageBucketCredentialState, error) {
func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context, workerId string, opts ...Option) (map[string]*plgpb.StorageBucketCredentialState, error) {
const op = "server.(Repository).ListWorkerStorageBucketCredentialState"
if workerId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "empty worker id")
@ -25,7 +26,12 @@ func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context,
CheckedAt *timestamp.Timestamp
ErrorDetails string
}
rows, err := r.reader.Query(ctx, getStorageBucketCredentialStatesByWorkerId, []any{sql.Named("worker_id", workerId)})
opt := GetOpts(opts...)
reader := r.reader
if !util.IsNil(opt.WithReader) {
reader = opt.WithReader
}
rows, err := reader.Query(ctx, getStorageBucketCredentialStatesByWorkerId, []any{sql.Named("worker_id", workerId)})
if err != nil && !errors.Match(errors.T(errors.RecordNotFound), err) {
return nil, errors.Wrap(ctx, err, op)
}
@ -36,7 +42,7 @@ func (r *Repository) ListWorkerStorageBucketCredentialState(ctx context.Context,
return nil, errors.Wrap(ctx, err, op)
}
var row remoteStorageState
if err := r.reader.ScanRows(ctx, rows, &row); err != nil {
if err := reader.ScanRows(ctx, rows, &row); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("failed to fetch remote storage state"))
}
s, ok := remoteStorageStates[row.StorageBucketId]

Loading…
Cancel
Save