From c413e013bdbec6bb9016cf3f7bb924dbefc94b3b Mon Sep 17 00:00:00 2001 From: Jim Date: Thu, 18 Jun 2020 21:24:55 -0400 Subject: [PATCH] get oplog Ticket before writing data (#127) --- internal/db/read_writer.go | 70 ++++++++++++++++++++++++++++++-------- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index f71d983f45..2f18146617 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/watchtower/internal/db/common" "github.com/hashicorp/watchtower/internal/oplog" + "github.com/hashicorp/watchtower/internal/oplog/store" "github.com/jinzhu/gorm" "google.golang.org/protobuf/proto" ) @@ -195,11 +196,19 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return fmt.Errorf("create: vet for write failed %w", err) } } + var ticket *store.Ticket + if withOplog { + var err error + ticket, err = rw.getTicket(i) + if err != nil { + return fmt.Errorf("create: unable to get ticket: %w", err) + } + } if err := rw.underlying.Create(i).Error; err != nil { return fmt.Errorf("create: failed %w", err) } if withOplog { - if err := rw.addOplog(ctx, CreateOp, opts, i); err != nil { + if err := rw.addOplog(ctx, CreateOp, opts, ticket, i); err != nil { return err } } @@ -272,7 +281,14 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string return NoRowsAffected, fmt.Errorf("update: vet for write failed %w", err) } } - + var ticket *store.Ticket + if withOplog { + var err error + ticket, err = rw.getTicket(i) + if err != nil { + return NoRowsAffected, fmt.Errorf("update: unable to get ticket: %w", err) + } + } underlying := rw.underlying.Model(i).Updates(updateFields) if underlying.Error != nil { if err == gorm.ErrRecordNotFound { @@ -294,7 +310,7 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string WithFieldMaskPaths: oplogFieldMasks, WithNullPaths: oplogNullPaths, } - if err := rw.addOplog(ctx, UpdateOp, oplogOpts, i); err != nil { + if err := rw.addOplog(ctx, UpdateOp, oplogOpts, ticket, i); err != nil { return rowsUpdated, fmt.Errorf("update: add oplog failed %w", err) } } @@ -331,13 +347,21 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er return NoRowsAffected, fmt.Errorf("delete: oplog validation failed %w", err) } } + var ticket *store.Ticket + if withOplog { + var err error + ticket, err = rw.getTicket(i) + if err != nil { + return NoRowsAffected, fmt.Errorf("delete: unable to get ticket: %w", err) + } + } underlying := rw.underlying.Delete(i) if underlying.Error != nil { return NoRowsAffected, fmt.Errorf("delete: failed %w", underlying.Error) } rowsDeleted := int(underlying.RowsAffected) if withOplog && rowsDeleted > 0 { - if err := rw.addOplog(ctx, DeleteOp, opts, i); err != nil { + if err := rw.addOplog(ctx, DeleteOp, opts, ticket, i); err != nil { return rowsDeleted, fmt.Errorf("delete: add oplog failed %w", err) } } @@ -359,22 +383,38 @@ func validateOplogArgs(i interface{}, opts Options) (oplog.ReplayableMessage, er return replayable, nil } -func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i interface{}) error { +func (rw *Db) getTicket(i interface{}) (*store.Ticket, error) { + if rw.underlying == nil { + return nil, fmt.Errorf("get ticket: underlying db missing: %w", ErrNilParameter) + } + replayable, ok := i.(oplog.ReplayableMessage) + if !ok { + return nil, fmt.Errorf("get ticket: not a replayable message %w", ErrInvalidParameter) + } + ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) + if err != nil { + return nil, fmt.Errorf("get ticket: unable to get Ticketer %w", err) + } + ticket, err := ticketer.GetTicket(replayable.TableName()) + if err != nil { + return nil, fmt.Errorf("get ticket: unable to get ticket %w", err) + } + return ticket, nil +} + +func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, ticket *store.Ticket, i interface{}) error { oplogArgs := opts.oplogOpts replayable, err := validateOplogArgs(i, opts) if err != nil { return err } - gdb := rw.underlying - ticketer, err := oplog.NewGormTicketer(gdb, oplog.WithAggregateNames(true)) - if err != nil { - return fmt.Errorf("error getting Ticketer %w for WithOplog", err) + if ticket == nil { + return fmt.Errorf("add oplog: missing ticket %w", ErrNilParameter) } - ticket, err := ticketer.GetTicket(replayable.TableName()) + ticketer, err := oplog.NewGormTicketer(rw.underlying, oplog.WithAggregateNames(true)) if err != nil { - return fmt.Errorf("error getting ticket %w for WithOplog", err) + return fmt.Errorf("add oplog: unable to get Ticketer %w", err) } - entry, err := oplog.NewEntry( replayable.TableName(), oplogArgs.metadata, @@ -398,16 +438,16 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter case DeleteOp: msg.OpType = oplog.OpType_OP_TYPE_DELETE default: - return fmt.Errorf("error operation type %v is not supported", opType) + return fmt.Errorf("add oplog: operation type %v is not supported", opType) } err = entry.WriteEntryWith( ctx, - &oplog.GormWriter{Tx: gdb}, + &oplog.GormWriter{Tx: rw.underlying}, ticket, &msg, ) if err != nil { - return fmt.Errorf("error creating oplog entry %w for WithOplog", err) + return fmt.Errorf("add oplog: unable to write oplog entry: %w", err) } return nil }