oplog support for null field updates (#107)

pull/114/head
Jim 6 years ago committed by GitHub
parent c70847fbac
commit 2d18177c9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,7 +55,7 @@ require (
github.com/spf13/viper v1.6.3 // indirect
github.com/stretchr/testify v1.6.0
go.mongodb.org/mongo-driver v1.3.2 // indirect
golang.org/x/net v0.0.0-20200519113804-d87ec0cfa476
golang.org/x/net v0.0.0-20200519113804-d87ec0cfa476 // indirect
golang.org/x/sys v0.0.0-20200509044756-6aff5f38e54f // indirect
golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1
golang.org/x/tools v0.0.0-20200519175826-7521f6f42533
@ -64,5 +64,4 @@ require (
google.golang.org/protobuf v1.24.0
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/ini.v1 v1.55.0 // indirect
gotest.tools v2.2.0+incompatible
)

@ -575,8 +575,6 @@ github.com/hashicorp/consul/api v1.4.0 h1:jfESivXnO5uLdH650JU/6AnjRoHrLhULq0FnC3
github.com/hashicorp/consul/api v1.4.0/go.mod h1:xc8u05kyMa3Wjr9eEAsIAo3dg8+LywT5E/Cl7cNS5nU=
github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8=
github.com/hashicorp/consul/sdk v0.4.0/go.mod h1:fY08Y9z5SvJqevyZNy6WWPXiG3KwBPAvlcdx16zZ0fM=
github.com/hashicorp/dbassert v0.0.0-20200601131634-c3045f4c81e8 h1:0R3FvnfT4FC2abVevODTZ22zzhguKM2K5aufuysG0oo=
github.com/hashicorp/dbassert v0.0.0-20200601131634-c3045f4c81e8/go.mod h1:+B5eZq7vXqN4Gxkgjm4ub1bwG6OTcYzFL0r+bMvWorg=
github.com/hashicorp/dbassert v0.0.0-20200602132714-30dfb78f0133 h1:NT2GV/LkzarLLGLYf/+HdqazyFEuuGDA7tOZ2Jyo+80=
github.com/hashicorp/dbassert v0.0.0-20200602132714-30dfb78f0133/go.mod h1:+B5eZq7vXqN4Gxkgjm4ub1bwG6OTcYzFL0r+bMvWorg=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=

@ -35,7 +35,7 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin
return nil, errors.New("both fieldMaskPaths and setToNullPaths are zero len")
}
inter, maskPaths, nullPaths, err := intersection(fieldMaskPaths, setToNullPaths)
inter, maskPaths, nullPaths, err := Intersection(fieldMaskPaths, setToNullPaths)
if err != nil {
return nil, err
}
@ -111,11 +111,11 @@ func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string {
return notFound
}
// intersection is a case-insensitive search for intersecting values. Returns
// []string of the intersection with values in lowercase, and map[string]string
// Intersection is a case-insensitive search for intersecting values. Returns
// []string of the Intersection with values in lowercase, and map[string]string
// of the original av and bv, with the key set to uppercase and value set to the
// original
func intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) {
func Intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) {
if av == nil {
return nil, nil, nil, fmt.Errorf("av is missing: %w", ErrNilParameter)
}

@ -165,7 +165,7 @@ func Test_intersection(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, got1, got2, err := intersection(tt.args.av, tt.args.bv)
got, got1, got2, err := Intersection(tt.args.av, tt.args.bv)
if err == nil && tt.wantErr {
assert.Error(err)
}

@ -284,7 +284,20 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
}
rowsUpdated := int(underlying.RowsAffected)
if withOplog && rowsUpdated > 0 {
if err := rw.addOplog(ctx, UpdateOp, opts, i); err != nil {
// we don't want to change the inbound slices in opts, so we'll make our
// own copy to pass to addOplog()
oplogFieldMasks := make([]string, len(fieldMaskPaths))
copy(oplogFieldMasks, fieldMaskPaths)
oplogNullPaths := make([]string, len(setToNullPaths))
copy(oplogNullPaths, setToNullPaths)
oplogOpts := Options{
oplogOpts: opts.oplogOpts,
withOplog: opts.withOplog,
withDebug: opts.withDebug,
WithFieldMaskPaths: oplogFieldMasks,
WithNullPaths: oplogNullPaths,
}
if err := rw.addOplog(ctx, UpdateOp, oplogOpts, i); err != nil {
return rowsUpdated, fmt.Errorf("update: add oplog failed %w", err)
}
}
@ -384,14 +397,19 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
if err != nil {
return err
}
var entryOp oplog.OpType
msg := oplog.Message{
Message: i.(proto.Message),
TypeName: replayable.TableName(),
}
switch opType {
case CreateOp:
entryOp = oplog.OpType_OP_TYPE_CREATE
msg.OpType = oplog.OpType_OP_TYPE_CREATE
case UpdateOp:
entryOp = oplog.OpType_OP_TYPE_UPDATE
msg.OpType = oplog.OpType_OP_TYPE_UPDATE
msg.FieldMaskPaths = opts.WithFieldMaskPaths
msg.SetToNullPaths = opts.WithNullPaths
case DeleteOp:
entryOp = oplog.OpType_OP_TYPE_DELETE
msg.OpType = oplog.OpType_OP_TYPE_DELETE
default:
return fmt.Errorf("error operation type %v is not supported", opType)
}
@ -399,7 +417,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter
ctx,
&oplog.GormWriter{Tx: gdb},
ticket,
&oplog.Message{Message: i.(proto.Message), TypeName: replayable.TableName(), OpType: entryOp},
&msg,
)
if err != nil {
return fmt.Errorf("error creating oplog entry %w for WithOplog", err)

@ -229,9 +229,8 @@ func TestDb_Update(t *testing.T) {
assert.NotEqual(publicId, foundUser.PublicId)
})
}
assert := assert.New(t)
t.Run("valid-WithOplog", func(t *testing.T) {
assert := assert.New(t)
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
@ -264,6 +263,7 @@ func TestDb_Update(t *testing.T) {
assert.NoError(err)
})
t.Run("vet-for-write", func(t *testing.T) {
assert := assert.New(t)
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
@ -294,6 +294,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal(0, rowsUpdated)
})
t.Run("nil-tx", func(t *testing.T) {
assert := assert.New(t)
w := Db{underlying: nil}
id, err := uuid.GenerateUUID()
assert.NoError(err)
@ -305,6 +306,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal("update: missing underlying db nil parameter", err.Error())
})
t.Run("no-wrapper-WithOplog", func(t *testing.T) {
assert := assert.New(t)
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
@ -326,6 +328,7 @@ func TestDb_Update(t *testing.T) {
assert.Equal("update: oplog validation failed error no wrapper WithOplog: nil parameter", err.Error())
})
t.Run("no-metadata-WithOplog", func(t *testing.T) {
assert := assert.New(t)
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)

@ -26,15 +26,19 @@ const (
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
// OpType provides the type of database operation the Any message represents (create,
// update, delete)
// OpType provides the type of database operation the Any message represents
// (create, update, delete)
type OpType int32
const (
// OP_TYPE_UNSPECIFIED defines an unspecified operation.
OpType_OP_TYPE_UNSPECIFIED OpType = 0
OpType_OP_TYPE_CREATE OpType = 1
OpType_OP_TYPE_UPDATE OpType = 2
OpType_OP_TYPE_DELETE OpType = 3
// OP_TYPE_CREATE defines a create operation.
OpType_OP_TYPE_CREATE OpType = 1
// OP_TYPE_UPDATE defines an update operation.
OpType_OP_TYPE_UPDATE OpType = 2
// OP_TYPE_DELETE defines a delete operation.
OpType_OP_TYPE_DELETE OpType = 3
)
// Enum value maps for OpType.
@ -87,11 +91,16 @@ type AnyOperation struct {
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
TypeName string `protobuf:"bytes,1,opt,name=type_name,json=typeName,proto3" json:"type_name,omitempty"`
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
// type_name defines type of operation.
TypeName string `protobuf:"bytes,1,opt,name=type_name,json=typeName,proto3" json:"type_name,omitempty"`
// value are the bytes of a marshalled proto buff.
Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"`
// operation_type defines the type of database operation.
OperationType OpType `protobuf:"varint,3,opt,name=operation_type,json=operationType,proto3,enum=controller.storage.oplog.v1.OpType" json:"operation_type,omitempty"`
// mask is the mask of fields to update.
// field_mask is the mask of fields to update.
FieldMask *field_mask.FieldMask `protobuf:"bytes,4,opt,name=field_mask,json=fieldMask,proto3" json:"field_mask,omitempty"`
// null_mask is the mask of fields to set to null.
NullMask *field_mask.FieldMask `protobuf:"bytes,5,opt,name=null_mask,json=nullMask,proto3" json:"null_mask,omitempty"`
}
func (x *AnyOperation) Reset() {
@ -154,6 +163,13 @@ func (x *AnyOperation) GetFieldMask() *field_mask.FieldMask {
return nil
}
func (x *AnyOperation) GetNullMask() *field_mask.FieldMask {
if x != nil {
return x.NullMask
}
return nil
}
var File_controller_storage_oplog_v1_any_operation_proto protoreflect.FileDescriptor
var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{
@ -164,7 +180,7 @@ var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{
0x6f, 0x72, 0x61, 0x67, 0x65, 0x2e, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x2e, 0x76, 0x31, 0x1a, 0x20,
0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f,
0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x22, 0xc8, 0x01, 0x0a, 0x0c, 0x41, 0x6e, 0x79, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f,
0x22, 0x81, 0x02, 0x0a, 0x0c, 0x41, 0x6e, 0x79, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f,
0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01,
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x79, 0x70, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14,
0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76,
@ -176,17 +192,21 @@ var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{
0x12, 0x39, 0x0a, 0x0a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x18, 0x04,
0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b,
0x52, 0x09, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x2a, 0x5d, 0x0a, 0x06, 0x4f,
0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45,
0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12,
0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45,
0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x50,
0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50,
0x45, 0x5f, 0x44, 0x45, 0x4c, 0x45, 0x54, 0x45, 0x10, 0x03, 0x42, 0x36, 0x5a, 0x34, 0x67, 0x69,
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f,
0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x3b, 0x6f, 0x70, 0x6c,
0x6f, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
0x52, 0x09, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x12, 0x37, 0x0a, 0x09, 0x6e,
0x75, 0x6c, 0x6c, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a,
0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66,
0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x52, 0x08, 0x6e, 0x75, 0x6c, 0x6c,
0x4d, 0x61, 0x73, 0x6b, 0x2a, 0x5d, 0x0a, 0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17,
0x0a, 0x13, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43,
0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59,
0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4f,
0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12,
0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x45, 0x4c, 0x45, 0x54,
0x45, 0x10, 0x03, 0x42, 0x36, 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f,
0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63,
0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f,
0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x3b, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (
@ -211,11 +231,12 @@ var file_controller_storage_oplog_v1_any_operation_proto_goTypes = []interface{}
var file_controller_storage_oplog_v1_any_operation_proto_depIdxs = []int32{
0, // 0: controller.storage.oplog.v1.AnyOperation.operation_type:type_name -> controller.storage.oplog.v1.OpType
2, // 1: controller.storage.oplog.v1.AnyOperation.field_mask:type_name -> google.protobuf.FieldMask
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
2, // 2: controller.storage.oplog.v1.AnyOperation.null_mask:type_name -> google.protobuf.FieldMask
3, // [3:3] is the sub-list for method output_type
3, // [3:3] is the sub-list for method input_type
3, // [3:3] is the sub-list for extension type_name
3, // [3:3] is the sub-list for extension extendee
0, // [0:3] is the sub-list for field type_name
}
func init() { file_controller_storage_oplog_v1_any_operation_proto_init() }

@ -24,6 +24,7 @@ type Message struct {
TypeName string
OpType OpType
FieldMaskPaths []string
SetToNullPaths []string
}
// Entry represents an oplog entry
@ -98,7 +99,7 @@ func (e *Entry) UnmarshalData(types *TypeCatalog) ([]Message, error) {
Catalog: types,
}
for {
m, typ, fieldMaskPaths, err := queue.Remove()
m, typ, fieldMaskPaths, nullPaths, err := queue.Remove()
if err == io.EOF {
break
}
@ -109,7 +110,7 @@ func (e *Entry) UnmarshalData(types *TypeCatalog) ([]Message, error) {
if err != nil {
return nil, fmt.Errorf("error getting TypeName: %w", err)
}
msgs = append(msgs, Message{Message: m, TypeName: name, OpType: typ, FieldMaskPaths: fieldMaskPaths})
msgs = append(msgs, Message{Message: m, TypeName: name, OpType: typ, FieldMaskPaths: fieldMaskPaths, SetToNullPaths: nullPaths})
}
return msgs, nil
}
@ -131,7 +132,7 @@ func (e *Entry) WriteEntryWith(ctx context.Context, tx Writer, ticket *store.Tic
if m == nil {
return errors.New("bad message")
}
if err := queue.Add(m.Message, m.TypeName, m.OpType, WithFieldMaskPaths(m.FieldMaskPaths)); err != nil {
if err := queue.Add(m.Message, m.TypeName, m.OpType, WithFieldMaskPaths(m.FieldMaskPaths), WithSetToNullPaths(m.SetToNullPaths)); err != nil {
return fmt.Errorf("error adding message to queue: %w", err)
}
}
@ -214,7 +215,9 @@ func (e *Entry) Replay(ctx context.Context, tx Writer, types *TypeCatalog, table
*/
replayTable := origTableName + tableSuffix
if !tx.hasTable(replayTable) {
tx.createTableLike(origTableName, replayTable)
if err := tx.createTableLike(origTableName, replayTable); err != nil {
return fmt.Errorf("replay: %w", err)
}
}
em.SetTableName(replayTable)
@ -224,7 +227,7 @@ func (e *Entry) Replay(ctx context.Context, tx Writer, types *TypeCatalog, table
return fmt.Errorf("replay error: %w", err)
}
case OpType_OP_TYPE_UPDATE:
if err := tx.Update(m.Message, m.FieldMaskPaths); err != nil {
if err := tx.Update(m.Message, m.FieldMaskPaths, m.SetToNullPaths); err != nil {
return fmt.Errorf("replay error: %w", err)
}
case OpType_OP_TYPE_DELETE:

@ -2,74 +2,54 @@ package oplog
import (
"context"
"crypto/rand"
"database/sql"
"fmt"
"os"
"strings"
"testing"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/hashicorp/go-uuid"
dbassert "github.com/hashicorp/dbassert/gorm"
"github.com/hashicorp/watchtower/internal/oplog/oplog_test"
"github.com/jinzhu/gorm"
"gotest.tools/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/ory/dockertest/v3"
)
// setup the tests (initialize the database one-time and intialized testDatabaseURL)
func setup(t *testing.T) (func(), *gorm.DB) {
cleanup := func() {}
var url string
var err error
cleanup, url, err = initDbInDocker(t)
if err != nil {
t.Fatal(err)
}
t.Helper()
require := require.New(t)
cleanup, url, err := testInitDbInDocker(t)
require.NoError(err)
db, err := gorm.Open("postgres", url)
if err != nil {
t.Fatal(err)
}
require.NoError(err)
oplog_test.Init(db)
return cleanup, db
}
// Test_BasicOplog provides some basic unit tests for oplogs
func Test_BasicOplog(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("EncryptData/DecryptData/UnmarshalData", func(t *testing.T) {
cipherer := initWrapper(t)
assert, require := assert.New(t), require.New(t)
cipherer := testWrapper(t)
// now let's us optimistic locking via a ticketing system for a serialized oplog
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(err)
queue := Queue{Catalog: types}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
resp := db.Create(&user)
assert.NilError(t, resp.Error)
id := testId(t)
user := testUser(t, db, "foo-"+id, "", "")
err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
err = queue.Add(user, "user", OpType_OP_TYPE_CREATE)
require.NoError(err)
l, err := NewEntry(
"test-users",
Metadata{
@ -80,62 +60,55 @@ func Test_BasicOplog(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
l.Data = queue.Bytes()
err = l.EncryptData(context.Background())
assert.NilError(t, err)
require.NoError(err)
resp = db.Create(&l)
assert.NilError(t, resp.Error)
assert.Check(t, l.CreateTime != nil)
assert.Check(t, l.UpdateTime != nil)
err = db.Create(&l).Error
require.NoError(err)
assert.NotNil(l.CreateTime)
assert.NotNil(l.UpdateTime)
entryId := l.Id
var foundEntry Entry
err = db.Where("id = ?", entryId).First(&foundEntry).Error
assert.NilError(t, err)
require.NoError(err)
foundEntry.Cipherer = cipherer
err = foundEntry.DecryptData(context.Background())
assert.NilError(t, err)
require.NoError(err)
foundUsers, err := foundEntry.UnmarshalData(types)
assert.NilError(t, err)
assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == user.Name)
require.NoError(err)
assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, user.Name)
foundUsers, err = foundEntry.UnmarshalData(types)
assert.NilError(t, err)
assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == user.Name)
require.NoError(err)
assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, user.Name)
})
t.Run("write entry", func(t *testing.T) {
cipherer := initWrapper(t)
assert, require := assert.New(t), require.New(t)
cipherer := testWrapper(t)
// now let's us optimistic locking via a ticketing system for a serialized oplog
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(err)
queue := Queue{Catalog: types}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
id := testId(t)
user := testUser(t, db, "foo-"+id, "", "")
user := oplog_test.TestUser{
Name: "foo-" + id,
}
resp := db.Create(&user)
assert.NilError(t, resp.Error)
assert.NilError(t, err)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
queue = Queue{}
err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
err = queue.Add(user, "user", OpType_OP_TYPE_CREATE)
require.NoError(err)
newLogEntry, err := NewEntry(
"test-users",
@ -147,11 +120,11 @@ func Test_BasicOplog(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
newLogEntry.Data = queue.Bytes()
err = newLogEntry.Write(context.Background(), &GormWriter{db}, ticket)
assert.NilError(t, err)
assert.Assert(t, newLogEntry.Id != 0)
require.NoError(err)
assert.NotEmpty(newLogEntry.Id)
})
}
@ -159,12 +132,13 @@ func Test_BasicOplog(t *testing.T) {
// Test_NewEntry provides some basic unit tests for NewEntry
func Test_NewEntry(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("valid", func(t *testing.T) {
cipherer := initWrapper(t)
require := require.New(t)
cipherer := testWrapper(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
_, err = NewEntry(
"test-users",
@ -176,12 +150,13 @@ func Test_NewEntry(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
})
t.Run("no metadata success", func(t *testing.T) {
cipherer := initWrapper(t)
require := require.New(t)
cipherer := testWrapper(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
_, err = NewEntry(
"test-users",
@ -189,12 +164,13 @@ func Test_NewEntry(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
})
t.Run("no aggregateName", func(t *testing.T) {
cipherer := initWrapper(t)
assert, require := assert.New(t), require.New(t)
cipherer := testWrapper(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
_, err = NewEntry(
"",
@ -206,11 +182,13 @@ func Test_NewEntry(t *testing.T) {
cipherer,
ticketer,
)
assert.Equal(t, err.Error(), "error creating entry: entry aggregate name is not set")
require.Error(err)
assert.Equal(err.Error(), "error creating entry: entry aggregate name is not set")
})
t.Run("bad cipherer", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
_, err = NewEntry(
"test-users",
@ -222,10 +200,12 @@ func Test_NewEntry(t *testing.T) {
nil,
ticketer,
)
assert.Equal(t, err.Error(), "error creating entry: entry Cipherer is nil")
require.Error(err)
assert.Equal(err.Error(), "error creating entry: entry Cipherer is nil")
})
t.Run("bad ticket", func(t *testing.T) {
cipherer := initWrapper(t)
assert, require := assert.New(t), require.New(t)
cipherer := testWrapper(t)
_, err := NewEntry(
"test-users",
Metadata{
@ -236,34 +216,34 @@ func Test_NewEntry(t *testing.T) {
cipherer,
nil,
)
assert.Equal(t, err.Error(), "error creating entry: entry Ticketer is nil")
require.Error(err)
assert.Equal(err.Error(), "error creating entry: entry Ticketer is nil")
})
}
func Test_UnmarshalData(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
cipherer := initWrapper(t)
cipherer := testWrapper(t)
// now let's us optimistic locking via a ticketing system for a serialized oplog
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(t, err)
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
id := testId(t)
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
queue := Queue{Catalog: types}
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
entry, err := NewEntry(
"test-users",
Metadata{
@ -274,18 +254,19 @@ func Test_UnmarshalData(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
entry.Data = queue.Bytes()
marshaledUsers, err := entry.UnmarshalData(types)
assert.NilError(t, err)
assert.Assert(t, marshaledUsers[0].Message.(*oplog_test.TestUser).Name == user.Name)
require.NoError(err)
assert.Equal(marshaledUsers[0].Message.(*oplog_test.TestUser).Name, user.Name)
take2marshaledUsers, err := entry.UnmarshalData(types)
assert.NilError(t, err)
assert.Assert(t, take2marshaledUsers[0].Message.(*oplog_test.TestUser).Name == user.Name)
require.NoError(err)
assert.Equal(take2marshaledUsers[0].Message.(*oplog_test.TestUser).Name, user.Name)
})
t.Run("no data", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
queue := Queue{Catalog: types}
entry, err := NewEntry(
"test-users",
@ -297,20 +278,22 @@ func Test_UnmarshalData(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
entry.Data = queue.Bytes()
_, err = entry.UnmarshalData(types)
assert.Equal(t, err.Error(), "no Data to unmarshal")
require.Error(err)
assert.Equal(err.Error(), "no Data to unmarshal")
})
t.Run("nil types", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
queue := Queue{Catalog: types}
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
entry, err := NewEntry(
"test-users",
Metadata{
@ -321,19 +304,21 @@ func Test_UnmarshalData(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
_, err = entry.UnmarshalData(nil)
assert.Equal(t, err.Error(), "TypeCatalog is nil")
require.Error(err)
assert.Equal(err.Error(), "TypeCatalog is nil")
})
t.Run("missing type", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
queue := Queue{Catalog: types}
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
entry, err := NewEntry(
"test-users",
Metadata{
@ -344,40 +329,39 @@ func Test_UnmarshalData(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
entry.Data = queue.Bytes()
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "not-valid-name"})
assert.NilError(t, err)
require.NoError(err)
_, err = entry.UnmarshalData(types)
assert.Equal(t, err.Error(), "error removing item from queue: error getting the TypeName for Remove: error typeName is not found for Get")
require.Error(err)
assert.Equal(err.Error(), "error removing item from queue: error getting the TypeName for Remove: error typeName is not found for Get")
})
}
// Test_Replay provides some basic unit tests for replaying entries
func Test_Replay(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
cipherer := initWrapper(t)
cipherer := testWrapper(t)
id := testId(t)
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
// setup new tables for replay
tableSuffix := "_" + id
writer := GormWriter{Tx: db}
testUser := &oplog_test.TestUser{}
replayUserTable := fmt.Sprintf("%s%s", testUser.TableName(), tableSuffix)
defer func() { assert.NilError(t, writer.dropTableIfExists(replayUserTable)) }()
userModel := &oplog_test.TestUser{}
replayUserTable := fmt.Sprintf("%s%s", userModel.TableName(), tableSuffix)
defer func() { require.NoError(t, writer.dropTableIfExists(replayUserTable)) }()
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(t, err)
t.Run("replay:create/update", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
newLogEntry, err := NewEntry(
"test-users",
@ -389,93 +373,93 @@ func Test_Replay(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
tx := db.Begin()
require.NoError(err)
id3, err := uuid.GenerateUUID()
assert.NilError(t, err)
userName := "foo-" + id3
// create a user that's replayable
userCreate := oplog_test.TestUser{
Name: userName,
}
err = tx.Create(&userCreate).Error
assert.NilError(t, err)
tx := db
userName := "foo-" + testId(t)
userCreate := testUser(t, db, userName, "", "")
userSave := oplog_test.TestUser{
Id: userCreate.Id,
Name: userCreate.Name,
Email: userName + "@hashicorp.com",
}
err = tx.Save(&userSave).Error
assert.NilError(t, err)
require.NoError(err)
userUpdate := oplog_test.TestUser{
Id: userCreate.Id,
PhoneNumber: "867-5309",
Id: userCreate.Id,
}
err = tx.Model(&userUpdate).Updates(map[string]interface{}{"PhoneNumber": "867-5309"}).Error
assert.NilError(t, err)
err = tx.Model(&userUpdate).Updates(map[string]interface{}{"PhoneNumber": "867-5309", "Name": gorm.Expr("NULL")}).Error
require.NoError(err)
dbassert := dbassert.New(t, tx.DB(), "postgres")
dbassert.IsNull(&userUpdate, "Name")
foundCreateUser := testFindUser(t, tx, userCreate.Id)
require.Equal(foundCreateUser.Id, userCreate.Id)
require.Equal(foundCreateUser.Name, "")
require.Equal(foundCreateUser.PhoneNumber, userUpdate.PhoneNumber)
err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{tx}, ticket,
&Message{Message: &userCreate, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE},
&Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"PhoneNumber"}},
&Message{Message: userCreate, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"Name", "Email"}},
&Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"Name", "Email"}, SetToNullPaths: nil},
&Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, SetToNullPaths: []string{"Name"}},
&Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: nil, SetToNullPaths: []string{"Name"}},
&Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"PhoneNumber"}, SetToNullPaths: []string{"Name"}},
)
assert.NilError(t, err)
require.NoError(err)
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(err)
var foundEntry Entry
err = tx.Where("id = ?", newLogEntry.Id).First(&foundEntry).Error
assert.NilError(t, err)
require.NoError(err)
foundEntry.Cipherer = cipherer
err = foundEntry.DecryptData(context.Background())
assert.NilError(t, err)
require.NoError(err)
err = foundEntry.Replay(context.Background(), &GormWriter{tx}, types, tableSuffix)
assert.NilError(t, err)
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", userCreate.Id).First(&foundUser).Error
assert.NilError(t, err)
require.NoError(err)
foundUser := testFindUser(t, tx, userCreate.Id)
var foundReplayedUser oplog_test.TestUser
foundReplayedUser.Table = foundReplayedUser.TableName() + tableSuffix
err = tx.Where("id = ?", userCreate.Id).First(&foundReplayedUser).Error
assert.NilError(t, err)
assert.Assert(t, foundUser.Id == foundReplayedUser.Id)
assert.Assert(t, foundUser.Name == foundReplayedUser.Name && foundUser.Name == userName)
assert.Assert(t, foundUser.PhoneNumber == foundReplayedUser.PhoneNumber && foundReplayedUser.PhoneNumber == "867-5309")
assert.Assert(t, foundUser.Email == foundReplayedUser.Email && foundReplayedUser.Email == userName+"@hashicorp.com")
tx.Commit()
require.NoError(err)
assert.Equal(foundUser.Id, foundReplayedUser.Id)
assert.Equal(foundUser.Name, foundReplayedUser.Name)
// assert.Equal(foundUser.Name, userName)
assert.Equal(foundUser.PhoneNumber, foundReplayedUser.PhoneNumber)
assert.Equal(foundReplayedUser.PhoneNumber, "867-5309")
assert.Equal(foundUser.Email, foundReplayedUser.Email)
assert.Equal(foundReplayedUser.Email, userName+"@hashicorp.com")
})
t.Run("replay:delete", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// we need to test delete replays now...
tx2 := db.Begin()
defer tx2.Commit()
ticket2, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
id4, err := uuid.GenerateUUID()
assert.NilError(t, err)
id4 := testId(t)
userName2 := "foo-" + id4
// create a user that's replayable
userCreate2 := oplog_test.TestUser{
Name: userName2,
}
err = tx2.Create(&userCreate2).Error
assert.NilError(t, err)
require.NoError(err)
deleteUser2 := oplog_test.TestUser{
Id: userCreate2.Id,
}
err = tx2.Delete(&deleteUser2).Error
assert.NilError(t, err)
require.NoError(err)
newLogEntry2, err := NewEntry(
"test-users",
@ -487,67 +471,63 @@ func Test_Replay(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
err = newLogEntry2.WriteEntryWith(context.Background(), &GormWriter{tx2}, ticket2,
&Message{Message: &userCreate2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &deleteUser2, TypeName: "user", OpType: OpType_OP_TYPE_DELETE},
)
assert.NilError(t, err)
require.NoError(err)
var foundEntry2 Entry
err = tx2.Where("id = ?", newLogEntry2.Id).First(&foundEntry2).Error
assert.NilError(t, err)
require.NoError(err)
foundEntry2.Cipherer = cipherer
err = foundEntry2.DecryptData(context.Background())
assert.NilError(t, err)
require.NoError(err)
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(err)
err = foundEntry2.Replay(context.Background(), &GormWriter{tx2}, types, tableSuffix)
assert.NilError(t, err)
require.NoError(err)
var foundUser2 oplog_test.TestUser
err = tx2.Where("id = ?", userCreate2.Id).First(&foundUser2).Error
assert.Assert(t, err == gorm.ErrRecordNotFound)
assert.Equal(err, gorm.ErrRecordNotFound)
var foundReplayedUser2 oplog_test.TestUser
err = tx2.Where("id = ?", userCreate2.Id).First(&foundReplayedUser2).Error
assert.Assert(t, err == gorm.ErrRecordNotFound)
assert.Equal(err, gorm.ErrRecordNotFound)
tx2.Commit()
})
}
// Test_WriteEntryWith provides unit tests for oplog.WriteEntryWith
func Test_WriteEntryWith(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
cipherer := initWrapper(t)
defer testCleanup(t, cleanup, db)
cipherer := testWrapper(t)
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
id := testId(t)
u := oplog_test.TestUser{
Name: "foo-" + id,
}
t.Log(&u)
id2, err := uuid.GenerateUUID()
assert.NilError(t, err)
id2 := testId(t)
u2 := oplog_test.TestUser{
Name: "foo-" + id2,
}
t.Log(&u2)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(t, err)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(t, err)
t.Run("successful", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
newLogEntry, err := NewEntry(
"test-users",
Metadata{
@ -558,25 +538,26 @@ func Test_WriteEntryWith(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, ticket,
&Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE})
assert.NilError(t, err)
require.NoError(err)
var foundEntry Entry
err = db.Where("id = ?", newLogEntry.Id).First(&foundEntry).Error
assert.NilError(t, err)
require.NoError(err)
foundEntry.Cipherer = cipherer
types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"})
assert.NilError(t, err)
require.NoError(err)
err = foundEntry.DecryptData(context.Background())
assert.NilError(t, err)
require.NoError(err)
foundUsers, err := foundEntry.UnmarshalData(types)
assert.NilError(t, err)
assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == u.Name)
require.NoError(err)
assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, u.Name)
})
t.Run("nil writer", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
newLogEntry, err := NewEntry(
"test-users",
Metadata{
@ -587,13 +568,15 @@ func Test_WriteEntryWith(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
err = newLogEntry.WriteEntryWith(context.Background(), nil, ticket,
&Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE})
assert.Equal(t, err.Error(), "bad writer")
require.Error(err)
assert.Equal(err.Error(), "bad writer")
})
t.Run("nil ticket", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
newLogEntry, err := NewEntry(
"test-users",
Metadata{
@ -604,13 +587,15 @@ func Test_WriteEntryWith(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, nil,
&Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE},
&Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE})
assert.Equal(t, err.Error(), "bad ticket")
require.Error(err)
assert.Equal(err.Error(), "bad ticket")
})
t.Run("nil ticket", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
newLogEntry, err := NewEntry(
"test-users",
Metadata{
@ -621,37 +606,37 @@ func Test_WriteEntryWith(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, ticket, nil)
assert.Equal(t, err.Error(), "bad message")
require.Error(err)
assert.Equal(err.Error(), "bad message")
})
}
// Test_TicketSerialization provides unit tests for making sure oplog.Tickets properly serialize writes to oplog entries
func Test_TicketSerialization(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
assert, require := assert.New(t), require.New(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
cipherer := initWrapper(t)
cipherer := testWrapper(t)
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
id := testId(t)
firstTx := db.Begin()
firstUser := oplog_test.TestUser{
Name: "foo-" + id,
}
err = firstTx.Create(&firstUser).Error
assert.NilError(t, err)
require.NoError(err)
firstTicket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
firstQueue := Queue{}
err = firstQueue.Add(&firstUser, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
firstLogEntry, err := NewEntry(
"test-users",
@ -663,23 +648,22 @@ func Test_TicketSerialization(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
firstLogEntry.Data = firstQueue.Bytes()
id2, err := uuid.GenerateUUID()
assert.NilError(t, err)
id2 := testId(t)
secondTx := db.Begin()
secondUser := oplog_test.TestUser{
Name: "foo-" + id2,
}
err = secondTx.Create(&secondUser).Error
assert.NilError(t, err)
require.NoError(err)
secondTicket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
secondQueue := Queue{}
err = secondQueue.Add(&secondUser, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
secondLogEntry, err := NewEntry(
"foobar",
@ -691,17 +675,17 @@ func Test_TicketSerialization(t *testing.T) {
cipherer,
ticketer,
)
assert.NilError(t, err)
require.NoError(err)
secondLogEntry.Data = secondQueue.Bytes()
err = secondLogEntry.Write(context.Background(), &GormWriter{secondTx}, secondTicket)
assert.NilError(t, err)
assert.Assert(t, secondLogEntry.Id != 0)
assert.Check(t, secondLogEntry.CreateTime != nil)
assert.Check(t, secondLogEntry.UpdateTime != nil)
require.NoError(err)
assert.NotEmpty(secondLogEntry.Id, 0)
assert.NotNil(secondLogEntry.CreateTime)
assert.NotNil(secondLogEntry.UpdateTime)
err = secondTx.Commit().Error
assert.NilError(t, err)
assert.Assert(t, secondLogEntry.Id != 0)
require.NoError(err)
assert.NotNil(secondLogEntry.Id)
err = firstLogEntry.Write(context.Background(), &GormWriter{firstTx}, firstTicket)
if err != nil {
@ -711,89 +695,3 @@ func Test_TicketSerialization(t *testing.T) {
t.Error("should have failed to write firstLogEntry")
}
}
// initDbInDocker initializes postgres within dockertest for the unit tests
func initDbInDocker(t *testing.T) (cleanup func(), retURL string, err error) {
if os.Getenv("PG_URL") != "" {
initTestStore(t, func() {}, os.Getenv("PG_URL"))
return func() {}, os.Getenv("PG_URL"), nil
}
pool, err := dockertest.NewPool("")
if err != nil {
return func() {}, "", fmt.Errorf("could not connect to docker: %w", err)
}
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"})
if err != nil {
return func() {}, "", fmt.Errorf("could not start resource: %w", err)
}
c := func() {
cleanupResource(pool, resource)
}
url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp"))
if err := pool.Retry(func() error {
db, err := sql.Open("postgres", url)
if err != nil {
return fmt.Errorf("error opening postgres dev container: %w", err)
}
if err := db.Ping(); err != nil {
return err
}
defer db.Close()
return nil
}); err != nil {
return func() {}, "", fmt.Errorf("could not connect to docker: %w", err)
}
initTestStore(t, c, url)
return c, url, nil
}
// initWrapper initializes an AEAD wrapping.Wrapper for testing the oplog
func initWrapper(t *testing.T) wrapping.Wrapper {
rootKey := make([]byte, 32)
n, err := rand.Read(rootKey)
if err != nil {
t.Fatal(err)
}
if n != 32 {
t.Fatal(n)
}
root := aead.NewWrapper(nil)
if err := root.SetAESGCMKeyBytes(rootKey); err != nil {
t.Fatal(err)
}
return root
}
// initTestStore will execute the migrations needed to initialize the store for tests
func initTestStore(t *testing.T, cleanup func(), url string) {
// run migrations
m, err := migrate.New("file://../db/migrations/postgres", url)
if err != nil {
cleanup()
t.Fatalf("Error creating migrations: %s", err)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
cleanup()
t.Fatalf("Error running migrations: %s", err)
}
}
// cleanupResource will clean up the dockertest resources (postgres)
func cleanupResource(pool *dockertest.Pool, resource *dockertest.Resource) error {
var err error
for i := 0; i < 10; i++ {
err = pool.Purge(resource)
if err == nil {
return nil
}
}
if strings.Contains(err.Error(), "No such container") {
return nil
}
return fmt.Errorf("Failed to cleanup local container: %s", err)
}

@ -18,6 +18,7 @@ type Options map[string]interface{}
func getDefaultOptions() Options {
return Options{
optionWithFieldMaskPaths: []string{},
optionWithSetToNullPaths: []string{},
optionWithAggregateNames: false,
}
}
@ -32,6 +33,16 @@ func WithFieldMaskPaths(fieldMaskPaths []string) Option {
}
}
const optionWithSetToNullPaths = "optionWithSetToNullPaths"
// WithSetToNullPaths represents an optional set of symbolic field paths (for example: "f.a", "f.b.d") used
// to specify a subset of fields that should be set to null. (see google.golang.org/genproto/protobuf/field_mask)
func WithSetToNullPaths(setToNullPaths []string) Option {
return func(o Options) {
o[optionWithSetToNullPaths] = setToNullPaths
}
}
const optionWithAggregateNames = "optionWithAggregateNames"
// WithAggregateNames enables/disables the use of multiple aggregate names for Ticketers

@ -0,0 +1,31 @@
package oplog
import (
"testing"
"github.com/stretchr/testify/assert"
)
// Test_GetOpts provides unit tests for GetOpts and all the options
func Test_GetOpts(t *testing.T) {
t.Parallel()
assert := assert.New(t)
t.Run("WithFieldMaskPaths", func(t *testing.T) {
opts := GetOpts(WithFieldMaskPaths([]string{"test"}))
testOpts := getDefaultOptions()
testOpts[optionWithFieldMaskPaths] = []string{"test"}
assert.Equal(opts, testOpts)
})
t.Run("WithSetToNullPaths", func(t *testing.T) {
opts := GetOpts(WithSetToNullPaths([]string{"test"}))
testOpts := getDefaultOptions()
testOpts[optionWithSetToNullPaths] = []string{"test"}
assert.Equal(opts, testOpts)
})
t.Run("WithAggregateNames", func(t *testing.T) {
opts := GetOpts(WithAggregateNames(true))
testOpts := getDefaultOptions()
testOpts[optionWithAggregateNames] = true
assert.Equal(opts, testOpts)
})
}

@ -7,6 +7,7 @@ import (
"fmt"
"sync"
common "github.com/hashicorp/watchtower/internal/db/common"
"google.golang.org/genproto/protobuf/field_mask"
"google.golang.org/protobuf/proto"
)
@ -15,18 +16,24 @@ import (
type Queue struct {
// Buffer for the queue
bytes.Buffer
// Catalog provides a TypeCatalog for the types added to the queue
Catalog *TypeCatalog
mx sync.Mutex
}
// Add pb message to queue
func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) error {
// Add message to queue. typeName defines the type of message added to the
// queue and allows the msg to be removed using a TypeCatalog with a
// coresponding typeName entry. OpType defines the msg's operation (create, add,
// update, etc). If OpType == OpType_OP_TYPE_UPDATE, the WithFieldMaskPaths()
// and SetToNullPaths() options are supported.
func (q *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) error {
// we're not checking the Catalog for nil, since it's not used
// when Adding messages to the queue
opts := GetOpts(opt...)
withPaths := opts[optionWithFieldMaskPaths].([]string)
withFieldMasks := opts[optionWithFieldMaskPaths].([]string)
withNullPaths := opts[optionWithSetToNullPaths].([]string)
if _, ok := m.(ReplayableMessage); !ok {
return fmt.Errorf("error %T is not a ReplayableMessage", m)
@ -35,23 +42,45 @@ func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) e
if err != nil {
return fmt.Errorf("error marshaling add parameter: %w", err)
}
if t == OpType_OP_TYPE_UPDATE {
if len(withFieldMasks) == 0 && len(withNullPaths) == 0 {
return fmt.Errorf("queue add: missing field masks or null paths for update")
}
fMasks := withFieldMasks
if fMasks == nil {
fMasks = []string{}
}
nullPaths := withNullPaths
if nullPaths == nil {
nullPaths = []string{}
}
i, _, _, err := common.Intersection(fMasks, nullPaths)
if err != nil {
return fmt.Errorf("queue add: %w", err)
}
if len(i) != 0 {
return fmt.Errorf("queue add: field masks and null paths intersect with: %s", i)
}
}
msg := &AnyOperation{
TypeName: typeName,
Value: value,
OperationType: t,
FieldMask: &field_mask.FieldMask{Paths: withPaths},
FieldMask: &field_mask.FieldMask{Paths: withFieldMasks},
NullMask: &field_mask.FieldMask{Paths: withNullPaths},
}
data, err := proto.Marshal(msg)
if err != nil {
return fmt.Errorf("error marhaling the msg for Add: %w", err)
}
r.mx.Lock()
defer r.mx.Unlock()
err = binary.Write(r, binary.LittleEndian, uint32(len(data)))
q.mx.Lock()
defer q.mx.Unlock()
err = binary.Write(q, binary.LittleEndian, uint32(len(data)))
if err != nil {
return err
}
n, err := r.Write(data)
n, err := q.Write(data)
if err != nil {
return fmt.Errorf("error writing to queue buffer: %w", err)
}
@ -61,34 +90,49 @@ func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) e
return nil
}
// Remove pb message from the queue and EOF if empty
func (r *Queue) Remove() (proto.Message, OpType, []string, error) {
if r.Catalog == nil {
return nil, OpType_OP_TYPE_UNSPECIFIED, nil, errors.New("remove Catalog is nil")
// Remove pb message from the queue and EOF if empty. It also returns the OpType
// for the msg and if it's OpType_OP_TYPE_UPDATE, the it will also return the
// fieldMask and setToNullPaths for the update operation.
func (q *Queue) Remove() (proto.Message, OpType, []string, []string, error) {
if q.Catalog == nil {
return nil, OpType_OP_TYPE_UNSPECIFIED, nil, nil, errors.New("remove Catalog is nil")
}
r.mx.Lock()
defer r.mx.Unlock()
q.mx.Lock()
defer q.mx.Unlock()
var n uint32
err := binary.Read(r, binary.LittleEndian, &n)
err := binary.Read(q, binary.LittleEndian, &n)
if err != nil {
return nil, 0, nil, err // intentionally not wrapping error so client can test for sentinel EOF error
return nil, 0, nil, nil, err // intentionally not wrapping error so client can test for sentinel EOF error
}
data := r.Next(int(n))
data := q.Next(int(n))
msg := new(AnyOperation)
err = proto.Unmarshal(data, msg)
if err != nil {
return nil, 0, nil, fmt.Errorf("error marshaling the msg for Remove: %w", err)
return nil, 0, nil, nil, fmt.Errorf("error marshaling the msg for Remove: %w", err)
}
if msg.Value == nil {
return nil, 0, nil, nil
return nil, 0, nil, nil, nil
}
any, err := r.Catalog.Get(msg.TypeName)
any, err := q.Catalog.Get(msg.TypeName)
if err != nil {
return nil, 0, nil, fmt.Errorf("error getting the TypeName for Remove: %w", err)
return nil, 0, nil, nil, fmt.Errorf("error getting the TypeName for Remove: %w", err)
}
pm := any.(proto.Message)
if err = proto.Unmarshal(msg.Value, pm); err != nil {
return nil, 0, nil, fmt.Errorf("error unmarshaling the value for Remove: %w", err)
return nil, 0, nil, nil, fmt.Errorf("error unmarshaling the value for Remove: %w", err)
}
var masks, nullPaths []string
if msg.OperationType == OpType_OP_TYPE_UPDATE {
if msg.FieldMask != nil {
masks = msg.FieldMask.GetPaths()
}
if msg.NullMask != nil {
nullPaths = msg.NullMask.GetPaths()
}
if len(masks) == 0 && len(nullPaths) == 0 {
return nil, 0, nil, nil, errors.New("error unmarshaling the value for Remove: field mask or null paths is required")
}
}
return pm, msg.OperationType, msg.FieldMask.GetPaths(), nil
return pm, msg.OperationType, masks, nullPaths, nil
}

@ -1,12 +1,12 @@
package oplog
import (
reflect "reflect"
"testing"
"github.com/hashicorp/watchtower/internal/oplog/oplog_test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"gotest.tools/assert"
)
// Test_Queue provides basic tests for the Queue type
@ -17,7 +17,7 @@ func Test_Queue(t *testing.T) {
Type{new(oplog_test.TestCar), "car"},
Type{new(oplog_test.TestRental), "rental"},
)
assert.NilError(t, err)
require.NoError(t, err)
queue := Queue{Catalog: types}
user := &oplog_test.TestUser{
@ -41,64 +41,92 @@ func Test_Queue(t *testing.T) {
Name: "bob",
Email: "bob@alice.com",
}
userNilUpdate := &oplog_test.TestUser{
Id: 1,
Name: "bob",
}
t.Run("add/remove", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
err = queue.Add(user, "user", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
err = queue.Add(car, "car", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
err = queue.Add(rental, "rental", OpType_OP_TYPE_CREATE)
assert.NilError(t, err)
require.NoError(err)
err = queue.Add(userUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name", "Email"}))
assert.NilError(t, err)
require.NoError(err)
err = queue.Add(userNilUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name"}), WithSetToNullPaths([]string{"Email"}))
require.NoError(err)
queuedUser, ty, fm, nm, err := queue.Remove()
require.NoError(err)
assert.True(proto.Equal(user, queuedUser))
assert.Equal(ty, OpType_OP_TYPE_CREATE)
assert.Nil(fm)
assert.Nil(nm)
queuedUser, ty, fm, err := queue.Remove()
assert.NilError(t, err)
assert.Assert(t, proto.Equal(user, queuedUser))
assert.Assert(t, ty == OpType_OP_TYPE_CREATE)
assert.Assert(t, fm == nil)
queuedCar, ty, fm, nm, err := queue.Remove()
require.NoError(err)
assert.True(proto.Equal(car, queuedCar))
assert.Equal(ty, OpType_OP_TYPE_CREATE)
assert.Nil(fm)
assert.Nil(nm)
queuedCar, ty, fm, err := queue.Remove()
assert.NilError(t, err)
assert.Assert(t, proto.Equal(car, queuedCar))
assert.Assert(t, ty == OpType_OP_TYPE_CREATE)
assert.Assert(t, fm == nil)
queuedRental, ty, fm, nm, err := queue.Remove()
require.NoError(err)
assert.True(proto.Equal(rental, queuedRental))
assert.Equal(ty, OpType_OP_TYPE_CREATE)
assert.Nil(fm)
assert.Nil(nm)
queuedRental, ty, fm, err := queue.Remove()
assert.NilError(t, err)
assert.Assert(t, proto.Equal(rental, queuedRental))
assert.Assert(t, ty == OpType_OP_TYPE_CREATE)
assert.Assert(t, fm == nil)
queuedUserUpdate, ty, fm, nm, err := queue.Remove()
require.NoError(err)
assert.True(proto.Equal(userUpdate, queuedUserUpdate))
assert.Equal(ty, OpType_OP_TYPE_UPDATE)
assert.Equal(fm, []string{"Name", "Email"})
assert.Nil(nm)
queuedUserUpdate, ty, fm, err := queue.Remove()
assert.NilError(t, err)
assert.Assert(t, proto.Equal(userUpdate, queuedUserUpdate))
assert.Assert(t, ty == OpType_OP_TYPE_UPDATE)
assert.Assert(t, reflect.DeepEqual(fm, []string{"Name", "Email"}))
queuedUserNilUpdate, ty, fm, nm, err := queue.Remove()
require.NoError(err)
assert.True(proto.Equal(userNilUpdate, queuedUserNilUpdate))
assert.Equal(ty, OpType_OP_TYPE_UPDATE)
assert.Equal(fm, []string{"Name"})
assert.Equal(nm, []string{"Email"})
})
t.Run("valid with nil type catalog", func(t *testing.T) {
require := require.New(t)
queue := Queue{}
assert.NilError(t, queue.Add(user, "user", OpType_OP_TYPE_CREATE))
require.NoError(queue.Add(user, "user", OpType_OP_TYPE_CREATE))
})
t.Run("error with nil type catalog", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
queue := Queue{}
_, _, _, err := queue.Remove()
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "remove Catalog is nil")
_, _, _, _, err = queue.Remove()
require.Error(err)
assert.Equal(err.Error(), "remove Catalog is nil")
})
t.Run("not replayable", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
u := &oplog_test.TestNonReplayableUser{
Name: "Alice",
PhoneNumber: "867-5309",
Email: "alice@bob.com",
}
err = queue.Add(u, "user", OpType_OP_TYPE_CREATE)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error *oplog_test.TestNonReplayableUser is not a ReplayableMessage")
require.Error(err)
assert.Equal(err.Error(), "error *oplog_test.TestNonReplayableUser is not a ReplayableMessage")
})
t.Run("nil message", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
err = queue.Add(nil, "user", OpType_OP_TYPE_CREATE)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error <nil> is not a ReplayableMessage")
require.Error(err)
assert.Equal(err.Error(), "error <nil> is not a ReplayableMessage")
})
t.Run("missing both field mask and null paths for update", func(t *testing.T) {
require.Error(t, queue.Add(userUpdate, "user", OpType_OP_TYPE_UPDATE))
})
t.Run("intersecting field masks and null paths", func(t *testing.T) {
require.Error(t, queue.Add(userNilUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name"}), WithSetToNullPaths([]string{"Name"})))
})
}

@ -0,0 +1,135 @@
package oplog
import (
"crypto/rand"
"database/sql"
"fmt"
"os"
"strings"
"testing"
"github.com/golang-migrate/migrate/v4"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/wrappers/aead"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/watchtower/internal/oplog/oplog_test"
"github.com/jinzhu/gorm"
"github.com/ory/dockertest/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func testCleanup(t *testing.T, cleanupFunc func(), db *gorm.DB) {
t.Helper()
cleanupFunc()
err := db.Close()
assert.NoError(t, err)
}
func testUser(t *testing.T, db *gorm.DB, name, phoneNumber, email string) *oplog_test.TestUser {
t.Helper()
u := &oplog_test.TestUser{
Name: name,
PhoneNumber: phoneNumber,
Email: email,
}
w := GormWriter{db}
err := w.Create(&u)
require.NoError(t, err)
return u
}
func testFindUser(t *testing.T, db *gorm.DB, userId uint32) *oplog_test.TestUser {
t.Helper()
var foundUser oplog_test.TestUser
err := db.Where("id = ?", userId).First(&foundUser).Error
require.NoError(t, err)
return &foundUser
}
func testId(t *testing.T) string {
t.Helper()
id, err := uuid.GenerateUUID()
require.NoError(t, err)
return id
}
// testInitDbInDocker initializes postgres within dockertest for the unit tests
func testInitDbInDocker(t *testing.T) (cleanup func(), retURL string, err error) {
t.Helper()
if os.Getenv("PG_URL") != "" {
testInitStore(t, func() {}, os.Getenv("PG_URL"))
return func() {}, os.Getenv("PG_URL"), nil
}
pool, err := dockertest.NewPool("")
require.NoErrorf(t, err, "could not connect to docker: %w", err)
resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"})
require.NoErrorf(t, err, "could not start resource: %w", err)
c := func() {
err := testCleanupResource(t, pool, resource)
assert.NoError(t, err)
}
url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp"))
err = pool.Retry(func() error {
db, err := sql.Open("postgres", url)
if err != nil {
return fmt.Errorf("error opening postgres dev container: %w", err)
}
if err := db.Ping(); err != nil {
return err
}
defer db.Close()
return nil
})
require.NoErrorf(t, err, "could not connect to docker: %w", err)
testInitStore(t, c, url)
return c, url, nil
}
// testWrapper initializes an AEAD wrapping.Wrapper for testing the oplog
func testWrapper(t *testing.T) wrapping.Wrapper {
t.Helper()
rootKey := make([]byte, 32)
n, err := rand.Read(rootKey)
require.NoError(t, err)
require.Equal(t, n, 32)
root := aead.NewWrapper(nil)
err = root.SetAESGCMKeyBytes(rootKey)
require.NoError(t, err)
return root
}
// testInitStore will execute the migrations needed to initialize the store for tests
func testInitStore(t *testing.T, cleanup func(), url string) {
t.Helper()
// run migrations
m, err := migrate.New("file://../db/migrations/postgres", url)
require.NoError(t, err, "Error creating migrations")
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
cleanup()
require.NoError(t, err, "Error running migrations")
}
}
// testCleanupResource will clean up the dockertest resources (postgres)
func testCleanupResource(t *testing.T, pool *dockertest.Pool, resource *dockertest.Resource) error {
t.Helper()
var err error
for i := 0; i < 10; i++ {
err = pool.Purge(resource)
if err == nil {
return nil
}
}
if strings.Contains(err.Error(), "No such container") {
return nil
}
return fmt.Errorf("Failed to cleanup local container: %s", err)
}

@ -0,0 +1,87 @@
package oplog
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_testCleanup(t *testing.T) {
assert := assert.New(t)
cleanup, db := setup(t)
testCleanup(t, cleanup, db)
err := db.Begin().Error
assert.Equal("sql: database is closed", err.Error())
}
func Test_testUser(t *testing.T) {
assert, require := assert.New(t), require.New(t)
cleanup, db := setup(t)
defer testCleanup(t, cleanup, db)
id := testId(t)
u := testUser(t, db, id, id, id)
require.NotNil(u)
assert.Equal(id, u.Name)
assert.Equal(id, u.PhoneNumber)
assert.Equal(id, u.Email)
}
func Test_testFindUser(t *testing.T) {
assert, require := assert.New(t), require.New(t)
cleanup, db := setup(t)
defer testCleanup(t, cleanup, db)
id := testId(t)
u := testUser(t, db, id, id, id)
require.NotNil(u)
found := testFindUser(t, db, u.Id)
require.NotNil(found)
assert.Equal(u, found)
}
func Test_testId(t *testing.T) {
require := require.New(t)
id := testId(t)
require.NotNil(id)
}
func Test_testInitDbInDocker(t *testing.T) {
require := require.New(t)
cleanup, url, err := testInitDbInDocker(t)
defer cleanup()
require.NoError(err)
require.NotEmpty(url)
require.NotNil(cleanup)
}
func Test_testWrapper(t *testing.T) {
w := testWrapper(t)
require.NotNil(t, w)
assert.Equal(t, "aead", w.Type())
}
func Test_testInitStore(t *testing.T) {
assert, require := assert.New(t), require.New(t)
cleanup, url, err := testInitDbInDocker(t)
require.NoError(err)
defer cleanup()
require.NotEmpty(url)
testInitStore(t, cleanup, url)
const query = `
select count(*) from information_schema."tables" t where table_name = 'db_test_user';
`
db, err := sql.Open("postgres", url)
require.NoError(err)
var cnt int
err = db.QueryRow(query).Scan(&cnt)
require.NoError(err)
assert.Equal(1, cnt)
}

@ -3,94 +3,108 @@ package oplog
import (
"testing"
"gotest.tools/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_NewGormTicketer provides unit tests for creating a Gorm ticketer
func Test_NewGormTicketer(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
assert.Assert(t, ticketer != nil)
require.NoError(err)
assert.NotNil(ticketer)
})
t.Run("bad db", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
_, err := NewGormTicketer(nil, WithAggregateNames(true))
assert.Equal(t, err.Error(), "tx is nil")
require.Error(err)
assert.Equal(err.Error(), "tx is nil")
})
}
// Test_GetTicket provides unit tests for getting oplog.Tickets
func Test_GetTicket(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
ticketer, err := NewGormTicketer(db, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(t, err)
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
assert.Equal(t, ticket.Name, "default")
assert.Check(t, ticket.Version != 0)
require.NoError(err)
assert.Equal(ticket.Name, "default")
assert.NotEqual(ticket.Version, 0)
})
t.Run("no name", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ticket, err := ticketer.GetTicket("")
assert.Equal(t, err.Error(), "bad ticket name")
assert.Check(t, ticket == nil)
require.Error(err)
assert.Equal(err.Error(), "bad ticket name")
assert.Nil(ticket)
})
}
// Test_Redeem provides unit tests for redeeming tickets
func Test_Redeem(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("valid", func(t *testing.T) {
require := require.New(t)
tx := db.Begin()
defer tx.Commit()
ticketer, err := NewGormTicketer(tx, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
err = ticketer.Redeem(ticket)
assert.NilError(t, err)
tx.Commit()
require.NoError(err)
})
t.Run("nil ticket", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
defer tx.Commit()
ticketer, err := NewGormTicketer(tx, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
err = ticketer.Redeem(nil)
assert.Equal(t, err.Error(), "ticket is nil")
tx.Commit()
require.Error(err)
assert.Equal(err.Error(), "ticket is nil")
})
t.Run("detect two redemptions in separate concurrent transactions", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
ticketer, err := NewGormTicketer(tx, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
ticket, err := ticketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
secondTx := db.Begin()
secondTicketer, err := NewGormTicketer(secondTx, WithAggregateNames(true))
assert.NilError(t, err)
require.NoError(err)
secondTicket, err := secondTicketer.GetTicket("default")
assert.NilError(t, err)
require.NoError(err)
err = ticketer.Redeem(ticket)
assert.NilError(t, err)
require.NoError(err)
tx.Commit()
err = secondTicketer.Redeem(secondTicket)
assert.Equal(t, err.Error(), "ticket already redeemed")
assert.Equal(err.Error(), "ticket already redeemed")
})
}

@ -5,34 +5,35 @@ import (
"testing"
"github.com/hashicorp/watchtower/internal/oplog/oplog_test"
"gotest.tools/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_TypeCatalog provides basic red/green unit tests
func Test_TypeCatalog(t *testing.T) {
t.Parallel()
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
Type{new(oplog_test.TestCar), "car"},
Type{new(oplog_test.TestRental), "rental"},
)
assert.NilError(t, err)
require.NoError(err)
name, err := types.GetTypeName(new(oplog_test.TestUser))
assert.NilError(t, err)
assert.Assert(t, name == "user")
require.NoError(err)
assert.Equal(name, "user")
_, err = types.GetTypeName(oplog_test.TestUser{})
assert.Assert(t, err != nil)
assert.Error(err)
s := "string"
_, err = types.GetTypeName(&s)
assert.Assert(t, err != nil)
assert.Error(err)
_, err = types.Get("unknown")
assert.Assert(t, err != nil)
assert.Error(err)
}
// Test_NewTypeCatalog provides unit tests for NewTypeCatalog
@ -40,39 +41,47 @@ func Test_NewTypeCatalog(t *testing.T) {
t.Parallel()
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
Type{new(oplog_test.TestCar), "car"},
Type{new(oplog_test.TestRental), "rental"},
)
assert.NilError(t, err)
require.NoError(err)
u, err := types.Get("user")
assert.NilError(t, err)
assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
require.NoError(err)
assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
})
t.Run("missing Type.Name", func(t *testing.T) {
assert := assert.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), ""},
)
assert.Check(t, types == nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error setting the type: typeName is an empty string for Set (in NewTypeCatalog)")
assert.Nil(types)
assert.Error(err)
assert.Equal(err.Error(), "error setting the type: typeName is an empty string for Set (in NewTypeCatalog)")
})
t.Run("missing Type.Interface", func(t *testing.T) {
assert := assert.New(t)
types, err := NewTypeCatalog(
Type{nil, ""},
)
assert.Check(t, types == nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error type is {} (in NewTypeCatalog)")
assert.Nil(types)
assert.Error(err)
assert.Equal(err.Error(), "error type is {} (in NewTypeCatalog)")
})
t.Run("empty Type", func(t *testing.T) {
assert := assert.New(t)
types, err := NewTypeCatalog(
Type{},
)
assert.Check(t, types == nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error type is {} (in NewTypeCatalog)")
assert.Nil(types)
assert.Error(err)
assert.Equal(err.Error(), "error type is {} (in NewTypeCatalog)")
})
}
@ -82,37 +91,43 @@ func Test_GetTypeName(t *testing.T) {
t.Parallel()
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
Type{new(oplog_test.TestCar), "car"},
)
assert.NilError(t, err)
require.NoError(err)
n, err := types.GetTypeName(new(oplog_test.TestUser))
assert.NilError(t, err)
assert.Equal(t, n, "user")
require.NoError(err)
assert.Equal(n, "user")
})
t.Run("bad name", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
)
assert.NilError(t, err)
require.NoError(err)
n, err := types.GetTypeName(new(oplog_test.TestCar))
assert.Check(t, err != nil)
assert.Equal(t, n, "")
assert.Equal(t, err.Error(), "error unknown name for interface: *oplog_test.TestCar")
require.Error(err)
assert.Equal(n, "")
assert.Equal(err.Error(), "error unknown name for interface: *oplog_test.TestCar")
})
t.Run("nil interface", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
)
assert.NilError(t, err)
require.NoError(err)
n, err := types.GetTypeName(nil)
assert.Check(t, err != nil)
assert.Equal(t, n, "")
assert.Equal(t, err.Error(), "error interface parameter is nil for GetTypeName")
require.Error(err)
assert.Equal(n, "")
assert.Equal(err.Error(), "error interface parameter is nil for GetTypeName")
})
}
@ -121,37 +136,43 @@ func Test_Get(t *testing.T) {
t.Parallel()
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
Type{new(oplog_test.TestCar), "car"},
)
assert.NilError(t, err)
require.NoError(err)
u, err := types.Get("user")
assert.NilError(t, err)
assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
require.NoError(err)
assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
})
t.Run("bad name", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
)
assert.NilError(t, err)
require.NoError(err)
n, err := types.Get("car")
assert.Check(t, err != nil)
assert.Equal(t, n, nil)
assert.Equal(t, err.Error(), "error typeName is not found for Get")
require.Error(err)
assert.Equal(n, nil)
assert.Equal(err.Error(), "error typeName is not found for Get")
})
t.Run("bad typeName", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog(
Type{new(oplog_test.TestUser), "user"},
)
assert.NilError(t, err)
require.NoError(err)
n, err := types.Get("")
assert.Check(t, err != nil)
assert.Equal(t, n, nil)
assert.Equal(t, err.Error(), "error typeName is empty string for Get")
require.Error(err)
assert.Equal(n, nil)
assert.Equal(err.Error(), "error typeName is empty string for Get")
})
}
@ -160,28 +181,34 @@ func Test_Set(t *testing.T) {
t.Parallel()
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog()
assert.NilError(t, err)
require.NoError(err)
u := new(oplog_test.TestUser)
err = types.Set(u, "user")
assert.NilError(t, err)
assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
require.NoError(err)
assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser)))
})
t.Run("bad name", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog()
assert.NilError(t, err)
require.NoError(err)
u := new(oplog_test.TestUser)
err = types.Set(u, "")
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "typeName is an empty string for Set")
assert.Assert(t, reflect.DeepEqual(types, &TypeCatalog{}))
require.Error(err)
assert.Equal(err.Error(), "typeName is an empty string for Set")
assert.Equal(types, &TypeCatalog{})
})
t.Run("bad interface", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
types, err := NewTypeCatalog()
assert.NilError(t, err)
require.NoError(err)
err = types.Set(nil, "")
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "error interface parameter is nil for Set")
assert.Assert(t, reflect.DeepEqual(types, &TypeCatalog{}))
require.Error(err)
assert.Equal(err.Error(), "error interface parameter is nil for Set")
assert.Equal(types, &TypeCatalog{})
})
}

