feature (db): Add support for multi-column PKs (#1658)

pull/1663/head
Jim 5 years ago committed by GitHub
parent 8d3f666bba
commit 4df3b40def
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -46,6 +46,16 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin
val := reflect.Indirect(reflect.ValueOf(i))
structTyp := val.Type()
for i := 0; i < structTyp.NumField(); i++ {
if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = val.Field(i).Interface()
found[strings.ToUpper(f)] = struct{}{}
continue
}
if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
continue
}
kind := structTyp.Field(i).Type.Kind()
if kind == reflect.Struct || kind == reflect.Ptr {
embType := structTyp.Field(i).Type
@ -76,14 +86,6 @@ func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []strin
continue
}
}
if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = val.Field(i).Interface()
found[strings.ToUpper(f)] = struct{}{}
}
if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
}
}
if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 {

@ -4,9 +4,13 @@ import (
"testing"
"github.com/hashicorp/boundary/internal/db/db_test"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/go-secure-stdlib/base62"
"github.com/hashicorp/go-uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm"
)
@ -398,6 +402,20 @@ func TestUpdateFields(t *testing.T) {
assert.Equal(tt.want, got)
})
}
t.Run("valid-embedded-timestamp", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
wantTs := &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: 1,
Nanos: 1,
},
}
u := testUser(t, "", "")
u.UpdateTime = wantTs
got, err := UpdateFields(u, []string{"UpdateTime"}, nil)
require.NoError(err)
assert.True(proto.Equal(wantTs, got["UpdateTime"].(*timestamp.Timestamp)))
})
}
func testUser(t *testing.T, name, email string) *db_test.TestUser {

@ -2,15 +2,19 @@
package db_test
import (
"errors"
"github.com/hashicorp/go-secure-stdlib/base62"
"google.golang.org/protobuf/proto"
)
const (
defaultUserTablename = "db_test_user"
defaultCarTableName = "db_test_car"
defaultRentalTableName = "db_test_rental"
defaultScooterTableName = "db_test_scooter"
defaultUserTablename = "db_test_user"
defaultCarTableName = "db_test_car"
defaultRentalTableName = "db_test_rental"
defaultScooterTableName = "db_test_scooter"
defaultAccessoryTableName = "db_test_accessory"
defaultScooterAccessoryTableName = "db_test_scooter_accessory"
)
type TestUser struct {
@ -153,6 +157,71 @@ func (t *TestScooter) SetTableName(name string) {
t.table = name
}
type TestAccessory struct {
*StoreTestAccessory
table string `gorm:"-"`
}
func NewTestAccessory(description string) (*TestAccessory, error) {
if description == "" {
return nil, errors.New("missing description")
}
return &TestAccessory{StoreTestAccessory: &StoreTestAccessory{Description: description}}, nil
}
func (t *TestAccessory) Clone() interface{} {
s := proto.Clone(t.StoreTestAccessory)
return &TestAccessory{
StoreTestAccessory: s.(*StoreTestAccessory),
}
}
func (t *TestAccessory) TableName() string {
if t.table != "" {
return t.table
}
return defaultAccessoryTableName
}
func (t *TestAccessory) SetTableName(name string) {
t.table = name
}
type TestScooterAccessory struct {
*StoreTestScooterAccessory
table string `gorm:"-"`
}
func NewTestScooterAccessory(scooterId, accessoryId uint32) (*TestScooterAccessory, error) {
if accessoryId == 0 {
return nil, errors.New("mssing accessory id")
}
return &TestScooterAccessory{
StoreTestScooterAccessory: &StoreTestScooterAccessory{
ScooterId: scooterId,
AccessoryId: accessoryId,
},
}, nil
}
func (t *TestScooterAccessory) Clone() interface{} {
s := proto.Clone(t.StoreTestScooterAccessory)
return &TestScooterAccessory{
StoreTestScooterAccessory: s.(*StoreTestScooterAccessory),
}
}
func (t *TestScooterAccessory) TableName() string {
if t.table != "" {
return t.table
}
return defaultScooterAccessoryTableName
}
func (t *TestScooterAccessory) SetTableName(name string) {
t.table = name
}
type Cloner interface {
Clone() interface{}
}

@ -453,6 +453,133 @@ func (x *StoreTestScooter) GetMpg() int32 {
return 0
}
// StoreTestAccessory used in the db tests only and provides a gorm resource with
// an id that's not a private or public id
type StoreTestAccessory struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// @inject_tag: gorm:"primary_key"
AccessoryId uint32 `protobuf:"varint,1,opt,name=accessory_id,json=accessoryId,proto3" json:"accessory_id,omitempty" gorm:"primary_key"`
// @inject_tag: `gorm:"default:not_null"`
Description string `protobuf:"bytes,2,opt,name=description,proto3" json:"description,omitempty" gorm:"default:not_null"`
}
func (x *StoreTestAccessory) Reset() {
*x = StoreTestAccessory{}
if protoimpl.UnsafeEnabled {
mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StoreTestAccessory) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StoreTestAccessory) ProtoMessage() {}
func (x *StoreTestAccessory) ProtoReflect() protoreflect.Message {
mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StoreTestAccessory.ProtoReflect.Descriptor instead.
func (*StoreTestAccessory) Descriptor() ([]byte, []int) {
return file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP(), []int{4}
}
func (x *StoreTestAccessory) GetAccessoryId() uint32 {
if x != nil {
return x.AccessoryId
}
return 0
}
func (x *StoreTestAccessory) GetDescription() string {
if x != nil {
return x.Description
}
return ""
}
// StoreTestScooterAccessory used in the db tests only and provides a gorm
// resource with multiple pks
type StoreTestScooterAccessory struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// @inject_tag: gorm:"primary_key"
AccessoryId uint32 `protobuf:"varint,1,opt,name=accessory_id,json=accessoryId,proto3" json:"accessory_id,omitempty" gorm:"primary_key"`
// @inject_tag: gorm:"primary_key"
ScooterId uint32 `protobuf:"varint,2,opt,name=scooter_id,json=scooterId,proto3" json:"scooter_id,omitempty" gorm:"primary_key"`
// @inject_tag: `gorm:"default:null"`
Review string `protobuf:"bytes,3,opt,name=review,proto3" json:"review,omitempty" gorm:"default:null"`
}
func (x *StoreTestScooterAccessory) Reset() {
*x = StoreTestScooterAccessory{}
if protoimpl.UnsafeEnabled {
mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StoreTestScooterAccessory) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StoreTestScooterAccessory) ProtoMessage() {}
func (x *StoreTestScooterAccessory) ProtoReflect() protoreflect.Message {
mi := &file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StoreTestScooterAccessory.ProtoReflect.Descriptor instead.
func (*StoreTestScooterAccessory) Descriptor() ([]byte, []int) {
return file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP(), []int{5}
}
func (x *StoreTestScooterAccessory) GetAccessoryId() uint32 {
if x != nil {
return x.AccessoryId
}
return 0
}
func (x *StoreTestScooterAccessory) GetScooterId() uint32 {
if x != nil {
return x.ScooterId
}
return 0
}
func (x *StoreTestScooterAccessory) GetReview() string {
if x != nil {
return x.Review
}
return ""
}
var File_controller_storage_db_db_test_v1_db_test_proto protoreflect.FileDescriptor
var file_controller_storage_db_db_test_v1_db_test_proto_rawDesc = []byte{
@ -535,12 +662,25 @@ var file_controller_storage_db_db_test_v1_db_test_proto_rawDesc = []byte{
0x5f, 0x69, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x76, 0x61,
0x74, 0x65, 0x49, 0x64, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x06, 0x20,
0x01, 0x28, 0x09, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x70,
0x67, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x70, 0x67, 0x42, 0x3b, 0x5a, 0x39,
0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69,
0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e,
0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x64, 0x62, 0x2f, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73,
0x74, 0x3b, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f,
0x33,
0x67, 0x18, 0x07, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6d, 0x70, 0x67, 0x22, 0x59, 0x0a, 0x12,
0x53, 0x74, 0x6f, 0x72, 0x65, 0x54, 0x65, 0x73, 0x74, 0x41, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f,
0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72, 0x79, 0x5f,
0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73,
0x6f, 0x72, 0x79, 0x49, 0x64, 0x12, 0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70,
0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63,
0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x22, 0x75, 0x0a, 0x19, 0x53, 0x74, 0x6f, 0x72, 0x65,
0x54, 0x65, 0x73, 0x74, 0x53, 0x63, 0x6f, 0x6f, 0x74, 0x65, 0x72, 0x41, 0x63, 0x63, 0x65, 0x73,
0x73, 0x6f, 0x72, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x61, 0x63, 0x63, 0x65, 0x73, 0x73, 0x6f, 0x72,
0x79, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0b, 0x61, 0x63, 0x63, 0x65,
0x73, 0x73, 0x6f, 0x72, 0x79, 0x49, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x63, 0x6f, 0x6f, 0x74,
0x65, 0x72, 0x5f, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x09, 0x73, 0x63, 0x6f,
0x6f, 0x74, 0x65, 0x72, 0x49, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77,
0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x72, 0x65, 0x76, 0x69, 0x65, 0x77, 0x42, 0x3b,
0x5a, 0x39, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73,
0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f,
0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x64, 0x62, 0x2f, 0x64, 0x62, 0x5f, 0x74,
0x65, 0x73, 0x74, 0x3b, 0x64, 0x62, 0x5f, 0x74, 0x65, 0x73, 0x74, 0x62, 0x06, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x33,
}
var (
@ -555,23 +695,25 @@ func file_controller_storage_db_db_test_v1_db_test_proto_rawDescGZIP() []byte {
return file_controller_storage_db_db_test_v1_db_test_proto_rawDescData
}
var file_controller_storage_db_db_test_v1_db_test_proto_msgTypes = make([]protoimpl.MessageInfo, 4)
var file_controller_storage_db_db_test_v1_db_test_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
var file_controller_storage_db_db_test_v1_db_test_proto_goTypes = []interface{}{
(*StoreTestUser)(nil), // 0: controller.storage.db.db_test.v1.StoreTestUser
(*StoreTestCar)(nil), // 1: controller.storage.db.db_test.v1.StoreTestCar
(*StoreTestRental)(nil), // 2: controller.storage.db.db_test.v1.StoreTestRental
(*StoreTestScooter)(nil), // 3: controller.storage.db.db_test.v1.StoreTestScooter
(*timestamp.Timestamp)(nil), // 4: controller.storage.timestamp.v1.Timestamp
(*StoreTestUser)(nil), // 0: controller.storage.db.db_test.v1.StoreTestUser
(*StoreTestCar)(nil), // 1: controller.storage.db.db_test.v1.StoreTestCar
(*StoreTestRental)(nil), // 2: controller.storage.db.db_test.v1.StoreTestRental
(*StoreTestScooter)(nil), // 3: controller.storage.db.db_test.v1.StoreTestScooter
(*StoreTestAccessory)(nil), // 4: controller.storage.db.db_test.v1.StoreTestAccessory
(*StoreTestScooterAccessory)(nil), // 5: controller.storage.db.db_test.v1.StoreTestScooterAccessory
(*timestamp.Timestamp)(nil), // 6: controller.storage.timestamp.v1.Timestamp
}
var file_controller_storage_db_db_test_v1_db_test_proto_depIdxs = []int32{
4, // 0: controller.storage.db.db_test.v1.StoreTestUser.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 1: controller.storage.db.db_test.v1.StoreTestUser.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 2: controller.storage.db.db_test.v1.StoreTestCar.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 3: controller.storage.db.db_test.v1.StoreTestCar.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 4: controller.storage.db.db_test.v1.StoreTestRental.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 5: controller.storage.db.db_test.v1.StoreTestRental.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 6: controller.storage.db.db_test.v1.StoreTestScooter.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
4, // 7: controller.storage.db.db_test.v1.StoreTestScooter.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 0: controller.storage.db.db_test.v1.StoreTestUser.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 1: controller.storage.db.db_test.v1.StoreTestUser.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 2: controller.storage.db.db_test.v1.StoreTestCar.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 3: controller.storage.db.db_test.v1.StoreTestCar.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 4: controller.storage.db.db_test.v1.StoreTestRental.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 5: controller.storage.db.db_test.v1.StoreTestRental.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 6: controller.storage.db.db_test.v1.StoreTestScooter.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
6, // 7: controller.storage.db.db_test.v1.StoreTestScooter.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
8, // [8:8] is the sub-list for method output_type
8, // [8:8] is the sub-list for method input_type
8, // [8:8] is the sub-list for extension type_name
@ -633,6 +775,30 @@ func file_controller_storage_db_db_test_v1_db_test_proto_init() {
return nil
}
}
file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StoreTestAccessory); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_controller_storage_db_db_test_v1_db_test_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StoreTestScooterAccessory); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
@ -640,7 +806,7 @@ func file_controller_storage_db_db_test_v1_db_test_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_controller_storage_db_db_test_v1_db_test_proto_rawDesc,
NumEnums: 0,
NumMessages: 4,
NumMessages: 6,
NumExtensions: 0,
NumServices: 0,
},

@ -43,9 +43,11 @@ const (
// Reader interface defines lookups/searching for resources
type Reader interface {
// LookupById will lookup a resource by its primary key id, which must be
// unique. The resourceWithIder must implement either ResourcePublicIder or
// ResourcePrivateIder interface.
LookupById(ctx context.Context, resourceWithIder interface{}, opt ...Option) error
// unique. If the resource implements either ResourcePublicIder or
// ResourcePrivateIder interface, then they are used as the resource's
// primary key for lookup. Otherwise, the resource tags are used to
// determine it's primary key(s) for lookup.
LookupById(ctx context.Context, resource interface{}, opt ...Option) error
// LookupByPublicId will lookup resource by its public_id which must be unique.
LookupByPublicId(ctx context.Context, resource ResourcePublicIder, opt ...Option) error
@ -1045,11 +1047,11 @@ func (rw *Db) LookupById(ctx context.Context, resourceWithIder interface{}, _ ..
if reflect.ValueOf(resourceWithIder).Kind() != reflect.Ptr {
return errors.New(ctx, errors.InvalidParameter, op, "interface parameter must to be a pointer")
}
primaryKey, where, err := primaryKeyWhere(ctx, resourceWithIder)
where, keys, err := rw.primaryKeysWhere(ctx, resourceWithIder)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if err := rw.underlying.Where(where, primaryKey).First(resourceWithIder).Error; err != nil {
if err := rw.underlying.Where(where, keys...).First(resourceWithIder).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return errors.E(ctx, errors.WithCode(errors.RecordNotFound), errors.WithOp(op), errors.WithoutEvent())
}
@ -1058,23 +1060,48 @@ func (rw *Db) LookupById(ctx context.Context, resourceWithIder interface{}, _ ..
return nil
}
func primaryKeyWhere(ctx context.Context, resourceWithIder interface{}) (pkey string, w string, e error) {
const op = "db.primaryKeyWhere"
var primaryKey, where string
switch resourceType := resourceWithIder.(type) {
func (rw *Db) primaryKeysWhere(ctx context.Context, i interface{}) (string, []interface{}, error) {
const op = "db.primaryKeysWhere"
var fieldNames []string
var fieldValues []interface{}
tx := rw.underlying.Model(i)
if err := tx.Statement.Parse(i); err != nil {
return "", nil, errors.Wrap(ctx, err, op, errors.WithoutEvent())
}
switch resourceType := i.(type) {
case ResourcePublicIder:
primaryKey = resourceType.GetPublicId()
where = "public_id = ?"
if resourceType.GetPublicId() == "" {
return "", nil, errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent())
}
fieldValues = []interface{}{resourceType.GetPublicId()}
fieldNames = []string{"public_id"}
case ResourcePrivateIder:
primaryKey = resourceType.GetPrivateId()
where = "private_id = ?"
if resourceType.GetPrivateId() == "" {
return "", nil, errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent())
}
fieldValues = []interface{}{resourceType.GetPrivateId()}
fieldNames = []string{"private_id"}
default:
return "", "", errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("unsupported interface type %T", resourceWithIder), errors.WithoutEvent())
v := reflect.ValueOf(i)
for _, f := range tx.Statement.Schema.PrimaryFields {
if f.PrimaryKey {
val, isZero := f.ValueOf(v)
if isZero {
return "", nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("primary field %s is zero", f.Name))
}
fieldNames = append(fieldNames, f.DBName)
fieldValues = append(fieldValues, val)
}
}
}
if len(fieldNames) == 0 {
return "", nil, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("no primary key(s) for %t", i))
}
if primaryKey == "" {
return "", "", errors.New(ctx, errors.InvalidParameter, op, "missing primary key", errors.WithoutEvent())
clauses := make([]string, 0, len(fieldNames))
for _, col := range fieldNames {
clauses = append(clauses, fmt.Sprintf("%s = ?", col))
}
return primaryKey, where, nil
return strings.Join(clauses, " and "), fieldValues, nil
}
// LookupByPublicId will lookup resource by its public_id, which must be unique.

@ -571,6 +571,17 @@ func TestDb_Update(t *testing.T) {
assert.Equal(0, rowsUpdated)
assert.Contains(err.Error(), "db.Update: oplog validation failed: db.validateOplogArgs: missing metadata: parameter violation: error #100")
})
t.Run("multi-column", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := Db{underlying: db}
scooter := testScooter(t, db, "", 0)
accessory := testAccessory(t, db, "test accessory")
scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId)
scooterAccessory.Review = "this is great"
rowsUpdated, err := w.Update(context.Background(), scooterAccessory, []string{"Review"}, nil)
require.NoError(err)
assert.Equal(1, rowsUpdated)
})
}
// testUserWithVet gives us a model that implements VetForWrite() without any
@ -826,7 +837,7 @@ func TestDb_LookupByPublicId(t *testing.T) {
require.NoError(err)
err = w.LookupByPublicId(context.Background(), foundUser)
require.Error(err)
assert.Contains(err.Error(), "db.LookupById: db.primaryKeyWhere: missing primary key: parameter violation: error #100")
assert.Contains(err.Error(), "db.LookupById: db.primaryKeysWhere: missing primary key: parameter violation: error #100")
})
t.Run("not-found", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
@ -1448,6 +1459,16 @@ func TestDb_Delete(t *testing.T) {
err = TestVerifyOplog(t, &w, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_UNSPECIFIED), WithCreateNotBefore(10*time.Second))
require.NoError(err)
})
t.Run("multi-column", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
w := Db{underlying: db}
scooter := testScooter(t, db, "", 0)
accessory := testAccessory(t, db, "test accessory")
scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId)
rowsDeleted, err := w.Delete(context.Background(), scooterAccessory)
require.NoError(err)
assert.Equal(1, rowsDeleted)
})
}
}
@ -1983,14 +2004,41 @@ func testScooter(t *testing.T, conn *DB, model string, mpg int32) *db_test.TestS
return u
}
func testAccessory(t *testing.T, conn *DB, description string) *db_test.TestAccessory {
t.Helper()
require := require.New(t)
a, err := db_test.NewTestAccessory(description)
require.NoError(err)
if conn != nil {
err = conn.Create(a).Error
require.NoError(err)
}
return a
}
func testScooterAccessory(t *testing.T, conn *DB, scooterId, accessoryId uint32) *db_test.TestScooterAccessory {
t.Helper()
require := require.New(t)
a, err := db_test.NewTestScooterAccessory(scooterId, accessoryId)
require.NoError(err)
if conn != nil {
err = conn.Create(a).Error
require.NoError(err)
}
return a
}
func TestDb_LookupById(t *testing.T) {
t.Parallel()
db, _ := TestSetup(t, "postgres")
scooter := testScooter(t, db, "", 0)
user := testUser(t, db, "", "", "")
accessory := testAccessory(t, db, "test accessory")
scooterAccessory := testScooterAccessory(t, db, scooter.Id, accessory.AccessoryId)
type args struct {
resourceWithIder interface{}
opt []Option
resource interface{}
opt []Option
}
tests := []struct {
name string
@ -2004,7 +2052,7 @@ func TestDb_LookupById(t *testing.T) {
name: "simple-private-id",
underlying: db,
args: args{
resourceWithIder: scooter,
resource: scooter,
},
wantErr: false,
want: scooter,
@ -2013,16 +2061,47 @@ func TestDb_LookupById(t *testing.T) {
name: "simple-public-id",
underlying: db,
args: args{
resourceWithIder: user,
resource: user,
},
wantErr: false,
want: user,
},
{
name: "simple-non-public-non-private-id",
underlying: db,
args: args{
resource: accessory,
},
wantErr: false,
want: accessory,
},
{
name: "compond",
underlying: db,
args: args{
resource: scooterAccessory,
},
wantErr: false,
want: scooterAccessory,
},
{
name: "compond-with-zero-value-pk",
underlying: db,
args: args{
resource: func() interface{} {
cp := scooterAccessory.Clone()
cp.(*db_test.TestScooterAccessory).ScooterId = 0
return cp
}(),
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
},
{
name: "missing-public-id",
underlying: db,
args: args{
resourceWithIder: &db_test.TestUser{
resource: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{},
},
},
@ -2033,7 +2112,7 @@ func TestDb_LookupById(t *testing.T) {
name: "missing-private-id",
underlying: db,
args: args{
resourceWithIder: &db_test.TestScooter{
resource: &db_test.TestScooter{
StoreTestScooter: &db_test.StoreTestScooter{},
},
},
@ -2044,7 +2123,7 @@ func TestDb_LookupById(t *testing.T) {
name: "not-an-ider",
underlying: db,
args: args{
resourceWithIder: &db_test.NotIder{},
resource: &db_test.NotIder{},
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -2053,7 +2132,7 @@ func TestDb_LookupById(t *testing.T) {
name: "missing-underlying-db",
underlying: nil,
args: args{
resourceWithIder: user,
resource: user,
},
wantErr: true,
wantIsErr: errors.InvalidParameter,
@ -2065,13 +2144,13 @@ func TestDb_LookupById(t *testing.T) {
rw := &Db{
underlying: tt.underlying,
}
cloner, ok := tt.args.resourceWithIder.(db_test.Cloner)
cloner, ok := tt.args.resource.(db_test.Cloner)
require.True(ok)
cp := cloner.Clone()
err := rw.LookupById(context.Background(), cp, tt.args.opt...)
if tt.wantErr {
require.Error(err)
require.True(errors.Match(errors.T(tt.wantIsErr), err))
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err code: %q got: %q", tt.wantIsErr, err)
return
}
require.NoError(err)

@ -0,0 +1,56 @@
begin;
create table if not exists db_test_accessory (
accessory_id bigint generated always as identity primary key,
create_time wt_timestamp,
update_time wt_timestamp,
description text not null
);
create trigger
update_time_column
before
update on db_test_accessory
for each row execute procedure update_time_column();
create trigger
immutable_columns
before
update on db_test_accessory
for each row execute procedure immutable_columns('create_time');
create trigger
default_create_time_column
before
insert on db_test_accessory
for each row execute procedure default_create_time();
create table if not exists db_test_scooter_accessory (
accessory_id bigint references db_test_accessory(accessory_id),
scooter_id bigint references db_test_scooter(id),
create_time wt_timestamp,
update_time wt_timestamp,
review text,
primary key(accessory_id, scooter_id)
);
create trigger
update_time_column
before
update on db_test_scooter_accessory
for each row execute procedure update_time_column();
create trigger
immutable_columns
before
update on db_test_scooter_accessory
for each row execute procedure immutable_columns('create_time');
create trigger
default_create_time_column
before
insert on db_test_scooter_accessory
for each row execute procedure default_create_time();
commit;

@ -118,4 +118,27 @@ message StoreTestScooter {
// @inject_tag: `gorm:"default:null"`
int32 mpg = 7;
}
// StoreTestAccessory used in the db tests only and provides a gorm resource with
// an id that's not a private or public id
message StoreTestAccessory {
// @inject_tag: gorm:"primary_key"
uint32 accessory_id = 1;
// @inject_tag: `gorm:"default:not_null"`
string description = 2;
}
// StoreTestScooterAccessory used in the db tests only and provides a gorm
// resource with multiple pks
message StoreTestScooterAccessory {
// @inject_tag: gorm:"primary_key"
uint32 accessory_id = 1;
// @inject_tag: gorm:"primary_key"
uint32 scooter_id = 2;
// @inject_tag: `gorm:"default:null"`
string review = 3;
}

@ -6,10 +6,12 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
"gorm.io/gorm/logger"
)
func TestSession_ImmutableFields(t *testing.T) {
@ -245,16 +247,20 @@ func TestConnectionState_ImmutableFields(t *testing.T) {
_, _ = iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
session := TestDefaultSession(t, conn, wrapper, iamRepo)
connection := TestConnection(t, conn, session.PublicId, "127.0.0.1", 22, "127.0.0.1", 2222)
state := TestConnectionState(t, conn, connection.PublicId, StatusClosed)
state := TestConnectionState(t, conn, connection.PublicId, StatusConnected)
var new ConnectionState
err := rw.LookupWhere(context.Background(), &new, "connection_id = ? and state = ?", state.ConnectionId, state.Status)
require.NoError(t, err)
conn.Logger = logger.Default.LogMode(logger.Info)
tests := []struct {
name string
update *ConnectionState
fieldMask []string
name string
update *ConnectionState
fieldMask []string
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "session_id",
@ -272,7 +278,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) {
s.Status = "closed"
return s
}(),
fieldMask: []string{"Status"},
fieldMask: []string{"Status"},
wantErrMatch: errors.T(errors.NotSpecificIntegrity),
wantErrContains: "immutable column",
},
{
name: "start time",
@ -281,7 +289,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) {
s.StartTime = &ts
return s
}(),
fieldMask: []string{"StartTime"},
fieldMask: []string{"StartTime"},
wantErrMatch: errors.T(errors.InvalidFieldMask),
wantErrContains: "parameter violation",
},
{
name: "previous_end_time",
@ -290,7 +300,9 @@ func TestConnectionState_ImmutableFields(t *testing.T) {
s.PreviousEndTime = &ts
return s
}(),
fieldMask: []string{"PreviousEndTime"},
fieldMask: []string{"PreviousEndTime"},
wantErrMatch: errors.T(errors.NotSpecificIntegrity),
wantErrContains: "immutable column",
},
}
for _, tt := range tests {
@ -304,7 +316,12 @@ func TestConnectionState_ImmutableFields(t *testing.T) {
rowsUpdated, err := rw.Update(context.Background(), tt.update, tt.fieldMask, nil, db.WithSkipVetForWrite(true))
require.Error(err)
assert.Equal(0, rowsUpdated)
if tt.wantErrMatch != nil {
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted error %s and got: %s", tt.wantErrMatch.Code, err.Error())
}
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
after := new.Clone()
err = rw.LookupWhere(context.Background(), after, "connection_id = ? and start_time = ?", new.ConnectionId, new.StartTime)
require.NoError(err)

Loading…
Cancel
Save