diff --git a/internal/db/common/update.go b/internal/db/common/update.go index 71b11e3fd3..2a76a6f599 100644 --- a/internal/db/common/update.go +++ b/internal/db/common/update.go @@ -46,6 +46,16 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin val := reflect.Indirect(reflect.ValueOf(i)) structTyp := val.Type() for i := 0; i < structTyp.NumField(); i++ { + if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { + updateFields[f] = val.Field(i).Interface() + found[strings.ToUpper(f)] = struct{}{} + continue + } + if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { + updateFields[f] = gorm.Expr("NULL") + found[strings.ToUpper(f)] = struct{}{} + continue + } kind := structTyp.Field(i).Type.Kind() if kind == reflect.Struct || kind == reflect.Ptr { embType := structTyp.Field(i).Type @@ -76,14 +86,6 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin continue } } - if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { - updateFields[f] = val.Field(i).Interface() - found[strings.ToUpper(f)] = struct{}{} - } - if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { - updateFields[f] = gorm.Expr("NULL") - found[strings.ToUpper(f)] = struct{}{} - } } if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 { diff --git a/internal/db/common/update_test.go b/internal/db/common/update_test.go index 4da3a16c67..22893f06e5 100644 --- a/internal/db/common/update_test.go +++ b/internal/db/common/update_test.go @@ -4,9 +4,13 @@ import ( "testing" "github.com/hashicorp/boundary/internal/db/db_test" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/go-secure-stdlib/base62" "github.com/hashicorp/go-uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) @@ -398,6 +402,20 @@ func TestUpdateFields(t *testing.T) { assert.Equal(tt.want, got) }) } + t.Run("valid-embedded-timestamp", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + wantTs := ×tamp.Timestamp{ + Timestamp: ×tamppb.Timestamp{ + Seconds: 1, + Nanos: 1, + }, + } + u := testUser(t, "", "") + u.UpdateTime = wantTs + got, err := UpdateFields(u, []string{"UpdateTime"}, nil) + require.NoError(err) + assert.True(proto.Equal(wantTs, got["UpdateTime"].(*timestamp.Timestamp))) + }) } func testUser(t *testing.T, name, email string) *db_test.TestUser { diff --git a/internal/db/db_test/db.go b/internal/db/db_test/db.go index 0d641b2a1e..12218629ac 100644 --- a/internal/db/db_test/db.go +++ b/internal/db/db_test/db.go @@ -2,15 +2,19 @@ package db_test import ( + "errors" + "github.com/hashicorp/go-secure-stdlib/base62" "google.golang.org/protobuf/proto" ) const ( - defaultUserTablename = "db_test_user" - defaultCarTableName = "db_test_car" - defaultRentalTableName = "db_test_rental" - defaultScooterTableName = "db_test_scooter" + defaultUserTablename = "db_test_user" + defaultCarTableName = "db_test_car" + defaultRentalTableName = "db_test_rental" + defaultScooterTableName = "db_test_scooter" + defaultAccessoryTableName = "db_test_accessory" + defaultScooterAccessoryTableName = "db_test_scooter_accessory" ) type TestUser struct { @@ -153,6 +157,71 @@ func (t *TestScooter) SetTableName(name string) { t.table = name } +type TestAccessory struct { + *StoreTestAccessory + table string `gorm:"-"` +} + +func NewTestAccessory(description string) (*TestAccessory, error) { + if description == "" { + return nil, errors.New("missing description") + } + return &TestAccessory{StoreTestAccessory: &StoreTestAccessory{Description: description}}, nil +} + +func (t *TestAccessory) Clone() interface{} { + s := proto.Clone(t.StoreTestAccessory) + return &TestAccessory{ + StoreTestAccessory: s.(*StoreTestAccessory), + } +} + +func (t *TestAccessory) TableName() string { + if t.table != "" { + return t.table + } + return defaultAccessoryTableName +} + +func (t *TestAccessory) SetTableName(name string) { + t.table = name +} + +type TestScooterAccessory struct { + *StoreTestScooterAccessory + table string `gorm:"-"` +} + +func NewTestScooterAccessory(scooterId, accessoryId uint32) (*TestScooterAccessory, error) { + if accessoryId == 0 { + return nil, errors.New("mssing accessory id") + } + return &TestScooterAccessory{ + StoreTestScooterAccessory: &StoreTestScooterAccessory{ + ScooterId: scooterId, + AccessoryId: accessoryId, + }, + }, nil +} + +func (t *TestScooterAccessory) Clone() interface{} { + s := proto.Clone(t.StoreTestScooterAccessory) + return &TestScooterAccessory{ + StoreTestScooterAccessory: s.(*StoreTestScooterAccessory), + } +} + +func (t *TestScooterAccessory) TableName() string { + if t.table != "" { + return t.table + } + return defaultScooterAccessoryTableName +} + +func (t *TestScooterAccessory) SetTableName(name string) { + t.table = name +} + type Cloner interface { Clone() interface{} } diff --git a/internal/db/db_test/db_test.pb.go b/internal/db/db_test/db_test.pb.go index b097025ad7..5a0d958975 100644 --- a/internal/db/db_test/db_test.pb.go +++ b/internal/db/db_test/db_test.pb.go @@ -453,6 +453,133 @@ func (x *StoreTestScooter) GetMpg() int32 { return 0 } +// StoreTestAccessory used in the db tests only and provides a gorm resource with +// an id that's not a private or public id +type StoreTestAccessory struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // @inject_tag: gorm:"primary_key" + AccessoryId uint32 `protobuf:"varint,1,opt,name=accessory_id,json=accessoryId,proto3" json:"accessory_id,omitempty" gorm:"primary_key"` + // @inject_tag: `gorm:"default:not_null"` + Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty" gorm:"default:not_null"` +} + +func (x *StoreTestAccessory) Reset() { + *x = StoreTestAccessory{} + if protoimpl.UnsafeEnabled { + mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StoreTestAccessory) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoreTestAccessory) ProtoMessage() {} + +func (x *StoreTestAccessory) ProtoReflect() protoreflect.Message { + mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoreTestAccessory.ProtoReflect.Descriptor instead. +func (*StoreTestAccessory) Descriptor() ([]byte, []int) { + return file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP(), []int{4} +} + +func (x *StoreTestAccessory) GetAccessoryId() uint32 { + if x != nil { + return x.AccessoryId + } + return 0 +} + +func (x *StoreTestAccessory) GetDescription() string { + if x != nil { + return x.Description + } + return "" +} + +// StoreTestScooterAccessory used in the db tests only and provides a gorm +// resource with multiple pks +type StoreTestScooterAccessory struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // @inject_tag: gorm:"primary_key" + AccessoryId uint32 `protobuf:"varint,1,opt,name=accessory_id,json=accessoryId,proto3" json:"accessory_id,omitempty" gorm:"primary_key"` + // @inject_tag: gorm:"primary_key" + ScooterId uint32 `protobuf:"varint,2,opt,name=scooter_id,json=scooterId,proto3" json:"scooter_id,omitempty" gorm:"primary_key"` + // @inject_tag: `gorm:"default:null"` + Review string `protobuf:"bytes,3,opt,name=review,proto3" json:"review,omitempty" gorm:"default:null"` +} + +func (x *StoreTestScooterAccessory) Reset() { + *x = StoreTestScooterAccessory{} + if protoimpl.UnsafeEnabled { + mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *StoreTestScooterAccessory) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StoreTestScooterAccessory) ProtoMessage() {} + +func (x *StoreTestScooterAccessory) ProtoReflect() protoreflect.Message { + mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StoreTestScooterAccessory.ProtoReflect.Descriptor instead. +func (*StoreTestScooterAccessory) Descriptor() ([]byte, []int) { + return file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP(), []int{5} +} + +func (x *StoreTestScooterAccessory) GetAccessoryId() uint32 { + if x != nil { + return x.AccessoryId + } + return 0 +} + +func (x *StoreTestScooterAccessory) GetScooterId() uint32 { + if x != nil { + return x.ScooterId + } + return 0 +} + +func (x *StoreTestScooterAccessory) GetReview() string { + if x != nil { + return x.Review + } + return "" +} + var File_controller_storage_db_db_test_v1_db_test_proto protoreflect.FileDescriptor var file_controller_storage_db_db_test_v1_db_test_proto_rawDesc = []byte{ @@ -535,12 +662,25 @@ var file_controller_storage_db_db_test_v1_db_test_proto_rawDesc = []byte{ 0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, 0x65, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x70, - 0x67, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x70, 0x67, 0x42, 0x3b, 0x5a, 0x39, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, - 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e, - 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x64, 0x62, 0x2f, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73, - 0x74, 0x3b, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x67, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x70, 0x67, 0x22, 0x59, 0x0a, 0x12, + 0x53, 0x74, 0x6f, 0x72, 0x65, 0x54, 0x65, 0x73, 0x74, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, + 0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, 0x79, 0x5f, + 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, + 0x6f, 0x72, 0x79, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, + 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63, + 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x75, 0x0a, 0x19, 0x53, 0x74, 0x6f, 0x72, 0x65, + 0x54, 0x65, 0x73, 0x74, 0x53, 0x63, 0x6f, 0x6f, 0x74, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73, + 0x73, 0x6f, 0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, + 0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, + 0x73, 0x73, 0x6f, 0x72, 0x79, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x63, 0x6f, 0x6f, 0x74, + 0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x73, 0x63, 0x6f, + 0x6f, 0x74, 0x65, 0x72, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, + 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, 0x42, 0x3b, + 0x5a, 0x39, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, + 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, + 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x64, 0x62, 0x2f, 0x64, 0x62, 0x5f, 0x74, + 0x65, 0x73, 0x74, 0x3b, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x33, } var ( @@ -555,23 +695,25 @@ func file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP() []byte { return file_controller_storage_db_db_test_v1_db_test_proto_rawDescData } -var file_controller_storage_db_db_test_v1_db_test_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_controller_storage_db_db_test_v1_db_test_proto_msgTypes = make([]protoimpl.MessageInfo, 6) var file_controller_storage_db_db_test_v1_db_test_proto_goTypes = []interface{}{ - (*StoreTestUser)(nil), // 0: controller.storage.db.db_test.v1.StoreTestUser - (*StoreTestCar)(nil), // 1: controller.storage.db.db_test.v1.StoreTestCar - (*StoreTestRental)(nil), // 2: controller.storage.db.db_test.v1.StoreTestRental - (*StoreTestScooter)(nil), // 3: controller.storage.db.db_test.v1.StoreTestScooter - (*timestamp.Timestamp)(nil), // 4: controller.storage.timestamp.v1.Timestamp + (*StoreTestUser)(nil), // 0: controller.storage.db.db_test.v1.StoreTestUser + (*StoreTestCar)(nil), // 1: controller.storage.db.db_test.v1.StoreTestCar + (*StoreTestRental)(nil), // 2: controller.storage.db.db_test.v1.StoreTestRental + (*StoreTestScooter)(nil), // 3: controller.storage.db.db_test.v1.StoreTestScooter + (*StoreTestAccessory)(nil), // 4: controller.storage.db.db_test.v1.StoreTestAccessory + (*StoreTestScooterAccessory)(nil), // 5: controller.storage.db.db_test.v1.StoreTestScooterAccessory + (*timestamp.Timestamp)(nil), // 6: controller.storage.timestamp.v1.Timestamp } var file_controller_storage_db_db_test_v1_db_test_proto_depIdxs = []int32{ - 4, // 0: controller.storage.db.db_test.v1.StoreTestUser.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 1: controller.storage.db.db_test.v1.StoreTestUser.update_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 2: controller.storage.db.db_test.v1.StoreTestCar.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 3: controller.storage.db.db_test.v1.StoreTestCar.update_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 4: controller.storage.db.db_test.v1.StoreTestRental.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 5: controller.storage.db.db_test.v1.StoreTestRental.update_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 6: controller.storage.db.db_test.v1.StoreTestScooter.create_time:type_name -> controller.storage.timestamp.v1.Timestamp - 4, // 7: controller.storage.db.db_test.v1.StoreTestScooter.update_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 0: controller.storage.db.db_test.v1.StoreTestUser.create_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 1: controller.storage.db.db_test.v1.StoreTestUser.update_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 2: controller.storage.db.db_test.v1.StoreTestCar.create_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 3: controller.storage.db.db_test.v1.StoreTestCar.update_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 4: controller.storage.db.db_test.v1.StoreTestRental.create_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 5: controller.storage.db.db_test.v1.StoreTestRental.update_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 6: controller.storage.db.db_test.v1.StoreTestScooter.create_time:type_name -> controller.storage.timestamp.v1.Timestamp + 6, // 7: controller.storage.db.db_test.v1.StoreTestScooter.update_time:type_name -> controller.storage.timestamp.v1.Timestamp 8, // [8:8] is the sub-list for method output_type 8, // [8:8] is the sub-list for method input_type 8, // [8:8] is the sub-list for extension type_name @@ -633,6 +775,30 @@ func file_controller_storage_db_db_test_v1_db_test_proto_init() { return nil } } + file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StoreTestAccessory); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*StoreTestScooterAccessory); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } type x struct{} out := protoimpl.TypeBuilder{ @@ -640,7 +806,7 @@ func file_controller_storage_db_db_test_v1_db_test_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_controller_storage_db_db_test_v1_db_test_proto_rawDesc, NumEnums: 0, - NumMessages: 4, + NumMessages: 6, NumExtensions: 0, NumServices: 0, }, diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index e942147ae3..b4a78faec5 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -43,9 +43,11 @@ const ( // Reader interface defines lookups/searching for resources type Reader interface { // LookupById will lookup a resource by its primary key id, which must be - // unique. The resourceWithIder must implement either ResourcePublicIder or - // ResourcePrivateIder interface. - LookupById(ctx context.Context, resourceWithIder interface{}, opt ...Option) error + // unique. If the resource implements either ResourcePublicIder or + // ResourcePrivateIder interface, then they are used as the resource's + // primary key for lookup. Otherwise, the resource tags are used to + // determine it's primary key(s) for lookup. + LookupById(ctx context.Context, resource interface{}, opt ...Option) error // LookupByPublicId will lookup resource by its public_id which must be unique. LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error @@ -1045,11 +1047,11 @@ func (rw *Db) LookupById(ctx context.Context, resourceWithIder interface{}, _ .. if reflect.ValueOf(resourceWithIder).Kind() != reflect.Ptr { return errors.New(ctx, errors.InvalidParameter, op, "interface parameter must to be a pointer") } - primaryKey, where, err := primaryKeyWhere(ctx, resourceWithIder) + where, keys, err := rw.primaryKeysWhere(ctx, resourceWithIder) if err != nil { return errors.Wrap(ctx, err, op) } - if err := rw.underlying.Where(where, primaryKey).First(resourceWithIder).Error; err != nil { + if err := rw.underlying.Where(where, keys...).First(resourceWithIder).Error; err != nil { if err == gorm.ErrRecordNotFound { return errors.E(ctx, errors.WithCode(errors.RecordNotFound), errors.WithOp(op), errors.WithoutEvent()) } @@ -1058,23 +1060,48 @@ func (rw *Db) LookupById(ctx context.Context, resourceWithIder interface{}, _ .. return nil } -func primaryKeyWhere(ctx context.Context, resourceWithIder interface{}) (pkey string, w string, e error) { - const op = "db.primaryKeyWhere" - var primaryKey, where string - switch resourceType := resourceWithIder.(type) { +func (rw *Db) primaryKeysWhere(ctx context.Context, i interface{}) (string, []interface{}, error) { + const op = "db.primaryKeysWhere" + var fieldNames []string + var fieldValues []interface{} + tx := rw.underlying.Model(i) + if err := tx.Statement.Parse(i); err != nil { + return "", nil, errors.Wrap(ctx, err, op, errors.WithoutEvent()) + } + switch resourceType := i.(type) { case ResourcePublicIder: - primaryKey = resourceType.GetPublicId() - where = "public_id = ?" + if resourceType.GetPublicId() == "" { + return "", nil, errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent()) + } + fieldValues = []interface{}{resourceType.GetPublicId()} + fieldNames = []string{"public_id"} case ResourcePrivateIder: - primaryKey = resourceType.GetPrivateId() - where = "private_id = ?" + if resourceType.GetPrivateId() == "" { + return "", nil, errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent()) + } + fieldValues = []interface{}{resourceType.GetPrivateId()} + fieldNames = []string{"private_id"} default: - return "", "", errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported interface type %T", resourceWithIder), errors.WithoutEvent()) + v := reflect.ValueOf(i) + for _, f := range tx.Statement.Schema.PrimaryFields { + if f.PrimaryKey { + val, isZero := f.ValueOf(v) + if isZero { + return "", nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("primary field %s is zero", f.Name)) + } + fieldNames = append(fieldNames, f.DBName) + fieldValues = append(fieldValues, val) + } + } + } + if len(fieldNames) == 0 { + return "", nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("no primary key(s) for %t", i)) } - if primaryKey == "" { - return "", "", errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent()) + clauses := make([]string, 0, len(fieldNames)) + for _, col := range fieldNames { + clauses = append(clauses, fmt.Sprintf("%s = ?", col)) } - return primaryKey, where, nil + return strings.Join(clauses, " and "), fieldValues, nil } // LookupByPublicId will lookup resource by its public_id, which must be unique. diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index 793663dd4a..86eac489a4 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -571,6 +571,17 @@ func TestDb_Update(t *testing.T) { assert.Equal(0, rowsUpdated) assert.Contains(err.Error(), "db.Update: oplog validation failed: db.validateOplogArgs: missing metadata: parameter violation: error #100") }) + t.Run("multi-column", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + w := Db{underlying: db} + scooter := testScooter(t, db, "", 0) + accessory := testAccessory(t, db, "test accessory") + scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId) + scooterAccessory.Review = "this is great" + rowsUpdated, err := w.Update(context.Background(), scooterAccessory, []string{"Review"}, nil) + require.NoError(err) + assert.Equal(1, rowsUpdated) + }) } // testUserWithVet gives us a model that implements VetForWrite() without any @@ -826,7 +837,7 @@ func TestDb_LookupByPublicId(t *testing.T) { require.NoError(err) err = w.LookupByPublicId(context.Background(), foundUser) require.Error(err) - assert.Contains(err.Error(), "db.LookupById: db.primaryKeyWhere: missing primary key: parameter violation: error #100") + assert.Contains(err.Error(), "db.LookupById: db.primaryKeysWhere: missing primary key: parameter violation: error #100") }) t.Run("not-found", func(t *testing.T) { assert, require := assert.New(t), require.New(t) @@ -1448,6 +1459,16 @@ func TestDb_Delete(t *testing.T) { err = TestVerifyOplog(t, &w, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_UNSPECIFIED), WithCreateNotBefore(10*time.Second)) require.NoError(err) }) + t.Run("multi-column", func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + w := Db{underlying: db} + scooter := testScooter(t, db, "", 0) + accessory := testAccessory(t, db, "test accessory") + scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId) + rowsDeleted, err := w.Delete(context.Background(), scooterAccessory) + require.NoError(err) + assert.Equal(1, rowsDeleted) + }) } } @@ -1983,14 +2004,41 @@ func testScooter(t *testing.T, conn *DB, model string, mpg int32) *db_test.TestS return u } +func testAccessory(t *testing.T, conn *DB, description string) *db_test.TestAccessory { + t.Helper() + require := require.New(t) + a, err := db_test.NewTestAccessory(description) + require.NoError(err) + if conn != nil { + err = conn.Create(a).Error + require.NoError(err) + } + return a +} + +func testScooterAccessory(t *testing.T, conn *DB, scooterId, accessoryId uint32) *db_test.TestScooterAccessory { + t.Helper() + require := require.New(t) + a, err := db_test.NewTestScooterAccessory(scooterId, accessoryId) + require.NoError(err) + if conn != nil { + err = conn.Create(a).Error + require.NoError(err) + } + return a +} + func TestDb_LookupById(t *testing.T) { t.Parallel() db, _ := TestSetup(t, "postgres") scooter := testScooter(t, db, "", 0) user := testUser(t, db, "", "", "") + accessory := testAccessory(t, db, "test accessory") + scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId) + type args struct { - resourceWithIder interface{} - opt []Option + resource interface{} + opt []Option } tests := []struct { name string @@ -2004,7 +2052,7 @@ func TestDb_LookupById(t *testing.T) { name: "simple-private-id", underlying: db, args: args{ - resourceWithIder: scooter, + resource: scooter, }, wantErr: false, want: scooter, @@ -2013,16 +2061,47 @@ func TestDb_LookupById(t *testing.T) { name: "simple-public-id", underlying: db, args: args{ - resourceWithIder: user, + resource: user, }, wantErr: false, want: user, }, + { + name: "simple-non-public-non-private-id", + underlying: db, + args: args{ + resource: accessory, + }, + wantErr: false, + want: accessory, + }, + { + name: "compond", + underlying: db, + args: args{ + resource: scooterAccessory, + }, + wantErr: false, + want: scooterAccessory, + }, + { + name: "compond-with-zero-value-pk", + underlying: db, + args: args{ + resource: func() interface{} { + cp := scooterAccessory.Clone() + cp.(*db_test.TestScooterAccessory).ScooterId = 0 + return cp + }(), + }, + wantErr: true, + wantIsErr: errors.InvalidParameter, + }, { name: "missing-public-id", underlying: db, args: args{ - resourceWithIder: &db_test.TestUser{ + resource: &db_test.TestUser{ StoreTestUser: &db_test.StoreTestUser{}, }, }, @@ -2033,7 +2112,7 @@ func TestDb_LookupById(t *testing.T) { name: "missing-private-id", underlying: db, args: args{ - resourceWithIder: &db_test.TestScooter{ + resource: &db_test.TestScooter{ StoreTestScooter: &db_test.StoreTestScooter{}, }, }, @@ -2044,7 +2123,7 @@ func TestDb_LookupById(t *testing.T) { name: "not-an-ider", underlying: db, args: args{ - resourceWithIder: &db_test.NotIder{}, + resource: &db_test.NotIder{}, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -2053,7 +2132,7 @@ func TestDb_LookupById(t *testing.T) { name: "missing-underlying-db", underlying: nil, args: args{ - resourceWithIder: user, + resource: user, }, wantErr: true, wantIsErr: errors.InvalidParameter, @@ -2065,13 +2144,13 @@ func TestDb_LookupById(t *testing.T) { rw := &Db{ underlying: tt.underlying, } - cloner, ok := tt.args.resourceWithIder.(db_test.Cloner) + cloner, ok := tt.args.resource.(db_test.Cloner) require.True(ok) cp := cloner.Clone() err := rw.LookupById(context.Background(), cp, tt.args.opt...) if tt.wantErr { require.Error(err) - require.True(errors.Match(errors.T(tt.wantIsErr), err)) + assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err code: %q got: %q", tt.wantIsErr, err) return } require.NoError(err) diff --git a/internal/db/schema/migrations/oss/postgres/19/01_db.up.sql b/internal/db/schema/migrations/oss/postgres/19/01_db.up.sql new file mode 100644 index 0000000000..352e80934b --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/19/01_db.up.sql @@ -0,0 +1,56 @@ +begin; + +create table if not exists db_test_accessory ( + accessory_id bigint generated always as identity primary key, + create_time wt_timestamp, + update_time wt_timestamp, + description text not null +); + +create trigger + update_time_column +before +update on db_test_accessory + for each row execute procedure update_time_column(); + +create trigger + immutable_columns +before +update on db_test_accessory + for each row execute procedure immutable_columns('create_time'); + +create trigger + default_create_time_column +before +insert on db_test_accessory + for each row execute procedure default_create_time(); + + +create table if not exists db_test_scooter_accessory ( + accessory_id bigint references db_test_accessory(accessory_id), + scooter_id bigint references db_test_scooter(id), + create_time wt_timestamp, + update_time wt_timestamp, + review text, + primary key(accessory_id, scooter_id) +); + +create trigger + update_time_column +before +update on db_test_scooter_accessory + for each row execute procedure update_time_column(); + +create trigger + immutable_columns +before +update on db_test_scooter_accessory + for each row execute procedure immutable_columns('create_time'); + +create trigger + default_create_time_column +before +insert on db_test_scooter_accessory + for each row execute procedure default_create_time(); + +commit; \ No newline at end of file diff --git a/internal/proto/local/controller/storage/db/db_test/v1/db_test.proto b/internal/proto/local/controller/storage/db/db_test/v1/db_test.proto index 22917ef95e..4519f5932a 100644 --- a/internal/proto/local/controller/storage/db/db_test/v1/db_test.proto +++ b/internal/proto/local/controller/storage/db/db_test/v1/db_test.proto @@ -118,4 +118,27 @@ message StoreTestScooter { // @inject_tag: `gorm:"default:null"` int32 mpg = 7; +} + +// StoreTestAccessory used in the db tests only and provides a gorm resource with +// an id that's not a private or public id +message StoreTestAccessory { + // @inject_tag: gorm:"primary_key" + uint32 accessory_id = 1; + + // @inject_tag: `gorm:"default:not_null"` + string description = 2; +} + +// StoreTestScooterAccessory used in the db tests only and provides a gorm +// resource with multiple pks +message StoreTestScooterAccessory { + // @inject_tag: gorm:"primary_key" + uint32 accessory_id = 1; + + // @inject_tag: gorm:"primary_key" + uint32 scooter_id = 2; + + // @inject_tag: `gorm:"default:null"` + string review = 3; } \ No newline at end of file diff --git a/internal/session/immutable_fields_test.go b/internal/session/immutable_fields_test.go index ec90235a16..96c7315e46 100644 --- a/internal/session/immutable_fields_test.go +++ b/internal/session/immutable_fields_test.go @@ -6,10 +6,12 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm/logger" ) func TestSession_ImmutableFields(t *testing.T) { @@ -245,16 +247,20 @@ func TestConnectionState_ImmutableFields(t *testing.T) { _, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper)) session := TestDefaultSession(t, conn, wrapper, iamRepo) connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222) - state := TestConnectionState(t, conn, connection.PublicId, StatusClosed) + state := TestConnectionState(t, conn, connection.PublicId, StatusConnected) var new ConnectionState err := rw.LookupWhere(context.Background(), &new, "connection_id = ? and state = ?", state.ConnectionId, state.Status) require.NoError(t, err) + conn.Logger = logger.Default.LogMode(logger.Info) + tests := []struct { - name string - update *ConnectionState - fieldMask []string + name string + update *ConnectionState + fieldMask []string + wantErrMatch *errors.Template + wantErrContains string }{ { name: "session_id", @@ -272,7 +278,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) { s.Status = "closed" return s }(), - fieldMask: []string{"Status"}, + fieldMask: []string{"Status"}, + wantErrMatch: errors.T(errors.NotSpecificIntegrity), + wantErrContains: "immutable column", }, { name: "start time", @@ -281,7 +289,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) { s.StartTime = &ts return s }(), - fieldMask: []string{"StartTime"}, + fieldMask: []string{"StartTime"}, + wantErrMatch: errors.T(errors.InvalidFieldMask), + wantErrContains: "parameter violation", }, { name: "previous_end_time", @@ -290,7 +300,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) { s.PreviousEndTime = &ts return s }(), - fieldMask: []string{"PreviousEndTime"}, + fieldMask: []string{"PreviousEndTime"}, + wantErrMatch: errors.T(errors.NotSpecificIntegrity), + wantErrContains: "immutable column", }, } for _, tt := range tests { @@ -304,7 +316,12 @@ func TestConnectionState_ImmutableFields(t *testing.T) { rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true)) require.Error(err) assert.Equal(0, rowsUpdated) - + if tt.wantErrMatch != nil { + assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted error %s and got: %s", tt.wantErrMatch.Code, err.Error()) + } + if tt.wantErrContains != "" { + assert.Contains(err.Error(), tt.wantErrContains) + } after := new.Clone() err = rw.LookupWhere(context.Background(), after, "connection_id = ? and start_time = ?", new.ConnectionId, new.StartTime) require.NoError(err)