@ -3,8 +3,8 @@ package oplog
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/hashicorp/watchtower/internal/db/common"
"github.com/jinzhu/gorm"
)
@ -14,8 +14,11 @@ type Writer interface {
// Create the entry
Create(interface{}) error
// Update the entry using the fieldMaskPaths, which are Paths from field_mask.proto
Update(entry interface{}, fieldMaskPaths []string) error
// Update the entry using the fieldMaskPaths and setNullPaths, which are
// Paths from field_mask.proto. fieldMaskPaths is required. setToNullPaths
// is optional. fieldMaskPaths and setNullPaths cannot instersect and both
// cannot be zero len.
Update(entry interface{}, fieldMaskPaths, setToNullPaths []string) error
// Delete the entry
Delete(interface{}) error
@ -23,7 +26,8 @@ type Writer interface {
// HasTable checks if tableName exists
hasTable(tableName string) bool
// CreateTableLike will create a newTableName using the existing table as a starting point
// CreateTableLike will create a newTableName using the existing table as a
// starting point
createTableLike(existingTableName string, newTableName string) error
// DropTableIfExists will drop the table if it exists
@ -49,45 +53,25 @@ func (w *GormWriter) Create(i interface{}) error {
return nil
}
// Update an object in storage, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise
// we will send every field to the DB.
func (w *GormWriter) Update(i interface{}, fieldMaskPaths []string) error {
// Update the entry using the fieldMaskPaths and setNullPaths, which are
// Paths from field_mask.proto. fieldMaskPaths is required. setToNullPaths is
// optional. fieldMaskPaths and setNullPaths cannot intersect and both cannot be
// zero len.
func (w *GormWriter) Update(i interface{}, fieldMaskPaths, setToNullPaths []string) error {
if w.Tx == nil {
return errors.New("update Tx is nil")
}
if i == nil {
return errors.New("update interface is nil")
}
if len(fieldMaskPaths) == 0 {
if err := w.Tx.Save(i).Error; err != nil {
return fmt.Errorf("error updating: %w", err)
}
}
updateFields := map[string]interface{}{}
val := reflect.Indirect(reflect.ValueOf(i))
structTyp := val.Type()
for _, field := range fieldMaskPaths {
for i := 0; i < structTyp.NumField(); i++ {
// support for an embedded a gorm type
if structTyp.Field(i).Type.Kind() == reflect.Struct {
embType := structTyp.Field(i).Type
// check if the embedded field is exported via CanInterface()
if val.Field(i).CanInterface() {
embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface()))
for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ {
if strings.EqualFold(embType.Field(embFieldNum).Name, field) {
updateFields[field] = embVal.Field(embFieldNum).Interface()
}
}
continue
}
}
// it's not an embedded type, so check if the field name matches
if strings.EqualFold(structTyp.Field(i).Name, field) {
updateFields[field] = val.Field(i).Interface()
}
}
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return errors.New("update both fieldMaskPaths and setToNullPaths are missing")
}
// common.UpdateFields will also check to ensure that fieldMaskPaths and
// setToNullPaths do not intersect.
updateFields, err := common.UpdateFields(i, fieldMaskPaths, setToNullPaths)
if err != nil {
return fmt.Errorf("error updating: unable to build update fields %w", err)
}
if err := w.Tx.Model(i).Updates(updateFields).Error; err != nil {
return fmt.Errorf("error updating: %w", err)

@ -3,302 +3,322 @@ package oplog
import (
"testing"
"github.com/hashicorp/go-uuid"
dbassert "github.com/hashicorp/dbassert/gorm"
"github.com/hashicorp/watchtower/internal/oplog/oplog_test"
"github.com/jinzhu/gorm"
"gotest.tools/assert"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// Test_GormWriterCreate provides unit tests for GormWriter Create
func Test_GormWriterCreate(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
Name: "foo-" + testId(t),
}
assert.NilError(t, w.Create(&user))
require.NoError(w.Create(&user))
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.NilError(t, err)
assert.Equal(t, user.Name, foundUser.Name)
foundUser := testFindUser(t, tx, user.Id)
assert.Equal(user.Name, foundUser.Name)
})
t.Run("nil tx", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{nil}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
err := w.Create(&oplog_test.TestUser{})
require.Error(err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = w.Create(&user)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "create Tx is nil")
assert.Equal(err.Error(), "create Tx is nil")
})
t.Run("nil model", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
err := w.Create(nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "create interface is nil")
require.Error(err)
assert.Equal(err.Error(), "create interface is nil")
})
}
// Test_GormWriterDelete provides unit tests for GormWriter Delete
func Test_GormWriterDelete(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
id := testId(t)
user := testUser(t, db, "foo-"+id, "", "")
foundUser := testFindUser(t, db, user.Id)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
assert.NilError(t, w.Create(&user))
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.NilError(t, err)
assert.NilError(t, w.Delete(&user))
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.Check(t, err != nil)
assert.Equal(t, err, gorm.ErrRecordNotFound)
require.NoError(w.Delete(&user))
err := tx.Where("id = ?", user.Id).First(&foundUser).Error
require.Error(err)
assert.Equal(err, gorm.ErrRecordNotFound)
})
t.Run("nil tx", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{nil}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = w.Delete(&user)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "delete Tx is nil")
err := w.Delete(&oplog_test.TestUser{})
require.Error(err)
assert.Equal(err.Error(), "delete Tx is nil")
})
t.Run("nil model", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
err := w.Delete(nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "delete interface is nil")
})
}
// Test_GormWriterUpdate provides unit tests for GormWriter Update
func Test_GormWriterUpdate(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
t.Run("valid no fieldmask", func(t *testing.T) {
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
assert.NilError(t, w.Create(&user))
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.NilError(t, err)
foundUser.Name = foundUser.Name + "_updated"
assert.NilError(t, w.Update(&foundUser, nil))
var updatedUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&updatedUser).Error
assert.NilError(t, err)
assert.Equal(t, foundUser.Name, updatedUser.Name)
})
t.Run("valid with fieldmask", func(t *testing.T) {
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
assert.NilError(t, w.Create(&user))
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.NilError(t, err)
foundUser.Name = foundUser.Name + "_updated"
assert.NilError(t, w.Update(&foundUser, []string{"Name"}))
var updatedUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&updatedUser).Error
assert.NilError(t, err)
assert.Equal(t, foundUser.Name, updatedUser.Name)
})
t.Run("valid with incorrect fieldmask", func(t *testing.T) {
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
assert.NilError(t, w.Create(&user))
var foundUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&foundUser).Error
assert.NilError(t, err)
foundUser.Name = foundUser.Name + "_updated"
assert.NilError(t, w.Update(&foundUser, []string{"PhoneNumber"}))
var updatedUser oplog_test.TestUser
err = tx.Where("id = ?", user.Id).First(&updatedUser).Error
assert.NilError(t, err)
assert.Check(t, updatedUser.Name != foundUser.Name+"_updated")
})
t.Run("nil tx", func(t *testing.T) {
w := GormWriter{nil}
id, err := uuid.GenerateUUID()
assert.NilError(t, err)
user := oplog_test.TestUser{
Name: "foo-" + id,
}
err = w.Update(&user, nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "update Tx is nil")
})
t.Run("nil model", func(t *testing.T) {
tx := db.Begin()
defer tx.Rollback()
w := GormWriter{tx}
err := w.Update(nil, nil)
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "update interface is nil")
require.Error(err)
assert.Equal(err.Error(), "delete interface is nil")
})
}
// Test_GormWriterHasTable provides unit tests for GormWriter HasTable
func Test_GormWriterHasTable(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
w := GormWriter{Tx: db}
t.Run("success", func(t *testing.T) {
assert := require.New(t)
ok := w.hasTable("oplog_test_user")
assert.Equal(t, ok, true)
assert.Equal(ok, true)
})
t.Run("no table", func(t *testing.T) {
badTableName, err := uuid.GenerateUUID()
assert.NilError(t, err)
assert := assert.New(t)
badTableName := testId(t)
ok := w.hasTable(badTableName)
assert.Equal(t, ok, false)
assert.Equal(ok, false)
})
t.Run("blank table name", func(t *testing.T) {
assert := assert.New(t)
ok := w.hasTable("")
assert.Equal(t, ok, false)
assert.Equal(ok, false)
})
}
// Test_GormWriterCreateTable provides unit tests for GormWriter CreateTable
func Test_GormWriterCreateTable(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("success", func(t *testing.T) {
assert := assert.New(t)
w := GormWriter{Tx: db}
suffix, err := uuid.GenerateUUID()
assert.NilError(t, err)
suffix := testId(t)
u := &oplog_test.TestUser{}
newTableName := u.TableName() + "_" + suffix
defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }()
err = w.createTableLike(u.TableName(), newTableName)
assert.NilError(t, err)
defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }()
err := w.createTableLike(u.TableName(), newTableName)
assert.NoError(err)
})
t.Run("call twice", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{Tx: db}
suffix, err := uuid.GenerateUUID()
assert.NilError(t, err)
suffix := testId(t)
u := &oplog_test.TestUser{}
newTableName := u.TableName() + "_" + suffix
defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }()
err = w.createTableLike(u.TableName(), newTableName)
assert.NilError(t, err)
defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }()
err := w.createTableLike(u.TableName(), newTableName)
require.NoError(err)
// should be an error to create the same table twice
err = w.createTableLike(u.TableName(), newTableName)
assert.Check(t, err != nil)
assert.Error(t, err, err.Error(), nil)
require.Error(err)
assert.Error(err, err.Error(), nil)
})
t.Run("empty existing", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{Tx: db}
suffix, err := uuid.GenerateUUID()
assert.NilError(t, err)
suffix := testId(t)
u := &oplog_test.TestUser{}
newTableName := u.TableName() + "_" + suffix
defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }()
err = w.createTableLike("", newTableName)
assert.Check(t, err != nil)
assert.Error(t, err, err.Error(), nil)
assert.Equal(t, err.Error(), "error existingTableName is empty string")
defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }()
err := w.createTableLike("", newTableName)
require.Error(err)
assert.Error(err, err.Error(), nil)
assert.Equal(err.Error(), "error existingTableName is empty string")
})
t.Run("blank name", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{Tx: db}
u := &oplog_test.TestUser{}
err := w.createTableLike(u.TableName(), "")
assert.Check(t, err != nil)
assert.Error(t, err, err.Error(), nil)
assert.Equal(t, err.Error(), "error newTableName is empty string")
require.Error(err)
assert.Error(err, err.Error(), nil)
assert.Equal(err.Error(), "error newTableName is empty string")
})
}
// Test_GormWriterDropTableIfExists provides unit tests for GormWriter DropTableIfExists
func Test_GormWriterDropTableIfExists(t *testing.T) {
cleanup, db := setup(t)
defer cleanup()
defer db.Close()
defer testCleanup(t, cleanup, db)
t.Run("success", func(t *testing.T) {
assert := assert.New(t)
w := GormWriter{Tx: db}
suffix, err := uuid.GenerateUUID()
assert.NilError(t, err)
suffix := testId(t)
u := &oplog_test.TestUser{}
newTableName := u.TableName() + "_" + suffix
err = w.createTableLike(u.TableName(), newTableName)
assert.NilError(t, err)
defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }()
assert.NoError(w.createTableLike(u.TableName(), newTableName))
defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }()
})
t.Run("success with blank", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := GormWriter{Tx: db}
err := w.dropTableIfExists("")
assert.Check(t, err != nil)
assert.Equal(t, err.Error(), "cannot drop table whose name is an empty string")
require.Error(err)
assert.Equal(err.Error(), "cannot drop table whose name is an empty string")
})
}
func TestGormWriter_Update(t *testing.T) {
cleanup, db := setup(t)
defer testCleanup(t, cleanup, db)
id := testId(t)
type fields struct {
Tx *gorm.DB
}
type args struct {
user *oplog_test.TestUser
fieldMaskPaths []string
setToNullPaths []string
}
tests := []struct {
name string
Tx *gorm.DB
args args
wantUser *oplog_test.TestUser
wantErr bool
}{
{
name: "valid-fieldmask",
Tx: db,
args: args{
user: &oplog_test.TestUser{
Name: "valid-fieldmask",
},
fieldMaskPaths: []string{"name"},
},
wantUser: &oplog_test.TestUser{
Name: "valid-fieldmask",
Email: id,
PhoneNumber: id,
},
wantErr: false,
},
{
name: "valid-setToNull",
Tx: db,
args: args{
user: &oplog_test.TestUser{
Name: "valid-setToNull",
},
fieldMaskPaths: nil,
setToNullPaths: []string{"name"},
},
wantUser: &oplog_test.TestUser{
Name: "",
Email: id,
PhoneNumber: id,
},
wantErr: false,
},
{
name: "valid-setToNull-and-fieldMask",
Tx: db,
args: args{
user: &oplog_test.TestUser{
Email: "valid-setToNull-and-fieldMask",
},
fieldMaskPaths: []string{"email"},
setToNullPaths: []string{"name"},
},
wantUser: &oplog_test.TestUser{
Name: "",
Email: "valid-setToNull-and-fieldMask",
PhoneNumber: id,
},
wantErr: false,
},
{
name: "no-field-mask",
Tx: db,
args: args{
user: &oplog_test.TestUser{
Name: "no-field-mask",
},
fieldMaskPaths: []string{""},
},
wantErr: true,
},
{
name: "nil-field-mask",
Tx: db,
args: args{
user: &oplog_test.TestUser{
Name: "nil-field-mask",
},
fieldMaskPaths: nil,
},
wantErr: true,
},
{
name: "nil-tx",
Tx: nil,
args: args{
user: &oplog_test.TestUser{
Name: "nil-txt",
},
fieldMaskPaths: []string{"name"},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require, dbassert := assert.New(t), require.New(t), dbassert.New(t, db.DB(), "postgres")
w := &GormWriter{
Tx: tt.Tx,
}
u := testUser(t, db, id, id, id) // intentionally, not relying on tt.Tx
u.Name = tt.args.user.Name
u.Email = tt.args.user.Email
u.PhoneNumber = tt.args.user.PhoneNumber
err := w.Update(tt.args.user, tt.args.fieldMaskPaths, tt.args.setToNullPaths)
if tt.wantErr {
require.Error(err)
return
}
require.NoError(err)
var foundUser oplog_test.TestUser
err = db.Where("id = ?", u.Id).First(&foundUser).Error
require.NoError(err)
tt.wantUser.Id = u.Id
assert.Equal(tt.wantUser, &foundUser)
for _, f := range tt.args.setToNullPaths {
dbassert.IsNull(u, f)
}
})
}
t.Run("nil model", func(t *testing.T) {
assert := assert.New(t)
w := GormWriter{db}
err := w.Create(nil)
assert.Error(err)
})
}

