From 2d18177c9da97e519df392bea624fc3de1d46c06 Mon Sep 17 00:00:00 2001 From: Jim Date: Wed, 10 Jun 2020 06:56:46 -0400 Subject: [PATCH] oplog support for null field updates (#107) --- go.mod | 3 +- go.sum | 2 - internal/db/common/update.go | 8 +- internal/db/common/update_test.go | 2 +- internal/db/read_writer.go | 30 +- internal/db/read_writer_test.go | 7 +- internal/oplog/any_operation.pb.go | 71 ++- internal/oplog/oplog.go | 13 +- internal/oplog/oplog_test.go | 496 +++++++----------- internal/oplog/options.go | 11 + internal/oplog/options_test.go | 31 ++ internal/oplog/queue.go | 90 +++- internal/oplog/queue_test.go | 100 ++-- internal/oplog/testing.go | 135 +++++ internal/oplog/testing_test.go | 87 +++ internal/oplog/ticketer_gorm_test.go | 74 +-- internal/oplog/type_catalog_test.go | 133 +++-- internal/oplog/writer.go | 60 +-- internal/oplog/writer_test.go | 404 +++++++------- .../storage/oplog/v1/any_operation.proto | 22 +- 20 files changed, 1058 insertions(+), 721 deletions(-) create mode 100644 internal/oplog/options_test.go create mode 100644 internal/oplog/testing.go create mode 100644 internal/oplog/testing_test.go diff --git a/go.mod b/go.mod index b48f25639b..cb6dc68098 100644 --- a/go.mod +++ b/go.mod @@ -55,7 +55,7 @@ require ( github.com/spf13/viper v1.6.3 // indirect github.com/stretchr/testify v1.6.0 go.mongodb.org/mongo-driver v1.3.2 // indirect - golang.org/x/net v0.0.0-20200519113804-d87ec0cfa476 + golang.org/x/net v0.0.0-20200519113804-d87ec0cfa476 // indirect golang.org/x/sys v0.0.0-20200509044756-6aff5f38e54f // indirect golang.org/x/time v0.0.0-20200416051211-89c76fbcd5d1 golang.org/x/tools v0.0.0-20200519175826-7521f6f42533 @@ -64,5 +64,4 @@ require ( google.golang.org/protobuf v1.24.0 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/ini.v1 v1.55.0 // indirect - gotest.tools v2.2.0+incompatible ) diff --git a/go.sum b/go.sum index 75f28d6f4c..2a8f6cf9ed 100644 --- a/go.sum +++ b/go.sum @@ -575,8 +575,6 @@ github.com/hashicorp/consul/api v1.4.0 h1:jfESivXnO5uLdH650JU/6AnjRoHrLhULq0FnC3 github.com/hashicorp/consul/api v1.4.0/go.mod h1:xc8u05kyMa3Wjr9eEAsIAo3dg8+LywT5E/Cl7cNS5nU= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/consul/sdk v0.4.0/go.mod h1:fY08Y9z5SvJqevyZNy6WWPXiG3KwBPAvlcdx16zZ0fM= -github.com/hashicorp/dbassert v0.0.0-20200601131634-c3045f4c81e8 h1:0R3FvnfT4FC2abVevODTZ22zzhguKM2K5aufuysG0oo= -github.com/hashicorp/dbassert v0.0.0-20200601131634-c3045f4c81e8/go.mod h1:+B5eZq7vXqN4Gxkgjm4ub1bwG6OTcYzFL0r+bMvWorg= github.com/hashicorp/dbassert v0.0.0-20200602132714-30dfb78f0133 h1:NT2GV/LkzarLLGLYf/+HdqazyFEuuGDA7tOZ2Jyo+80= github.com/hashicorp/dbassert v0.0.0-20200602132714-30dfb78f0133/go.mod h1:+B5eZq7vXqN4Gxkgjm4ub1bwG6OTcYzFL0r+bMvWorg= github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= diff --git a/internal/db/common/update.go b/internal/db/common/update.go index dacbb19153..4205261397 100644 --- a/internal/db/common/update.go +++ b/internal/db/common/update.go @@ -35,7 +35,7 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin return nil, errors.New("both fieldMaskPaths and setToNullPaths are zero len") } - inter, maskPaths, nullPaths, err := intersection(fieldMaskPaths, setToNullPaths) + inter, maskPaths, nullPaths, err := Intersection(fieldMaskPaths, setToNullPaths) if err != nil { return nil, err } @@ -111,11 +111,11 @@ func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string { return notFound } -// intersection is a case-insensitive search for intersecting values. Returns -// []string of the intersection with values in lowercase, and map[string]string +// Intersection is a case-insensitive search for intersecting values. Returns +// []string of the Intersection with values in lowercase, and map[string]string // of the original av and bv, with the key set to uppercase and value set to the // original -func intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) { +func Intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) { if av == nil { return nil, nil, nil, fmt.Errorf("av is missing: %w", ErrNilParameter) } diff --git a/internal/db/common/update_test.go b/internal/db/common/update_test.go index 38ca0685c2..e1c69cf860 100644 --- a/internal/db/common/update_test.go +++ b/internal/db/common/update_test.go @@ -165,7 +165,7 @@ func Test_intersection(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - got, got1, got2, err := intersection(tt.args.av, tt.args.bv) + got, got1, got2, err := Intersection(tt.args.av, tt.args.bv) if err == nil && tt.wantErr { assert.Error(err) } diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index 13e8067c30..a51c97b685 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -284,7 +284,20 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string } rowsUpdated := int(underlying.RowsAffected) if withOplog && rowsUpdated > 0 { - if err := rw.addOplog(ctx, UpdateOp, opts, i); err != nil { + // we don't want to change the inbound slices in opts, so we'll make our + // own copy to pass to addOplog() + oplogFieldMasks := make([]string, len(fieldMaskPaths)) + copy(oplogFieldMasks, fieldMaskPaths) + oplogNullPaths := make([]string, len(setToNullPaths)) + copy(oplogNullPaths, setToNullPaths) + oplogOpts := Options{ + oplogOpts: opts.oplogOpts, + withOplog: opts.withOplog, + withDebug: opts.withDebug, + WithFieldMaskPaths: oplogFieldMasks, + WithNullPaths: oplogNullPaths, + } + if err := rw.addOplog(ctx, UpdateOp, oplogOpts, i); err != nil { return rowsUpdated, fmt.Errorf("update: add oplog failed %w", err) } } @@ -384,14 +397,19 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter if err != nil { return err } - var entryOp oplog.OpType + msg := oplog.Message{ + Message: i.(proto.Message), + TypeName: replayable.TableName(), + } switch opType { case CreateOp: - entryOp = oplog.OpType_OP_TYPE_CREATE + msg.OpType = oplog.OpType_OP_TYPE_CREATE case UpdateOp: - entryOp = oplog.OpType_OP_TYPE_UPDATE + msg.OpType = oplog.OpType_OP_TYPE_UPDATE + msg.FieldMaskPaths = opts.WithFieldMaskPaths + msg.SetToNullPaths = opts.WithNullPaths case DeleteOp: - entryOp = oplog.OpType_OP_TYPE_DELETE + msg.OpType = oplog.OpType_OP_TYPE_DELETE default: return fmt.Errorf("error operation type %v is not supported", opType) } @@ -399,7 +417,7 @@ func (rw *Db) addOplog(ctx context.Context, opType OpType, opts Options, i inter ctx, &oplog.GormWriter{Tx: gdb}, ticket, - &oplog.Message{Message: i.(proto.Message), TypeName: replayable.TableName(), OpType: entryOp}, + &msg, ) if err != nil { return fmt.Errorf("error creating oplog entry %w for WithOplog", err) diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 007d6685a6..8d7c8cac51 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -229,9 +229,8 @@ func TestDb_Update(t *testing.T) { assert.NotEqual(publicId, foundUser.PublicId) }) } - - assert := assert.New(t) t.Run("valid-WithOplog", func(t *testing.T) { + assert := assert.New(t) w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) @@ -264,6 +263,7 @@ func TestDb_Update(t *testing.T) { assert.NoError(err) }) t.Run("vet-for-write", func(t *testing.T) { + assert := assert.New(t) w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) @@ -294,6 +294,7 @@ func TestDb_Update(t *testing.T) { assert.Equal(0, rowsUpdated) }) t.Run("nil-tx", func(t *testing.T) { + assert := assert.New(t) w := Db{underlying: nil} id, err := uuid.GenerateUUID() assert.NoError(err) @@ -305,6 +306,7 @@ func TestDb_Update(t *testing.T) { assert.Equal("update: missing underlying db nil parameter", err.Error()) }) t.Run("no-wrapper-WithOplog", func(t *testing.T) { + assert := assert.New(t) w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) @@ -326,6 +328,7 @@ func TestDb_Update(t *testing.T) { assert.Equal("update: oplog validation failed error no wrapper WithOplog: nil parameter", err.Error()) }) t.Run("no-metadata-WithOplog", func(t *testing.T) { + assert := assert.New(t) w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) diff --git a/internal/oplog/any_operation.pb.go b/internal/oplog/any_operation.pb.go index b2f0518351..3f69e13469 100644 --- a/internal/oplog/any_operation.pb.go +++ b/internal/oplog/any_operation.pb.go @@ -26,15 +26,19 @@ const ( // of the legacy proto package is being used. const _ = proto.ProtoPackageIsVersion4 -// OpType provides the type of database operation the Any message represents (create, -// update, delete) +// OpType provides the type of database operation the Any message represents +// (create, update, delete) type OpType int32 const ( + // OP_TYPE_UNSPECIFIED defines an unspecified operation. OpType_OP_TYPE_UNSPECIFIED OpType = 0 - OpType_OP_TYPE_CREATE OpType = 1 - OpType_OP_TYPE_UPDATE OpType = 2 - OpType_OP_TYPE_DELETE OpType = 3 + // OP_TYPE_CREATE defines a create operation. + OpType_OP_TYPE_CREATE OpType = 1 + // OP_TYPE_UPDATE defines an update operation. + OpType_OP_TYPE_UPDATE OpType = 2 + // OP_TYPE_DELETE defines a delete operation. + OpType_OP_TYPE_DELETE OpType = 3 ) // Enum value maps for OpType. @@ -87,11 +91,16 @@ type AnyOperation struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - TypeName string `protobuf:"bytes,1,opt,name=type_name,json=typeName,proto3" json:"type_name,omitempty"` - Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + // type_name defines type of operation. + TypeName string `protobuf:"bytes,1,opt,name=type_name,json=typeName,proto3" json:"type_name,omitempty"` + // value are the bytes of a marshalled proto buff. + Value []byte `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + // operation_type defines the type of database operation. OperationType OpType `protobuf:"varint,3,opt,name=operation_type,json=operationType,proto3,enum=controller.storage.oplog.v1.OpType" json:"operation_type,omitempty"` - // mask is the mask of fields to update. + // field_mask is the mask of fields to update. FieldMask *field_mask.FieldMask `protobuf:"bytes,4,opt,name=field_mask,json=fieldMask,proto3" json:"field_mask,omitempty"` + // null_mask is the mask of fields to set to null. + NullMask *field_mask.FieldMask `protobuf:"bytes,5,opt,name=null_mask,json=nullMask,proto3" json:"null_mask,omitempty"` } func (x *AnyOperation) Reset() { @@ -154,6 +163,13 @@ func (x *AnyOperation) GetFieldMask() *field_mask.FieldMask { return nil } +func (x *AnyOperation) GetNullMask() *field_mask.FieldMask { + if x != nil { + return x.NullMask + } + return nil +} + var File_controller_storage_oplog_v1_any_operation_proto protoreflect.FileDescriptor var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{ @@ -164,7 +180,7 @@ var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{ 0x6f, 0x72, 0x61, 0x67, 0x65, 0x2e, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x2e, 0x76, 0x31, 0x1a, 0x20, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x22, 0xc8, 0x01, 0x0a, 0x0c, 0x41, 0x6e, 0x79, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, + 0x22, 0x81, 0x02, 0x0a, 0x0c, 0x41, 0x6e, 0x79, 0x4f, 0x70, 0x65, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x1b, 0x0a, 0x09, 0x74, 0x79, 0x70, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x74, 0x79, 0x70, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x76, @@ -176,17 +192,21 @@ var file_controller_storage_oplog_v1_any_operation_proto_rawDesc = []byte{ 0x12, 0x39, 0x0a, 0x0a, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, - 0x52, 0x09, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x2a, 0x5d, 0x0a, 0x06, 0x4f, - 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, 0x0a, 0x13, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, - 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, - 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, - 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x50, - 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, - 0x45, 0x5f, 0x44, 0x45, 0x4c, 0x45, 0x54, 0x45, 0x10, 0x03, 0x42, 0x36, 0x5a, 0x34, 0x67, 0x69, - 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, - 0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63, 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x3b, 0x6f, 0x70, 0x6c, - 0x6f, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x52, 0x09, 0x66, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x12, 0x37, 0x0a, 0x09, 0x6e, + 0x75, 0x6c, 0x6c, 0x5f, 0x6d, 0x61, 0x73, 0x6b, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, + 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2e, 0x46, 0x69, 0x65, 0x6c, 0x64, 0x4d, 0x61, 0x73, 0x6b, 0x52, 0x08, 0x6e, 0x75, 0x6c, 0x6c, + 0x4d, 0x61, 0x73, 0x6b, 0x2a, 0x5d, 0x0a, 0x06, 0x4f, 0x70, 0x54, 0x79, 0x70, 0x65, 0x12, 0x17, + 0x0a, 0x13, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, + 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, + 0x50, 0x45, 0x5f, 0x43, 0x52, 0x45, 0x41, 0x54, 0x45, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x4f, + 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x55, 0x50, 0x44, 0x41, 0x54, 0x45, 0x10, 0x02, 0x12, + 0x12, 0x0a, 0x0e, 0x4f, 0x50, 0x5f, 0x54, 0x59, 0x50, 0x45, 0x5f, 0x44, 0x45, 0x4c, 0x45, 0x54, + 0x45, 0x10, 0x03, 0x42, 0x36, 0x5a, 0x34, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, + 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x77, 0x61, 0x74, 0x63, + 0x68, 0x74, 0x6f, 0x77, 0x65, 0x72, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, + 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x3b, 0x6f, 0x70, 0x6c, 0x6f, 0x67, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -211,11 +231,12 @@ var file_controller_storage_oplog_v1_any_operation_proto_goTypes = []interface{} var file_controller_storage_oplog_v1_any_operation_proto_depIdxs = []int32{ 0, // 0: controller.storage.oplog.v1.AnyOperation.operation_type:type_name -> controller.storage.oplog.v1.OpType 2, // 1: controller.storage.oplog.v1.AnyOperation.field_mask:type_name -> google.protobuf.FieldMask - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 2, // 2: controller.storage.oplog.v1.AnyOperation.null_mask:type_name -> google.protobuf.FieldMask + 3, // [3:3] is the sub-list for method output_type + 3, // [3:3] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name } func init() { file_controller_storage_oplog_v1_any_operation_proto_init() } diff --git a/internal/oplog/oplog.go b/internal/oplog/oplog.go index 07eb56eeb8..fa1f8afd00 100644 --- a/internal/oplog/oplog.go +++ b/internal/oplog/oplog.go @@ -24,6 +24,7 @@ type Message struct { TypeName string OpType OpType FieldMaskPaths []string + SetToNullPaths []string } // Entry represents an oplog entry @@ -98,7 +99,7 @@ func (e *Entry) UnmarshalData(types *TypeCatalog) ([]Message, error) { Catalog: types, } for { - m, typ, fieldMaskPaths, err := queue.Remove() + m, typ, fieldMaskPaths, nullPaths, err := queue.Remove() if err == io.EOF { break } @@ -109,7 +110,7 @@ func (e *Entry) UnmarshalData(types *TypeCatalog) ([]Message, error) { if err != nil { return nil, fmt.Errorf("error getting TypeName: %w", err) } - msgs = append(msgs, Message{Message: m, TypeName: name, OpType: typ, FieldMaskPaths: fieldMaskPaths}) + msgs = append(msgs, Message{Message: m, TypeName: name, OpType: typ, FieldMaskPaths: fieldMaskPaths, SetToNullPaths: nullPaths}) } return msgs, nil } @@ -131,7 +132,7 @@ func (e *Entry) WriteEntryWith(ctx context.Context, tx Writer, ticket *store.Tic if m == nil { return errors.New("bad message") } - if err := queue.Add(m.Message, m.TypeName, m.OpType, WithFieldMaskPaths(m.FieldMaskPaths)); err != nil { + if err := queue.Add(m.Message, m.TypeName, m.OpType, WithFieldMaskPaths(m.FieldMaskPaths), WithSetToNullPaths(m.SetToNullPaths)); err != nil { return fmt.Errorf("error adding message to queue: %w", err) } } @@ -214,7 +215,9 @@ func (e *Entry) Replay(ctx context.Context, tx Writer, types *TypeCatalog, table */ replayTable := origTableName + tableSuffix if !tx.hasTable(replayTable) { - tx.createTableLike(origTableName, replayTable) + if err := tx.createTableLike(origTableName, replayTable); err != nil { + return fmt.Errorf("replay: %w", err) + } } em.SetTableName(replayTable) @@ -224,7 +227,7 @@ func (e *Entry) Replay(ctx context.Context, tx Writer, types *TypeCatalog, table return fmt.Errorf("replay error: %w", err) } case OpType_OP_TYPE_UPDATE: - if err := tx.Update(m.Message, m.FieldMaskPaths); err != nil { + if err := tx.Update(m.Message, m.FieldMaskPaths, m.SetToNullPaths); err != nil { return fmt.Errorf("replay error: %w", err) } case OpType_OP_TYPE_DELETE: diff --git a/internal/oplog/oplog_test.go b/internal/oplog/oplog_test.go index 15ba7c40b5..0824cbc80b 100644 --- a/internal/oplog/oplog_test.go +++ b/internal/oplog/oplog_test.go @@ -2,74 +2,54 @@ package oplog import ( "context" - "crypto/rand" - "database/sql" "fmt" - "os" - "strings" "testing" - wrapping "github.com/hashicorp/go-kms-wrapping" - "github.com/hashicorp/go-kms-wrapping/wrappers/aead" - "github.com/hashicorp/go-uuid" + dbassert "github.com/hashicorp/dbassert/gorm" + "github.com/hashicorp/watchtower/internal/oplog/oplog_test" "github.com/jinzhu/gorm" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" - "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" _ "github.com/golang-migrate/migrate/v4/source/file" - "github.com/ory/dockertest/v3" ) // setup the tests (initialize the database one-time and intialized testDatabaseURL) func setup(t *testing.T) (func(), *gorm.DB) { - cleanup := func() {} - var url string - var err error - - cleanup, url, err = initDbInDocker(t) - if err != nil { - t.Fatal(err) - } + t.Helper() + require := require.New(t) + cleanup, url, err := testInitDbInDocker(t) + require.NoError(err) db, err := gorm.Open("postgres", url) - if err != nil { - t.Fatal(err) - } + require.NoError(err) oplog_test.Init(db) - return cleanup, db } // Test_BasicOplog provides some basic unit tests for oplogs func Test_BasicOplog(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) t.Run("EncryptData/DecryptData/UnmarshalData", func(t *testing.T) { - cipherer := initWrapper(t) + assert, require := assert.New(t), require.New(t) + cipherer := testWrapper(t) // now let's us optimistic locking via a ticketing system for a serialized oplog ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(err) queue := Queue{Catalog: types} - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - - resp := db.Create(&user) - assert.NilError(t, resp.Error) + id := testId(t) + user := testUser(t, db, "foo-"+id, "", "") - err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + err = queue.Add(user, "user", OpType_OP_TYPE_CREATE) + require.NoError(err) l, err := NewEntry( "test-users", Metadata{ @@ -80,62 +60,55 @@ func Test_BasicOplog(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) l.Data = queue.Bytes() err = l.EncryptData(context.Background()) - assert.NilError(t, err) + require.NoError(err) - resp = db.Create(&l) - assert.NilError(t, resp.Error) - assert.Check(t, l.CreateTime != nil) - assert.Check(t, l.UpdateTime != nil) + err = db.Create(&l).Error + require.NoError(err) + assert.NotNil(l.CreateTime) + assert.NotNil(l.UpdateTime) entryId := l.Id var foundEntry Entry err = db.Where("id = ?", entryId).First(&foundEntry).Error - assert.NilError(t, err) + require.NoError(err) foundEntry.Cipherer = cipherer err = foundEntry.DecryptData(context.Background()) - assert.NilError(t, err) + require.NoError(err) foundUsers, err := foundEntry.UnmarshalData(types) - assert.NilError(t, err) - assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == user.Name) + require.NoError(err) + assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, user.Name) foundUsers, err = foundEntry.UnmarshalData(types) - assert.NilError(t, err) - assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == user.Name) + require.NoError(err) + assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, user.Name) }) t.Run("write entry", func(t *testing.T) { - cipherer := initWrapper(t) + assert, require := assert.New(t), require.New(t) + cipherer := testWrapper(t) // now let's us optimistic locking via a ticketing system for a serialized oplog ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(err) queue := Queue{Catalog: types} - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + id := testId(t) + user := testUser(t, db, "foo-"+id, "", "") - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - - resp := db.Create(&user) - assert.NilError(t, resp.Error) - - assert.NilError(t, err) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) queue = Queue{} - err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + err = queue.Add(user, "user", OpType_OP_TYPE_CREATE) + require.NoError(err) newLogEntry, err := NewEntry( "test-users", @@ -147,11 +120,11 @@ func Test_BasicOplog(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) newLogEntry.Data = queue.Bytes() err = newLogEntry.Write(context.Background(), &GormWriter{db}, ticket) - assert.NilError(t, err) - assert.Assert(t, newLogEntry.Id != 0) + require.NoError(err) + assert.NotEmpty(newLogEntry.Id) }) } @@ -159,12 +132,13 @@ func Test_BasicOplog(t *testing.T) { // Test_NewEntry provides some basic unit tests for NewEntry func Test_NewEntry(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) + t.Run("valid", func(t *testing.T) { - cipherer := initWrapper(t) + require := require.New(t) + cipherer := testWrapper(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) _, err = NewEntry( "test-users", @@ -176,12 +150,13 @@ func Test_NewEntry(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) }) t.Run("no metadata success", func(t *testing.T) { - cipherer := initWrapper(t) + require := require.New(t) + cipherer := testWrapper(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) _, err = NewEntry( "test-users", @@ -189,12 +164,13 @@ func Test_NewEntry(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) }) t.Run("no aggregateName", func(t *testing.T) { - cipherer := initWrapper(t) + assert, require := assert.New(t), require.New(t) + cipherer := testWrapper(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) _, err = NewEntry( "", @@ -206,11 +182,13 @@ func Test_NewEntry(t *testing.T) { cipherer, ticketer, ) - assert.Equal(t, err.Error(), "error creating entry: entry aggregate name is not set") + require.Error(err) + assert.Equal(err.Error(), "error creating entry: entry aggregate name is not set") }) t.Run("bad cipherer", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) _, err = NewEntry( "test-users", @@ -222,10 +200,12 @@ func Test_NewEntry(t *testing.T) { nil, ticketer, ) - assert.Equal(t, err.Error(), "error creating entry: entry Cipherer is nil") + require.Error(err) + assert.Equal(err.Error(), "error creating entry: entry Cipherer is nil") }) t.Run("bad ticket", func(t *testing.T) { - cipherer := initWrapper(t) + assert, require := assert.New(t), require.New(t) + cipherer := testWrapper(t) _, err := NewEntry( "test-users", Metadata{ @@ -236,34 +216,34 @@ func Test_NewEntry(t *testing.T) { cipherer, nil, ) - assert.Equal(t, err.Error(), "error creating entry: entry Ticketer is nil") + require.Error(err) + assert.Equal(err.Error(), "error creating entry: entry Ticketer is nil") }) } func Test_UnmarshalData(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) - cipherer := initWrapper(t) + cipherer := testWrapper(t) // now let's us optimistic locking via a ticketing system for a serialized oplog ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(t, err) - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + id := testId(t) t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) queue := Queue{Catalog: types} user := oplog_test.TestUser{ Name: "foo-" + id, } err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) entry, err := NewEntry( "test-users", Metadata{ @@ -274,18 +254,19 @@ func Test_UnmarshalData(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) entry.Data = queue.Bytes() marshaledUsers, err := entry.UnmarshalData(types) - assert.NilError(t, err) - assert.Assert(t, marshaledUsers[0].Message.(*oplog_test.TestUser).Name == user.Name) + require.NoError(err) + assert.Equal(marshaledUsers[0].Message.(*oplog_test.TestUser).Name, user.Name) take2marshaledUsers, err := entry.UnmarshalData(types) - assert.NilError(t, err) - assert.Assert(t, take2marshaledUsers[0].Message.(*oplog_test.TestUser).Name == user.Name) + require.NoError(err) + assert.Equal(take2marshaledUsers[0].Message.(*oplog_test.TestUser).Name, user.Name) }) t.Run("no data", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) queue := Queue{Catalog: types} entry, err := NewEntry( "test-users", @@ -297,20 +278,22 @@ func Test_UnmarshalData(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) entry.Data = queue.Bytes() _, err = entry.UnmarshalData(types) - assert.Equal(t, err.Error(), "no Data to unmarshal") + require.Error(err) + assert.Equal(err.Error(), "no Data to unmarshal") }) t.Run("nil types", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) queue := Queue{Catalog: types} user := oplog_test.TestUser{ Name: "foo-" + id, } err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) entry, err := NewEntry( "test-users", Metadata{ @@ -321,19 +304,21 @@ func Test_UnmarshalData(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) _, err = entry.UnmarshalData(nil) - assert.Equal(t, err.Error(), "TypeCatalog is nil") + require.Error(err) + assert.Equal(err.Error(), "TypeCatalog is nil") }) t.Run("missing type", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) queue := Queue{Catalog: types} user := oplog_test.TestUser{ Name: "foo-" + id, } err = queue.Add(&user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) entry, err := NewEntry( "test-users", Metadata{ @@ -344,40 +329,39 @@ func Test_UnmarshalData(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) entry.Data = queue.Bytes() types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "not-valid-name"}) - assert.NilError(t, err) + require.NoError(err) _, err = entry.UnmarshalData(types) - assert.Equal(t, err.Error(), "error removing item from queue: error getting the TypeName for Remove: error typeName is not found for Get") + require.Error(err) + assert.Equal(err.Error(), "error removing item from queue: error getting the TypeName for Remove: error typeName is not found for Get") }) } // Test_Replay provides some basic unit tests for replaying entries func Test_Replay(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) - cipherer := initWrapper(t) + cipherer := testWrapper(t) + id := testId(t) - id, err := uuid.GenerateUUID() - assert.NilError(t, err) // setup new tables for replay tableSuffix := "_" + id writer := GormWriter{Tx: db} - testUser := &oplog_test.TestUser{} - replayUserTable := fmt.Sprintf("%s%s", testUser.TableName(), tableSuffix) - defer func() { assert.NilError(t, writer.dropTableIfExists(replayUserTable)) }() + userModel := &oplog_test.TestUser{} + replayUserTable := fmt.Sprintf("%s%s", userModel.TableName(), tableSuffix) + defer func() { require.NoError(t, writer.dropTableIfExists(replayUserTable)) }() ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(t, err) t.Run("replay:create/update", func(t *testing.T) { - + assert, require := assert.New(t), require.New(t) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) newLogEntry, err := NewEntry( "test-users", @@ -389,93 +373,93 @@ func Test_Replay(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) - - tx := db.Begin() + require.NoError(err) - id3, err := uuid.GenerateUUID() - assert.NilError(t, err) - userName := "foo-" + id3 - // create a user that's replayable - userCreate := oplog_test.TestUser{ - Name: userName, - } - err = tx.Create(&userCreate).Error - assert.NilError(t, err) + tx := db + userName := "foo-" + testId(t) + userCreate := testUser(t, db, userName, "", "") userSave := oplog_test.TestUser{ Id: userCreate.Id, Name: userCreate.Name, Email: userName + "@hashicorp.com", } err = tx.Save(&userSave).Error - assert.NilError(t, err) + require.NoError(err) userUpdate := oplog_test.TestUser{ - Id: userCreate.Id, - PhoneNumber: "867-5309", + Id: userCreate.Id, } - err = tx.Model(&userUpdate).Updates(map[string]interface{}{"PhoneNumber": "867-5309"}).Error - assert.NilError(t, err) + err = tx.Model(&userUpdate).Updates(map[string]interface{}{"PhoneNumber": "867-5309", "Name": gorm.Expr("NULL")}).Error + require.NoError(err) + dbassert := dbassert.New(t, tx.DB(), "postgres") + dbassert.IsNull(&userUpdate, "Name") + + foundCreateUser := testFindUser(t, tx, userCreate.Id) + require.Equal(foundCreateUser.Id, userCreate.Id) + require.Equal(foundCreateUser.Name, "") + require.Equal(foundCreateUser.PhoneNumber, userUpdate.PhoneNumber) err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{tx}, ticket, - &Message{Message: &userCreate, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, - &Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE}, - &Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"PhoneNumber"}}, + &Message{Message: userCreate, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, + &Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"Name", "Email"}}, + &Message{Message: &userSave, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"Name", "Email"}, SetToNullPaths: nil}, + &Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, SetToNullPaths: []string{"Name"}}, + &Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: nil, SetToNullPaths: []string{"Name"}}, + &Message{Message: &userUpdate, TypeName: "user", OpType: OpType_OP_TYPE_UPDATE, FieldMaskPaths: []string{"PhoneNumber"}, SetToNullPaths: []string{"Name"}}, ) - assert.NilError(t, err) + require.NoError(err) types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(err) var foundEntry Entry err = tx.Where("id = ?", newLogEntry.Id).First(&foundEntry).Error - assert.NilError(t, err) + require.NoError(err) foundEntry.Cipherer = cipherer err = foundEntry.DecryptData(context.Background()) - assert.NilError(t, err) + require.NoError(err) err = foundEntry.Replay(context.Background(), &GormWriter{tx}, types, tableSuffix) - assert.NilError(t, err) - - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", userCreate.Id).First(&foundUser).Error - assert.NilError(t, err) + require.NoError(err) + foundUser := testFindUser(t, tx, userCreate.Id) var foundReplayedUser oplog_test.TestUser + foundReplayedUser.Table = foundReplayedUser.TableName() + tableSuffix err = tx.Where("id = ?", userCreate.Id).First(&foundReplayedUser).Error - assert.NilError(t, err) - - assert.Assert(t, foundUser.Id == foundReplayedUser.Id) - assert.Assert(t, foundUser.Name == foundReplayedUser.Name && foundUser.Name == userName) - assert.Assert(t, foundUser.PhoneNumber == foundReplayedUser.PhoneNumber && foundReplayedUser.PhoneNumber == "867-5309") - assert.Assert(t, foundUser.Email == foundReplayedUser.Email && foundReplayedUser.Email == userName+"@hashicorp.com") - - tx.Commit() + require.NoError(err) + + assert.Equal(foundUser.Id, foundReplayedUser.Id) + assert.Equal(foundUser.Name, foundReplayedUser.Name) + // assert.Equal(foundUser.Name, userName) + assert.Equal(foundUser.PhoneNumber, foundReplayedUser.PhoneNumber) + assert.Equal(foundReplayedUser.PhoneNumber, "867-5309") + assert.Equal(foundUser.Email, foundReplayedUser.Email) + assert.Equal(foundReplayedUser.Email, userName+"@hashicorp.com") }) t.Run("replay:delete", func(t *testing.T) { - + assert, require := assert.New(t), require.New(t) // we need to test delete replays now... tx2 := db.Begin() + defer tx2.Commit() ticket2, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) - id4, err := uuid.GenerateUUID() - assert.NilError(t, err) + id4 := testId(t) userName2 := "foo-" + id4 // create a user that's replayable userCreate2 := oplog_test.TestUser{ Name: userName2, } err = tx2.Create(&userCreate2).Error - assert.NilError(t, err) + require.NoError(err) deleteUser2 := oplog_test.TestUser{ Id: userCreate2.Id, } err = tx2.Delete(&deleteUser2).Error - assert.NilError(t, err) + require.NoError(err) newLogEntry2, err := NewEntry( "test-users", @@ -487,67 +471,63 @@ func Test_Replay(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) err = newLogEntry2.WriteEntryWith(context.Background(), &GormWriter{tx2}, ticket2, &Message{Message: &userCreate2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, &Message{Message: &deleteUser2, TypeName: "user", OpType: OpType_OP_TYPE_DELETE}, ) - assert.NilError(t, err) + require.NoError(err) var foundEntry2 Entry err = tx2.Where("id = ?", newLogEntry2.Id).First(&foundEntry2).Error - assert.NilError(t, err) + require.NoError(err) foundEntry2.Cipherer = cipherer err = foundEntry2.DecryptData(context.Background()) - assert.NilError(t, err) + require.NoError(err) types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(err) err = foundEntry2.Replay(context.Background(), &GormWriter{tx2}, types, tableSuffix) - assert.NilError(t, err) + require.NoError(err) var foundUser2 oplog_test.TestUser err = tx2.Where("id = ?", userCreate2.Id).First(&foundUser2).Error - assert.Assert(t, err == gorm.ErrRecordNotFound) + assert.Equal(err, gorm.ErrRecordNotFound) var foundReplayedUser2 oplog_test.TestUser err = tx2.Where("id = ?", userCreate2.Id).First(&foundReplayedUser2).Error - assert.Assert(t, err == gorm.ErrRecordNotFound) + assert.Equal(err, gorm.ErrRecordNotFound) - tx2.Commit() }) } // Test_WriteEntryWith provides unit tests for oplog.WriteEntryWith func Test_WriteEntryWith(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() - - cipherer := initWrapper(t) + defer testCleanup(t, cleanup, db) + cipherer := testWrapper(t) - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + id := testId(t) u := oplog_test.TestUser{ Name: "foo-" + id, } t.Log(&u) - id2, err := uuid.GenerateUUID() - assert.NilError(t, err) + id2 := testId(t) u2 := oplog_test.TestUser{ Name: "foo-" + id2, } t.Log(&u2) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(t, err) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(t, err) t.Run("successful", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) newLogEntry, err := NewEntry( "test-users", Metadata{ @@ -558,25 +538,26 @@ func Test_WriteEntryWith(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, ticket, &Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, &Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}) - assert.NilError(t, err) + require.NoError(err) var foundEntry Entry err = db.Where("id = ?", newLogEntry.Id).First(&foundEntry).Error - assert.NilError(t, err) + require.NoError(err) foundEntry.Cipherer = cipherer types, err := NewTypeCatalog(Type{new(oplog_test.TestUser), "user"}) - assert.NilError(t, err) + require.NoError(err) err = foundEntry.DecryptData(context.Background()) - assert.NilError(t, err) + require.NoError(err) foundUsers, err := foundEntry.UnmarshalData(types) - assert.NilError(t, err) - assert.Assert(t, foundUsers[0].Message.(*oplog_test.TestUser).Name == u.Name) + require.NoError(err) + assert.Equal(foundUsers[0].Message.(*oplog_test.TestUser).Name, u.Name) }) t.Run("nil writer", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) newLogEntry, err := NewEntry( "test-users", Metadata{ @@ -587,13 +568,15 @@ func Test_WriteEntryWith(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) err = newLogEntry.WriteEntryWith(context.Background(), nil, ticket, &Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, &Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}) - assert.Equal(t, err.Error(), "bad writer") + require.Error(err) + assert.Equal(err.Error(), "bad writer") }) t.Run("nil ticket", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) newLogEntry, err := NewEntry( "test-users", Metadata{ @@ -604,13 +587,15 @@ func Test_WriteEntryWith(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, nil, &Message{Message: &u, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}, &Message{Message: &u2, TypeName: "user", OpType: OpType_OP_TYPE_CREATE}) - assert.Equal(t, err.Error(), "bad ticket") + require.Error(err) + assert.Equal(err.Error(), "bad ticket") }) t.Run("nil ticket", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) newLogEntry, err := NewEntry( "test-users", Metadata{ @@ -621,37 +606,37 @@ func Test_WriteEntryWith(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) err = newLogEntry.WriteEntryWith(context.Background(), &GormWriter{db}, ticket, nil) - assert.Equal(t, err.Error(), "bad message") + require.Error(err) + assert.Equal(err.Error(), "bad message") }) } // Test_TicketSerialization provides unit tests for making sure oplog.Tickets properly serialize writes to oplog entries func Test_TicketSerialization(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) + assert, require := assert.New(t), require.New(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) - cipherer := initWrapper(t) + cipherer := testWrapper(t) - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + id := testId(t) firstTx := db.Begin() firstUser := oplog_test.TestUser{ Name: "foo-" + id, } err = firstTx.Create(&firstUser).Error - assert.NilError(t, err) + require.NoError(err) firstTicket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) firstQueue := Queue{} err = firstQueue.Add(&firstUser, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) firstLogEntry, err := NewEntry( "test-users", @@ -663,23 +648,22 @@ func Test_TicketSerialization(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) firstLogEntry.Data = firstQueue.Bytes() - id2, err := uuid.GenerateUUID() - assert.NilError(t, err) + id2 := testId(t) secondTx := db.Begin() secondUser := oplog_test.TestUser{ Name: "foo-" + id2, } err = secondTx.Create(&secondUser).Error - assert.NilError(t, err) + require.NoError(err) secondTicket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) secondQueue := Queue{} err = secondQueue.Add(&secondUser, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) secondLogEntry, err := NewEntry( "foobar", @@ -691,17 +675,17 @@ func Test_TicketSerialization(t *testing.T) { cipherer, ticketer, ) - assert.NilError(t, err) + require.NoError(err) secondLogEntry.Data = secondQueue.Bytes() err = secondLogEntry.Write(context.Background(), &GormWriter{secondTx}, secondTicket) - assert.NilError(t, err) - assert.Assert(t, secondLogEntry.Id != 0) - assert.Check(t, secondLogEntry.CreateTime != nil) - assert.Check(t, secondLogEntry.UpdateTime != nil) + require.NoError(err) + assert.NotEmpty(secondLogEntry.Id, 0) + assert.NotNil(secondLogEntry.CreateTime) + assert.NotNil(secondLogEntry.UpdateTime) err = secondTx.Commit().Error - assert.NilError(t, err) - assert.Assert(t, secondLogEntry.Id != 0) + require.NoError(err) + assert.NotNil(secondLogEntry.Id) err = firstLogEntry.Write(context.Background(), &GormWriter{firstTx}, firstTicket) if err != nil { @@ -711,89 +695,3 @@ func Test_TicketSerialization(t *testing.T) { t.Error("should have failed to write firstLogEntry") } } - -// initDbInDocker initializes postgres within dockertest for the unit tests -func initDbInDocker(t *testing.T) (cleanup func(), retURL string, err error) { - if os.Getenv("PG_URL") != "" { - initTestStore(t, func() {}, os.Getenv("PG_URL")) - return func() {}, os.Getenv("PG_URL"), nil - } - pool, err := dockertest.NewPool("") - if err != nil { - return func() {}, "", fmt.Errorf("could not connect to docker: %w", err) - } - - resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"}) - if err != nil { - return func() {}, "", fmt.Errorf("could not start resource: %w", err) - } - - c := func() { - cleanupResource(pool, resource) - } - - url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp")) - - if err := pool.Retry(func() error { - db, err := sql.Open("postgres", url) - if err != nil { - return fmt.Errorf("error opening postgres dev container: %w", err) - } - - if err := db.Ping(); err != nil { - return err - } - defer db.Close() - return nil - }); err != nil { - return func() {}, "", fmt.Errorf("could not connect to docker: %w", err) - } - initTestStore(t, c, url) - return c, url, nil -} - -// initWrapper initializes an AEAD wrapping.Wrapper for testing the oplog -func initWrapper(t *testing.T) wrapping.Wrapper { - rootKey := make([]byte, 32) - n, err := rand.Read(rootKey) - if err != nil { - t.Fatal(err) - } - if n != 32 { - t.Fatal(n) - } - root := aead.NewWrapper(nil) - if err := root.SetAESGCMKeyBytes(rootKey); err != nil { - t.Fatal(err) - } - return root -} - -// initTestStore will execute the migrations needed to initialize the store for tests -func initTestStore(t *testing.T, cleanup func(), url string) { - // run migrations - m, err := migrate.New("file://../db/migrations/postgres", url) - if err != nil { - cleanup() - t.Fatalf("Error creating migrations: %s", err) - } - if err := m.Up(); err != nil && err != migrate.ErrNoChange { - cleanup() - t.Fatalf("Error running migrations: %s", err) - } -} - -// cleanupResource will clean up the dockertest resources (postgres) -func cleanupResource(pool *dockertest.Pool, resource *dockertest.Resource) error { - var err error - for i := 0; i < 10; i++ { - err = pool.Purge(resource) - if err == nil { - return nil - } - } - if strings.Contains(err.Error(), "No such container") { - return nil - } - return fmt.Errorf("Failed to cleanup local container: %s", err) -} diff --git a/internal/oplog/options.go b/internal/oplog/options.go index 7c0ceb3e08..fbb6c4fdae 100644 --- a/internal/oplog/options.go +++ b/internal/oplog/options.go @@ -18,6 +18,7 @@ type Options map[string]interface{} func getDefaultOptions() Options { return Options{ optionWithFieldMaskPaths: []string{}, + optionWithSetToNullPaths: []string{}, optionWithAggregateNames: false, } } @@ -32,6 +33,16 @@ func WithFieldMaskPaths(fieldMaskPaths []string) Option { } } +const optionWithSetToNullPaths = "optionWithSetToNullPaths" + +// WithSetToNullPaths represents an optional set of symbolic field paths (for example: "f.a", "f.b.d") used +// to specify a subset of fields that should be set to null. (see google.golang.org/genproto/protobuf/field_mask) +func WithSetToNullPaths(setToNullPaths []string) Option { + return func(o Options) { + o[optionWithSetToNullPaths] = setToNullPaths + } +} + const optionWithAggregateNames = "optionWithAggregateNames" // WithAggregateNames enables/disables the use of multiple aggregate names for Ticketers diff --git a/internal/oplog/options_test.go b/internal/oplog/options_test.go new file mode 100644 index 0000000000..15a52591ec --- /dev/null +++ b/internal/oplog/options_test.go @@ -0,0 +1,31 @@ +package oplog + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// Test_GetOpts provides unit tests for GetOpts and all the options +func Test_GetOpts(t *testing.T) { + t.Parallel() + assert := assert.New(t) + t.Run("WithFieldMaskPaths", func(t *testing.T) { + opts := GetOpts(WithFieldMaskPaths([]string{"test"})) + testOpts := getDefaultOptions() + testOpts[optionWithFieldMaskPaths] = []string{"test"} + assert.Equal(opts, testOpts) + }) + t.Run("WithSetToNullPaths", func(t *testing.T) { + opts := GetOpts(WithSetToNullPaths([]string{"test"})) + testOpts := getDefaultOptions() + testOpts[optionWithSetToNullPaths] = []string{"test"} + assert.Equal(opts, testOpts) + }) + t.Run("WithAggregateNames", func(t *testing.T) { + opts := GetOpts(WithAggregateNames(true)) + testOpts := getDefaultOptions() + testOpts[optionWithAggregateNames] = true + assert.Equal(opts, testOpts) + }) +} diff --git a/internal/oplog/queue.go b/internal/oplog/queue.go index ed12e68ee7..ea5e667c0c 100644 --- a/internal/oplog/queue.go +++ b/internal/oplog/queue.go @@ -7,6 +7,7 @@ import ( "fmt" "sync" + common "github.com/hashicorp/watchtower/internal/db/common" "google.golang.org/genproto/protobuf/field_mask" "google.golang.org/protobuf/proto" ) @@ -15,18 +16,24 @@ import ( type Queue struct { // Buffer for the queue bytes.Buffer + // Catalog provides a TypeCatalog for the types added to the queue Catalog *TypeCatalog mx sync.Mutex } -// Add pb message to queue -func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) error { +// Add message to queue. typeName defines the type of message added to the +// queue and allows the msg to be removed using a TypeCatalog with a +// coresponding typeName entry. OpType defines the msg's operation (create, add, +// update, etc). If OpType == OpType_OP_TYPE_UPDATE, the WithFieldMaskPaths() +// and SetToNullPaths() options are supported. +func (q *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) error { // we're not checking the Catalog for nil, since it's not used // when Adding messages to the queue opts := GetOpts(opt...) - withPaths := opts[optionWithFieldMaskPaths].([]string) + withFieldMasks := opts[optionWithFieldMaskPaths].([]string) + withNullPaths := opts[optionWithSetToNullPaths].([]string) if _, ok := m.(ReplayableMessage); !ok { return fmt.Errorf("error %T is not a ReplayableMessage", m) @@ -35,23 +42,45 @@ func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) e if err != nil { return fmt.Errorf("error marshaling add parameter: %w", err) } + if t == OpType_OP_TYPE_UPDATE { + if len(withFieldMasks) == 0 && len(withNullPaths) == 0 { + return fmt.Errorf("queue add: missing field masks or null paths for update") + } + fMasks := withFieldMasks + if fMasks == nil { + fMasks = []string{} + } + nullPaths := withNullPaths + if nullPaths == nil { + nullPaths = []string{} + } + i, _, _, err := common.Intersection(fMasks, nullPaths) + if err != nil { + return fmt.Errorf("queue add: %w", err) + } + if len(i) != 0 { + return fmt.Errorf("queue add: field masks and null paths intersect with: %s", i) + } + + } msg := &AnyOperation{ TypeName: typeName, Value: value, OperationType: t, - FieldMask: &field_mask.FieldMask{Paths: withPaths}, + FieldMask: &field_mask.FieldMask{Paths: withFieldMasks}, + NullMask: &field_mask.FieldMask{Paths: withNullPaths}, } data, err := proto.Marshal(msg) if err != nil { return fmt.Errorf("error marhaling the msg for Add: %w", err) } - r.mx.Lock() - defer r.mx.Unlock() - err = binary.Write(r, binary.LittleEndian, uint32(len(data))) + q.mx.Lock() + defer q.mx.Unlock() + err = binary.Write(q, binary.LittleEndian, uint32(len(data))) if err != nil { return err } - n, err := r.Write(data) + n, err := q.Write(data) if err != nil { return fmt.Errorf("error writing to queue buffer: %w", err) } @@ -61,34 +90,49 @@ func (r *Queue) Add(m proto.Message, typeName string, t OpType, opt ...Option) e return nil } -// Remove pb message from the queue and EOF if empty -func (r *Queue) Remove() (proto.Message, OpType, []string, error) { - if r.Catalog == nil { - return nil, OpType_OP_TYPE_UNSPECIFIED, nil, errors.New("remove Catalog is nil") +// Remove pb message from the queue and EOF if empty. It also returns the OpType +// for the msg and if it's OpType_OP_TYPE_UPDATE, the it will also return the +// fieldMask and setToNullPaths for the update operation. +func (q *Queue) Remove() (proto.Message, OpType, []string, []string, error) { + if q.Catalog == nil { + return nil, OpType_OP_TYPE_UNSPECIFIED, nil, nil, errors.New("remove Catalog is nil") } - r.mx.Lock() - defer r.mx.Unlock() + q.mx.Lock() + defer q.mx.Unlock() var n uint32 - err := binary.Read(r, binary.LittleEndian, &n) + err := binary.Read(q, binary.LittleEndian, &n) if err != nil { - return nil, 0, nil, err // intentionally not wrapping error so client can test for sentinel EOF error + return nil, 0, nil, nil, err // intentionally not wrapping error so client can test for sentinel EOF error } - data := r.Next(int(n)) + data := q.Next(int(n)) msg := new(AnyOperation) err = proto.Unmarshal(data, msg) if err != nil { - return nil, 0, nil, fmt.Errorf("error marshaling the msg for Remove: %w", err) + return nil, 0, nil, nil, fmt.Errorf("error marshaling the msg for Remove: %w", err) } if msg.Value == nil { - return nil, 0, nil, nil + return nil, 0, nil, nil, nil } - any, err := r.Catalog.Get(msg.TypeName) + any, err := q.Catalog.Get(msg.TypeName) if err != nil { - return nil, 0, nil, fmt.Errorf("error getting the TypeName for Remove: %w", err) + return nil, 0, nil, nil, fmt.Errorf("error getting the TypeName for Remove: %w", err) } pm := any.(proto.Message) if err = proto.Unmarshal(msg.Value, pm); err != nil { - return nil, 0, nil, fmt.Errorf("error unmarshaling the value for Remove: %w", err) + return nil, 0, nil, nil, fmt.Errorf("error unmarshaling the value for Remove: %w", err) + } + var masks, nullPaths []string + if msg.OperationType == OpType_OP_TYPE_UPDATE { + if msg.FieldMask != nil { + masks = msg.FieldMask.GetPaths() + } + if msg.NullMask != nil { + nullPaths = msg.NullMask.GetPaths() + } + if len(masks) == 0 && len(nullPaths) == 0 { + return nil, 0, nil, nil, errors.New("error unmarshaling the value for Remove: field mask or null paths is required") + } + } - return pm, msg.OperationType, msg.FieldMask.GetPaths(), nil + return pm, msg.OperationType, masks, nullPaths, nil } diff --git a/internal/oplog/queue_test.go b/internal/oplog/queue_test.go index 95531b86be..f428fc8a80 100644 --- a/internal/oplog/queue_test.go +++ b/internal/oplog/queue_test.go @@ -1,12 +1,12 @@ package oplog import ( - reflect "reflect" "testing" "github.com/hashicorp/watchtower/internal/oplog/oplog_test" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" - "gotest.tools/assert" ) // Test_Queue provides basic tests for the Queue type @@ -17,7 +17,7 @@ func Test_Queue(t *testing.T) { Type{new(oplog_test.TestCar), "car"}, Type{new(oplog_test.TestRental), "rental"}, ) - assert.NilError(t, err) + require.NoError(t, err) queue := Queue{Catalog: types} user := &oplog_test.TestUser{ @@ -41,64 +41,92 @@ func Test_Queue(t *testing.T) { Name: "bob", Email: "bob@alice.com", } + + userNilUpdate := &oplog_test.TestUser{ + Id: 1, + Name: "bob", + } t.Run("add/remove", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) err = queue.Add(user, "user", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) err = queue.Add(car, "car", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) err = queue.Add(rental, "rental", OpType_OP_TYPE_CREATE) - assert.NilError(t, err) + require.NoError(err) err = queue.Add(userUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name", "Email"})) - assert.NilError(t, err) + require.NoError(err) + err = queue.Add(userNilUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name"}), WithSetToNullPaths([]string{"Email"})) + require.NoError(err) + + queuedUser, ty, fm, nm, err := queue.Remove() + require.NoError(err) + assert.True(proto.Equal(user, queuedUser)) + assert.Equal(ty, OpType_OP_TYPE_CREATE) + assert.Nil(fm) + assert.Nil(nm) - queuedUser, ty, fm, err := queue.Remove() - assert.NilError(t, err) - assert.Assert(t, proto.Equal(user, queuedUser)) - assert.Assert(t, ty == OpType_OP_TYPE_CREATE) - assert.Assert(t, fm == nil) + queuedCar, ty, fm, nm, err := queue.Remove() + require.NoError(err) + assert.True(proto.Equal(car, queuedCar)) + assert.Equal(ty, OpType_OP_TYPE_CREATE) + assert.Nil(fm) + assert.Nil(nm) - queuedCar, ty, fm, err := queue.Remove() - assert.NilError(t, err) - assert.Assert(t, proto.Equal(car, queuedCar)) - assert.Assert(t, ty == OpType_OP_TYPE_CREATE) - assert.Assert(t, fm == nil) + queuedRental, ty, fm, nm, err := queue.Remove() + require.NoError(err) + assert.True(proto.Equal(rental, queuedRental)) + assert.Equal(ty, OpType_OP_TYPE_CREATE) + assert.Nil(fm) + assert.Nil(nm) - queuedRental, ty, fm, err := queue.Remove() - assert.NilError(t, err) - assert.Assert(t, proto.Equal(rental, queuedRental)) - assert.Assert(t, ty == OpType_OP_TYPE_CREATE) - assert.Assert(t, fm == nil) + queuedUserUpdate, ty, fm, nm, err := queue.Remove() + require.NoError(err) + assert.True(proto.Equal(userUpdate, queuedUserUpdate)) + assert.Equal(ty, OpType_OP_TYPE_UPDATE) + assert.Equal(fm, []string{"Name", "Email"}) + assert.Nil(nm) - queuedUserUpdate, ty, fm, err := queue.Remove() - assert.NilError(t, err) - assert.Assert(t, proto.Equal(userUpdate, queuedUserUpdate)) - assert.Assert(t, ty == OpType_OP_TYPE_UPDATE) - assert.Assert(t, reflect.DeepEqual(fm, []string{"Name", "Email"})) + queuedUserNilUpdate, ty, fm, nm, err := queue.Remove() + require.NoError(err) + assert.True(proto.Equal(userNilUpdate, queuedUserNilUpdate)) + assert.Equal(ty, OpType_OP_TYPE_UPDATE) + assert.Equal(fm, []string{"Name"}) + assert.Equal(nm, []string{"Email"}) }) t.Run("valid with nil type catalog", func(t *testing.T) { + require := require.New(t) queue := Queue{} - assert.NilError(t, queue.Add(user, "user", OpType_OP_TYPE_CREATE)) + require.NoError(queue.Add(user, "user", OpType_OP_TYPE_CREATE)) }) t.Run("error with nil type catalog", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) queue := Queue{} - _, _, _, err := queue.Remove() - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "remove Catalog is nil") + _, _, _, _, err = queue.Remove() + require.Error(err) + assert.Equal(err.Error(), "remove Catalog is nil") }) t.Run("not replayable", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) u := &oplog_test.TestNonReplayableUser{ Name: "Alice", PhoneNumber: "867-5309", Email: "alice@bob.com", } err = queue.Add(u, "user", OpType_OP_TYPE_CREATE) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error *oplog_test.TestNonReplayableUser is not a ReplayableMessage") + require.Error(err) + assert.Equal(err.Error(), "error *oplog_test.TestNonReplayableUser is not a ReplayableMessage") }) t.Run("nil message", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) err = queue.Add(nil, "user", OpType_OP_TYPE_CREATE) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error is not a ReplayableMessage") + require.Error(err) + assert.Equal(err.Error(), "error is not a ReplayableMessage") + }) + t.Run("missing both field mask and null paths for update", func(t *testing.T) { + require.Error(t, queue.Add(userUpdate, "user", OpType_OP_TYPE_UPDATE)) + }) + t.Run("intersecting field masks and null paths", func(t *testing.T) { + require.Error(t, queue.Add(userNilUpdate, "user", OpType_OP_TYPE_UPDATE, WithFieldMaskPaths([]string{"Name"}), WithSetToNullPaths([]string{"Name"}))) }) - } diff --git a/internal/oplog/testing.go b/internal/oplog/testing.go new file mode 100644 index 0000000000..d59a1e8c4f --- /dev/null +++ b/internal/oplog/testing.go @@ -0,0 +1,135 @@ +package oplog + +import ( + "crypto/rand" + "database/sql" + "fmt" + "os" + "strings" + "testing" + + "github.com/golang-migrate/migrate/v4" + wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-kms-wrapping/wrappers/aead" + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/watchtower/internal/oplog/oplog_test" + "github.com/jinzhu/gorm" + "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testCleanup(t *testing.T, cleanupFunc func(), db *gorm.DB) { + t.Helper() + cleanupFunc() + err := db.Close() + assert.NoError(t, err) +} + +func testUser(t *testing.T, db *gorm.DB, name, phoneNumber, email string) *oplog_test.TestUser { + t.Helper() + u := &oplog_test.TestUser{ + Name: name, + PhoneNumber: phoneNumber, + Email: email, + } + w := GormWriter{db} + + err := w.Create(&u) + require.NoError(t, err) + return u +} + +func testFindUser(t *testing.T, db *gorm.DB, userId uint32) *oplog_test.TestUser { + t.Helper() + var foundUser oplog_test.TestUser + err := db.Where("id = ?", userId).First(&foundUser).Error + require.NoError(t, err) + return &foundUser +} + +func testId(t *testing.T) string { + t.Helper() + id, err := uuid.GenerateUUID() + require.NoError(t, err) + return id +} + +// testInitDbInDocker initializes postgres within dockertest for the unit tests +func testInitDbInDocker(t *testing.T) (cleanup func(), retURL string, err error) { + t.Helper() + if os.Getenv("PG_URL") != "" { + testInitStore(t, func() {}, os.Getenv("PG_URL")) + return func() {}, os.Getenv("PG_URL"), nil + } + pool, err := dockertest.NewPool("") + require.NoErrorf(t, err, "could not connect to docker: %w", err) + + resource, err := pool.Run("postgres", "latest", []string{"POSTGRES_PASSWORD=secret", "POSTGRES_DB=watchtower"}) + require.NoErrorf(t, err, "could not start resource: %w", err) + + c := func() { + err := testCleanupResource(t, pool, resource) + assert.NoError(t, err) + } + + url := fmt.Sprintf("postgres://postgres:secret@localhost:%s?sslmode=disable", resource.GetPort("5432/tcp")) + + err = pool.Retry(func() error { + db, err := sql.Open("postgres", url) + if err != nil { + return fmt.Errorf("error opening postgres dev container: %w", err) + } + + if err := db.Ping(); err != nil { + return err + } + defer db.Close() + return nil + }) + require.NoErrorf(t, err, "could not connect to docker: %w", err) + testInitStore(t, c, url) + return c, url, nil +} + +// testWrapper initializes an AEAD wrapping.Wrapper for testing the oplog +func testWrapper(t *testing.T) wrapping.Wrapper { + t.Helper() + rootKey := make([]byte, 32) + n, err := rand.Read(rootKey) + require.NoError(t, err) + require.Equal(t, n, 32) + root := aead.NewWrapper(nil) + err = root.SetAESGCMKeyBytes(rootKey) + require.NoError(t, err) + return root +} + +// testInitStore will execute the migrations needed to initialize the store for tests +func testInitStore(t *testing.T, cleanup func(), url string) { + t.Helper() + // run migrations + m, err := migrate.New("file://../db/migrations/postgres", url) + require.NoError(t, err, "Error creating migrations") + + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + cleanup() + require.NoError(t, err, "Error running migrations") + } +} + +// testCleanupResource will clean up the dockertest resources (postgres) +func testCleanupResource(t *testing.T, pool *dockertest.Pool, resource *dockertest.Resource) error { + t.Helper() + var err error + for i := 0; i < 10; i++ { + err = pool.Purge(resource) + if err == nil { + return nil + } + } + if strings.Contains(err.Error(), "No such container") { + return nil + } + return fmt.Errorf("Failed to cleanup local container: %s", err) +} diff --git a/internal/oplog/testing_test.go b/internal/oplog/testing_test.go new file mode 100644 index 0000000000..b5d1030a52 --- /dev/null +++ b/internal/oplog/testing_test.go @@ -0,0 +1,87 @@ +package oplog + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_testCleanup(t *testing.T) { + assert := assert.New(t) + cleanup, db := setup(t) + testCleanup(t, cleanup, db) + + err := db.Begin().Error + assert.Equal("sql: database is closed", err.Error()) +} + +func Test_testUser(t *testing.T) { + assert, require := assert.New(t), require.New(t) + cleanup, db := setup(t) + defer testCleanup(t, cleanup, db) + + id := testId(t) + + u := testUser(t, db, id, id, id) + require.NotNil(u) + assert.Equal(id, u.Name) + assert.Equal(id, u.PhoneNumber) + assert.Equal(id, u.Email) +} + +func Test_testFindUser(t *testing.T) { + assert, require := assert.New(t), require.New(t) + cleanup, db := setup(t) + defer testCleanup(t, cleanup, db) + id := testId(t) + u := testUser(t, db, id, id, id) + require.NotNil(u) + + found := testFindUser(t, db, u.Id) + require.NotNil(found) + assert.Equal(u, found) +} + +func Test_testId(t *testing.T) { + require := require.New(t) + id := testId(t) + require.NotNil(id) +} + +func Test_testInitDbInDocker(t *testing.T) { + require := require.New(t) + cleanup, url, err := testInitDbInDocker(t) + defer cleanup() + require.NoError(err) + require.NotEmpty(url) + require.NotNil(cleanup) +} + +func Test_testWrapper(t *testing.T) { + w := testWrapper(t) + require.NotNil(t, w) + assert.Equal(t, "aead", w.Type()) +} + +func Test_testInitStore(t *testing.T) { + assert, require := assert.New(t), require.New(t) + cleanup, url, err := testInitDbInDocker(t) + require.NoError(err) + defer cleanup() + require.NotEmpty(url) + + testInitStore(t, cleanup, url) + + const query = ` +select count(*) from information_schema."tables" t where table_name = 'db_test_user'; +` + db, err := sql.Open("postgres", url) + require.NoError(err) + + var cnt int + err = db.QueryRow(query).Scan(&cnt) + require.NoError(err) + assert.Equal(1, cnt) +} diff --git a/internal/oplog/ticketer_gorm_test.go b/internal/oplog/ticketer_gorm_test.go index 68fa3d5cc3..8fb084ff74 100644 --- a/internal/oplog/ticketer_gorm_test.go +++ b/internal/oplog/ticketer_gorm_test.go @@ -3,94 +3,108 @@ package oplog import ( "testing" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Test_NewGormTicketer provides unit tests for creating a Gorm ticketer func Test_NewGormTicketer(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) + t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) - assert.Assert(t, ticketer != nil) + require.NoError(err) + assert.NotNil(ticketer) }) t.Run("bad db", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) _, err := NewGormTicketer(nil, WithAggregateNames(true)) - assert.Equal(t, err.Error(), "tx is nil") + require.Error(err) + assert.Equal(err.Error(), "tx is nil") }) } // Test_GetTicket provides unit tests for getting oplog.Tickets func Test_GetTicket(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) + ticketer, err := NewGormTicketer(db, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(t, err) t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) - assert.Equal(t, ticket.Name, "default") - assert.Check(t, ticket.Version != 0) + require.NoError(err) + assert.Equal(ticket.Name, "default") + assert.NotEqual(ticket.Version, 0) }) t.Run("no name", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) ticket, err := ticketer.GetTicket("") - assert.Equal(t, err.Error(), "bad ticket name") - assert.Check(t, ticket == nil) + require.Error(err) + assert.Equal(err.Error(), "bad ticket name") + assert.Nil(ticket) }) } // Test_Redeem provides unit tests for redeeming tickets func Test_Redeem(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) + t.Run("valid", func(t *testing.T) { + require := require.New(t) + tx := db.Begin() + defer tx.Commit() ticketer, err := NewGormTicketer(tx, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) err = ticketer.Redeem(ticket) - assert.NilError(t, err) - tx.Commit() + require.NoError(err) }) - t.Run("nil ticket", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + tx := db.Begin() + defer tx.Commit() + ticketer, err := NewGormTicketer(tx, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) err = ticketer.Redeem(nil) - assert.Equal(t, err.Error(), "ticket is nil") - tx.Commit() + require.Error(err) + assert.Equal(err.Error(), "ticket is nil") }) t.Run("detect two redemptions in separate concurrent transactions", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + tx := db.Begin() ticketer, err := NewGormTicketer(tx, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) ticket, err := ticketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) secondTx := db.Begin() secondTicketer, err := NewGormTicketer(secondTx, WithAggregateNames(true)) - assert.NilError(t, err) + require.NoError(err) secondTicket, err := secondTicketer.GetTicket("default") - assert.NilError(t, err) + require.NoError(err) err = ticketer.Redeem(ticket) - assert.NilError(t, err) + require.NoError(err) tx.Commit() err = secondTicketer.Redeem(secondTicket) - assert.Equal(t, err.Error(), "ticket already redeemed") + assert.Equal(err.Error(), "ticket already redeemed") }) } diff --git a/internal/oplog/type_catalog_test.go b/internal/oplog/type_catalog_test.go index 60ff099238..16e011c97f 100644 --- a/internal/oplog/type_catalog_test.go +++ b/internal/oplog/type_catalog_test.go @@ -5,34 +5,35 @@ import ( "testing" "github.com/hashicorp/watchtower/internal/oplog/oplog_test" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Test_TypeCatalog provides basic red/green unit tests func Test_TypeCatalog(t *testing.T) { t.Parallel() + assert, require := assert.New(t), require.New(t) types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, Type{new(oplog_test.TestCar), "car"}, Type{new(oplog_test.TestRental), "rental"}, ) - assert.NilError(t, err) + require.NoError(err) name, err := types.GetTypeName(new(oplog_test.TestUser)) - assert.NilError(t, err) - assert.Assert(t, name == "user") + require.NoError(err) + assert.Equal(name, "user") _, err = types.GetTypeName(oplog_test.TestUser{}) - assert.Assert(t, err != nil) + assert.Error(err) s := "string" _, err = types.GetTypeName(&s) - assert.Assert(t, err != nil) + assert.Error(err) _, err = types.Get("unknown") - assert.Assert(t, err != nil) - + assert.Error(err) } // Test_NewTypeCatalog provides unit tests for NewTypeCatalog @@ -40,39 +41,47 @@ func Test_NewTypeCatalog(t *testing.T) { t.Parallel() t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, Type{new(oplog_test.TestCar), "car"}, Type{new(oplog_test.TestRental), "rental"}, ) - assert.NilError(t, err) + require.NoError(err) u, err := types.Get("user") - assert.NilError(t, err) - assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) + require.NoError(err) + assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) }) t.Run("missing Type.Name", func(t *testing.T) { + assert := assert.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), ""}, ) - assert.Check(t, types == nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error setting the type: typeName is an empty string for Set (in NewTypeCatalog)") + assert.Nil(types) + assert.Error(err) + assert.Equal(err.Error(), "error setting the type: typeName is an empty string for Set (in NewTypeCatalog)") }) t.Run("missing Type.Interface", func(t *testing.T) { + assert := assert.New(t) + types, err := NewTypeCatalog( Type{nil, ""}, ) - assert.Check(t, types == nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error type is {} (in NewTypeCatalog)") + assert.Nil(types) + assert.Error(err) + assert.Equal(err.Error(), "error type is {} (in NewTypeCatalog)") }) t.Run("empty Type", func(t *testing.T) { + assert := assert.New(t) + types, err := NewTypeCatalog( Type{}, ) - assert.Check(t, types == nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error type is {} (in NewTypeCatalog)") + assert.Nil(types) + assert.Error(err) + assert.Equal(err.Error(), "error type is {} (in NewTypeCatalog)") }) } @@ -82,37 +91,43 @@ func Test_GetTypeName(t *testing.T) { t.Parallel() t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, Type{new(oplog_test.TestCar), "car"}, ) - assert.NilError(t, err) + require.NoError(err) n, err := types.GetTypeName(new(oplog_test.TestUser)) - assert.NilError(t, err) - assert.Equal(t, n, "user") + require.NoError(err) + assert.Equal(n, "user") }) t.Run("bad name", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, ) - assert.NilError(t, err) + require.NoError(err) n, err := types.GetTypeName(new(oplog_test.TestCar)) - assert.Check(t, err != nil) - assert.Equal(t, n, "") - assert.Equal(t, err.Error(), "error unknown name for interface: *oplog_test.TestCar") + require.Error(err) + assert.Equal(n, "") + assert.Equal(err.Error(), "error unknown name for interface: *oplog_test.TestCar") }) t.Run("nil interface", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, ) - assert.NilError(t, err) + require.NoError(err) n, err := types.GetTypeName(nil) - assert.Check(t, err != nil) - assert.Equal(t, n, "") - assert.Equal(t, err.Error(), "error interface parameter is nil for GetTypeName") + require.Error(err) + assert.Equal(n, "") + assert.Equal(err.Error(), "error interface parameter is nil for GetTypeName") }) } @@ -121,37 +136,43 @@ func Test_Get(t *testing.T) { t.Parallel() t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, Type{new(oplog_test.TestCar), "car"}, ) - assert.NilError(t, err) + require.NoError(err) u, err := types.Get("user") - assert.NilError(t, err) - assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) + require.NoError(err) + assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) }) t.Run("bad name", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, ) - assert.NilError(t, err) + require.NoError(err) n, err := types.Get("car") - assert.Check(t, err != nil) - assert.Equal(t, n, nil) - assert.Equal(t, err.Error(), "error typeName is not found for Get") + require.Error(err) + assert.Equal(n, nil) + assert.Equal(err.Error(), "error typeName is not found for Get") }) t.Run("bad typeName", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog( Type{new(oplog_test.TestUser), "user"}, ) - assert.NilError(t, err) + require.NoError(err) n, err := types.Get("") - assert.Check(t, err != nil) - assert.Equal(t, n, nil) - assert.Equal(t, err.Error(), "error typeName is empty string for Get") + require.Error(err) + assert.Equal(n, nil) + assert.Equal(err.Error(), "error typeName is empty string for Get") }) } @@ -160,28 +181,34 @@ func Test_Set(t *testing.T) { t.Parallel() t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog() - assert.NilError(t, err) + require.NoError(err) u := new(oplog_test.TestUser) err = types.Set(u, "user") - assert.NilError(t, err) - assert.Equal(t, reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) + require.NoError(err) + assert.Equal(reflect.TypeOf(u), reflect.TypeOf(new(oplog_test.TestUser))) }) t.Run("bad name", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog() - assert.NilError(t, err) + require.NoError(err) u := new(oplog_test.TestUser) err = types.Set(u, "") - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "typeName is an empty string for Set") - assert.Assert(t, reflect.DeepEqual(types, &TypeCatalog{})) + require.Error(err) + assert.Equal(err.Error(), "typeName is an empty string for Set") + assert.Equal(types, &TypeCatalog{}) }) t.Run("bad interface", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + types, err := NewTypeCatalog() - assert.NilError(t, err) + require.NoError(err) err = types.Set(nil, "") - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "error interface parameter is nil for Set") - assert.Assert(t, reflect.DeepEqual(types, &TypeCatalog{})) + require.Error(err) + assert.Equal(err.Error(), "error interface parameter is nil for Set") + assert.Equal(types, &TypeCatalog{}) }) } diff --git a/internal/oplog/writer.go b/internal/oplog/writer.go index 11a8eb12e1..6581c8086b 100644 --- a/internal/oplog/writer.go +++ b/internal/oplog/writer.go @@ -3,8 +3,8 @@ package oplog import ( "errors" "fmt" - "reflect" - "strings" + + "github.com/hashicorp/watchtower/internal/db/common" "github.com/jinzhu/gorm" ) @@ -14,8 +14,11 @@ type Writer interface { // Create the entry Create(interface{}) error - // Update the entry using the fieldMaskPaths, which are Paths from field_mask.proto - Update(entry interface{}, fieldMaskPaths []string) error + // Update the entry using the fieldMaskPaths and setNullPaths, which are + // Paths from field_mask.proto. fieldMaskPaths is required. setToNullPaths + // is optional. fieldMaskPaths and setNullPaths cannot instersect and both + // cannot be zero len. + Update(entry interface{}, fieldMaskPaths, setToNullPaths []string) error // Delete the entry Delete(interface{}) error @@ -23,7 +26,8 @@ type Writer interface { // HasTable checks if tableName exists hasTable(tableName string) bool - // CreateTableLike will create a newTableName using the existing table as a starting point + // CreateTableLike will create a newTableName using the existing table as a + // starting point createTableLike(existingTableName string, newTableName string) error // DropTableIfExists will drop the table if it exists @@ -49,45 +53,25 @@ func (w *GormWriter) Create(i interface{}) error { return nil } -// Update an object in storage, if there's a fieldMask then only the field_mask.proto paths are updated, otherwise -// we will send every field to the DB. -func (w *GormWriter) Update(i interface{}, fieldMaskPaths []string) error { +// Update the entry using the fieldMaskPaths and setNullPaths, which are +// Paths from field_mask.proto. fieldMaskPaths is required. setToNullPaths is +// optional. fieldMaskPaths and setNullPaths cannot intersect and both cannot be +// zero len. +func (w *GormWriter) Update(i interface{}, fieldMaskPaths, setToNullPaths []string) error { if w.Tx == nil { return errors.New("update Tx is nil") } if i == nil { return errors.New("update interface is nil") } - if len(fieldMaskPaths) == 0 { - if err := w.Tx.Save(i).Error; err != nil { - return fmt.Errorf("error updating: %w", err) - } - } - updateFields := map[string]interface{}{} - - val := reflect.Indirect(reflect.ValueOf(i)) - structTyp := val.Type() - for _, field := range fieldMaskPaths { - for i := 0; i < structTyp.NumField(); i++ { - // support for an embedded a gorm type - if structTyp.Field(i).Type.Kind() == reflect.Struct { - embType := structTyp.Field(i).Type - // check if the embedded field is exported via CanInterface() - if val.Field(i).CanInterface() { - embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface())) - for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ { - if strings.EqualFold(embType.Field(embFieldNum).Name, field) { - updateFields[field] = embVal.Field(embFieldNum).Interface() - } - } - continue - } - } - // it's not an embedded type, so check if the field name matches - if strings.EqualFold(structTyp.Field(i).Name, field) { - updateFields[field] = val.Field(i).Interface() - } - } + if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { + return errors.New("update both fieldMaskPaths and setToNullPaths are missing") + } + // common.UpdateFields will also check to ensure that fieldMaskPaths and + // setToNullPaths do not intersect. + updateFields, err := common.UpdateFields(i, fieldMaskPaths, setToNullPaths) + if err != nil { + return fmt.Errorf("error updating: unable to build update fields %w", err) } if err := w.Tx.Model(i).Updates(updateFields).Error; err != nil { return fmt.Errorf("error updating: %w", err) diff --git a/internal/oplog/writer_test.go b/internal/oplog/writer_test.go index cbcc5bea67..9c3d8e4fe9 100644 --- a/internal/oplog/writer_test.go +++ b/internal/oplog/writer_test.go @@ -3,302 +3,322 @@ package oplog import ( "testing" - "github.com/hashicorp/go-uuid" + dbassert "github.com/hashicorp/dbassert/gorm" "github.com/hashicorp/watchtower/internal/oplog/oplog_test" "github.com/jinzhu/gorm" - "gotest.tools/assert" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) // Test_GormWriterCreate provides unit tests for GormWriter Create func Test_GormWriterCreate(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) tx := db.Begin() defer tx.Rollback() w := GormWriter{tx} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - user := oplog_test.TestUser{ - Name: "foo-" + id, + Name: "foo-" + testId(t), } - assert.NilError(t, w.Create(&user)) + require.NoError(w.Create(&user)) - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.NilError(t, err) - assert.Equal(t, user.Name, foundUser.Name) + foundUser := testFindUser(t, tx, user.Id) + assert.Equal(user.Name, foundUser.Name) }) t.Run("nil tx", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{nil} - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + err := w.Create(&oplog_test.TestUser{}) + require.Error(err) - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - err = w.Create(&user) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "create Tx is nil") + assert.Equal(err.Error(), "create Tx is nil") }) t.Run("nil model", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) tx := db.Begin() defer tx.Rollback() w := GormWriter{tx} err := w.Create(nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "create interface is nil") + require.Error(err) + assert.Equal(err.Error(), "create interface is nil") }) } // Test_GormWriterDelete provides unit tests for GormWriter Delete func Test_GormWriterDelete(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) t.Run("valid", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) tx := db.Begin() defer tx.Rollback() w := GormWriter{tx} - id, err := uuid.GenerateUUID() - assert.NilError(t, err) + id := testId(t) + user := testUser(t, db, "foo-"+id, "", "") + foundUser := testFindUser(t, db, user.Id) - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - assert.NilError(t, w.Create(&user)) - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.NilError(t, err) - - assert.NilError(t, w.Delete(&user)) - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.Check(t, err != nil) - assert.Equal(t, err, gorm.ErrRecordNotFound) + require.NoError(w.Delete(&user)) + err := tx.Where("id = ?", user.Id).First(&foundUser).Error + require.Error(err) + assert.Equal(err, gorm.ErrRecordNotFound) }) t.Run("nil tx", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{nil} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - err = w.Delete(&user) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "delete Tx is nil") + err := w.Delete(&oplog_test.TestUser{}) + require.Error(err) + assert.Equal(err.Error(), "delete Tx is nil") }) t.Run("nil model", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) tx := db.Begin() defer tx.Rollback() w := GormWriter{tx} err := w.Delete(nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "delete interface is nil") - }) -} - -// Test_GormWriterUpdate provides unit tests for GormWriter Update -func Test_GormWriterUpdate(t *testing.T) { - cleanup, db := setup(t) - defer cleanup() - defer db.Close() - t.Run("valid no fieldmask", func(t *testing.T) { - tx := db.Begin() - defer tx.Rollback() - w := GormWriter{tx} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - assert.NilError(t, w.Create(&user)) - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.NilError(t, err) - - foundUser.Name = foundUser.Name + "_updated" - assert.NilError(t, w.Update(&foundUser, nil)) - - var updatedUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&updatedUser).Error - assert.NilError(t, err) - assert.Equal(t, foundUser.Name, updatedUser.Name) - }) - t.Run("valid with fieldmask", func(t *testing.T) { - tx := db.Begin() - defer tx.Rollback() - w := GormWriter{tx} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - assert.NilError(t, w.Create(&user)) - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.NilError(t, err) - - foundUser.Name = foundUser.Name + "_updated" - assert.NilError(t, w.Update(&foundUser, []string{"Name"})) - - var updatedUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&updatedUser).Error - assert.NilError(t, err) - assert.Equal(t, foundUser.Name, updatedUser.Name) - }) - t.Run("valid with incorrect fieldmask", func(t *testing.T) { - tx := db.Begin() - defer tx.Rollback() - w := GormWriter{tx} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - assert.NilError(t, w.Create(&user)) - var foundUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&foundUser).Error - assert.NilError(t, err) - - foundUser.Name = foundUser.Name + "_updated" - assert.NilError(t, w.Update(&foundUser, []string{"PhoneNumber"})) - - var updatedUser oplog_test.TestUser - err = tx.Where("id = ?", user.Id).First(&updatedUser).Error - assert.NilError(t, err) - assert.Check(t, updatedUser.Name != foundUser.Name+"_updated") - }) - t.Run("nil tx", func(t *testing.T) { - w := GormWriter{nil} - - id, err := uuid.GenerateUUID() - assert.NilError(t, err) - - user := oplog_test.TestUser{ - Name: "foo-" + id, - } - err = w.Update(&user, nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "update Tx is nil") - }) - t.Run("nil model", func(t *testing.T) { - tx := db.Begin() - defer tx.Rollback() - w := GormWriter{tx} - err := w.Update(nil, nil) - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "update interface is nil") + require.Error(err) + assert.Equal(err.Error(), "delete interface is nil") }) } // Test_GormWriterHasTable provides unit tests for GormWriter HasTable func Test_GormWriterHasTable(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) w := GormWriter{Tx: db} t.Run("success", func(t *testing.T) { + assert := require.New(t) ok := w.hasTable("oplog_test_user") - assert.Equal(t, ok, true) + assert.Equal(ok, true) }) t.Run("no table", func(t *testing.T) { - badTableName, err := uuid.GenerateUUID() - assert.NilError(t, err) + assert := assert.New(t) + badTableName := testId(t) ok := w.hasTable(badTableName) - assert.Equal(t, ok, false) + assert.Equal(ok, false) }) t.Run("blank table name", func(t *testing.T) { + assert := assert.New(t) ok := w.hasTable("") - assert.Equal(t, ok, false) + assert.Equal(ok, false) }) } // Test_GormWriterCreateTable provides unit tests for GormWriter CreateTable func Test_GormWriterCreateTable(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) t.Run("success", func(t *testing.T) { + assert := assert.New(t) w := GormWriter{Tx: db} - suffix, err := uuid.GenerateUUID() - assert.NilError(t, err) + suffix := testId(t) u := &oplog_test.TestUser{} newTableName := u.TableName() + "_" + suffix - defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }() - err = w.createTableLike(u.TableName(), newTableName) - assert.NilError(t, err) + defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }() + err := w.createTableLike(u.TableName(), newTableName) + assert.NoError(err) }) t.Run("call twice", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{Tx: db} - suffix, err := uuid.GenerateUUID() - assert.NilError(t, err) + suffix := testId(t) u := &oplog_test.TestUser{} newTableName := u.TableName() + "_" + suffix - defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }() - err = w.createTableLike(u.TableName(), newTableName) - assert.NilError(t, err) + defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }() + err := w.createTableLike(u.TableName(), newTableName) + require.NoError(err) // should be an error to create the same table twice err = w.createTableLike(u.TableName(), newTableName) - assert.Check(t, err != nil) - assert.Error(t, err, err.Error(), nil) + require.Error(err) + assert.Error(err, err.Error(), nil) }) t.Run("empty existing", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{Tx: db} - suffix, err := uuid.GenerateUUID() - assert.NilError(t, err) + suffix := testId(t) u := &oplog_test.TestUser{} newTableName := u.TableName() + "_" + suffix - defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }() - err = w.createTableLike("", newTableName) - assert.Check(t, err != nil) - assert.Error(t, err, err.Error(), nil) - assert.Equal(t, err.Error(), "error existingTableName is empty string") + defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }() + err := w.createTableLike("", newTableName) + require.Error(err) + assert.Error(err, err.Error(), nil) + assert.Equal(err.Error(), "error existingTableName is empty string") }) t.Run("blank name", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{Tx: db} u := &oplog_test.TestUser{} err := w.createTableLike(u.TableName(), "") - assert.Check(t, err != nil) - assert.Error(t, err, err.Error(), nil) - assert.Equal(t, err.Error(), "error newTableName is empty string") + require.Error(err) + assert.Error(err, err.Error(), nil) + assert.Equal(err.Error(), "error newTableName is empty string") }) } // Test_GormWriterDropTableIfExists provides unit tests for GormWriter DropTableIfExists func Test_GormWriterDropTableIfExists(t *testing.T) { cleanup, db := setup(t) - defer cleanup() - defer db.Close() + defer testCleanup(t, cleanup, db) t.Run("success", func(t *testing.T) { + assert := assert.New(t) w := GormWriter{Tx: db} - suffix, err := uuid.GenerateUUID() - assert.NilError(t, err) + suffix := testId(t) + u := &oplog_test.TestUser{} newTableName := u.TableName() + "_" + suffix - err = w.createTableLike(u.TableName(), newTableName) - assert.NilError(t, err) - defer func() { assert.NilError(t, w.dropTableIfExists(newTableName)) }() + + assert.NoError(w.createTableLike(u.TableName(), newTableName)) + defer func() { assert.NoError(w.dropTableIfExists(newTableName)) }() }) t.Run("success with blank", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) w := GormWriter{Tx: db} err := w.dropTableIfExists("") - assert.Check(t, err != nil) - assert.Equal(t, err.Error(), "cannot drop table whose name is an empty string") + require.Error(err) + assert.Equal(err.Error(), "cannot drop table whose name is an empty string") + }) +} + +func TestGormWriter_Update(t *testing.T) { + cleanup, db := setup(t) + defer testCleanup(t, cleanup, db) + id := testId(t) + type fields struct { + Tx *gorm.DB + } + type args struct { + user *oplog_test.TestUser + fieldMaskPaths []string + setToNullPaths []string + } + tests := []struct { + name string + Tx *gorm.DB + args args + wantUser *oplog_test.TestUser + wantErr bool + }{ + { + name: "valid-fieldmask", + Tx: db, + args: args{ + user: &oplog_test.TestUser{ + Name: "valid-fieldmask", + }, + fieldMaskPaths: []string{"name"}, + }, + wantUser: &oplog_test.TestUser{ + Name: "valid-fieldmask", + Email: id, + PhoneNumber: id, + }, + wantErr: false, + }, + { + name: "valid-setToNull", + Tx: db, + args: args{ + user: &oplog_test.TestUser{ + Name: "valid-setToNull", + }, + fieldMaskPaths: nil, + setToNullPaths: []string{"name"}, + }, + wantUser: &oplog_test.TestUser{ + Name: "", + Email: id, + PhoneNumber: id, + }, + wantErr: false, + }, + { + name: "valid-setToNull-and-fieldMask", + Tx: db, + args: args{ + user: &oplog_test.TestUser{ + Email: "valid-setToNull-and-fieldMask", + }, + fieldMaskPaths: []string{"email"}, + setToNullPaths: []string{"name"}, + }, + wantUser: &oplog_test.TestUser{ + Name: "", + Email: "valid-setToNull-and-fieldMask", + PhoneNumber: id, + }, + wantErr: false, + }, + { + name: "no-field-mask", + Tx: db, + args: args{ + user: &oplog_test.TestUser{ + Name: "no-field-mask", + }, + fieldMaskPaths: []string{""}, + }, + wantErr: true, + }, + { + name: "nil-field-mask", + Tx: db, + args: args{ + user: &oplog_test.TestUser{ + Name: "nil-field-mask", + }, + fieldMaskPaths: nil, + }, + wantErr: true, + }, + { + name: "nil-tx", + Tx: nil, + args: args{ + user: &oplog_test.TestUser{ + Name: "nil-txt", + }, + fieldMaskPaths: []string{"name"}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require, dbassert := assert.New(t), require.New(t), dbassert.New(t, db.DB(), "postgres") + + w := &GormWriter{ + Tx: tt.Tx, + } + u := testUser(t, db, id, id, id) // intentionally, not relying on tt.Tx + u.Name = tt.args.user.Name + u.Email = tt.args.user.Email + u.PhoneNumber = tt.args.user.PhoneNumber + err := w.Update(tt.args.user, tt.args.fieldMaskPaths, tt.args.setToNullPaths) + if tt.wantErr { + require.Error(err) + return + } + require.NoError(err) + var foundUser oplog_test.TestUser + err = db.Where("id = ?", u.Id).First(&foundUser).Error + require.NoError(err) + tt.wantUser.Id = u.Id + assert.Equal(tt.wantUser, &foundUser) + for _, f := range tt.args.setToNullPaths { + dbassert.IsNull(u, f) + } + }) + } + t.Run("nil model", func(t *testing.T) { + assert := assert.New(t) + w := GormWriter{db} + err := w.Create(nil) + assert.Error(err) }) } diff --git a/internal/proto/local/controller/storage/oplog/v1/any_operation.proto b/internal/proto/local/controller/storage/oplog/v1/any_operation.proto index 5cc426a123..d655ccfd94 100644 --- a/internal/proto/local/controller/storage/oplog/v1/any_operation.proto +++ b/internal/proto/local/controller/storage/oplog/v1/any_operation.proto @@ -5,21 +5,37 @@ import "google/protobuf/field_mask.proto"; package controller.storage.oplog.v1; option go_package = "github.com/hashicorp/watchtower/internal/oplog;oplog"; -// OpType provides the type of database operation the Any message represents (create, -// update, delete) +// OpType provides the type of database operation the Any message represents +// (create, update, delete) enum OpType { + // OP_TYPE_UNSPECIFIED defines an unspecified operation. OP_TYPE_UNSPECIFIED = 0; + + // OP_TYPE_CREATE defines a create operation. OP_TYPE_CREATE = 1; + + // OP_TYPE_UPDATE defines an update operation. OP_TYPE_UPDATE = 2; + + // OP_TYPE_DELETE defines a delete operation. OP_TYPE_DELETE = 3; } // AnyOperation provides a message for anything and the type of operation it // represents. message AnyOperation { + // type_name defines type of operation. string type_name = 1; + + // value are the bytes of a marshalled proto buff. bytes value = 2; + + // operation_type defines the type of database operation. OpType operation_type = 3; - // mask is the mask of fields to update. + + // field_mask is the mask of fields to update. google.protobuf.FieldMask field_mask = 4; + + // null_mask is the mask of fields to set to null. + google.protobuf.FieldMask null_mask = 5; } \ No newline at end of file