@ -5,21 +5,37 @@ import "google/protobuf/field_mask.proto";
package controller.storage.oplog.v1;
option go_package = "github.com/hashicorp/watchtower/internal/oplog;oplog";
// OpType provides the type of database operation the Any message represents (create,
// update, delete)
// OpType provides the type of database operation the Any message represents
// (create, update, delete)
enum OpType {
// OP_TYPE_UNSPECIFIED defines an unspecified operation.
OP_TYPE_UNSPECIFIED = 0;
// OP_TYPE_CREATE defines a create operation.
OP_TYPE_CREATE = 1;
// OP_TYPE_UPDATE defines an update operation.
OP_TYPE_UPDATE = 2;
// OP_TYPE_DELETE defines a delete operation.
OP_TYPE_DELETE = 3;
}
// AnyOperation provides a message for anything and the type of operation it
// represents.
message AnyOperation {
// type_name defines type of operation.
string type_name = 1;
// value are the bytes of a marshalled proto buff.
bytes value = 2;
// operation_type defines the type of database operation.
OpType operation_type = 3;
// mask is the mask of fields to update.
// field_mask is the mask of fields to update.
google.protobuf.FieldMask field_mask = 4;
// null_mask is the mask of fields to set to null.
google.protobuf.FieldMask null_mask = 5;
}
Loading…
Cancel
Save