diff --git a/internal/db/common/update.go b/internal/db/common/update.go new file mode 100644 index 0000000000..dacbb19153 --- /dev/null +++ b/internal/db/common/update.go @@ -0,0 +1,143 @@ +// common package contains functions from internal/db which need to be shared +// commonly with other packages that have a cyclic dependency on internal/db +// like internal/oplog. +package common + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/jinzhu/gorm" +) + +var ( + // ErrNilParameter is returned when a required parameter is nil. + ErrNilParameter = errors.New("nil parameter") +) + +// UpdateFields will create a map[string]interface of the update values to be +// sent to the db. The map keys will be the field names for the fields to be +// updated. The caller provided fieldMaskPaths and setToNullPaths must not +// intersect. fieldMaskPaths and setToNullPaths cannot both be zero len. +func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []string) (map[string]interface{}, error) { + if i == nil { + return nil, fmt.Errorf("interface is missing: %w", ErrNilParameter) + } + if fieldMaskPaths == nil { + fieldMaskPaths = []string{} + } + if setToNullPaths == nil { + setToNullPaths = []string{} + } + if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { + return nil, errors.New("both fieldMaskPaths and setToNullPaths are zero len") + } + + inter, maskPaths, nullPaths, err := intersection(fieldMaskPaths, setToNullPaths) + if err != nil { + return nil, err + } + if len(inter) != 0 { + return nil, fmt.Errorf("fieldMashPaths and setToNullPaths cannot intersect") + } + + updateFields := map[string]interface{}{} // case sensitive update fields to values + + found := map[string]struct{}{} // we need something to keep track of found fields (case insensitive) + + val := reflect.Indirect(reflect.ValueOf(i)) + structTyp := val.Type() + for i := 0; i < structTyp.NumField(); i++ { + kind := structTyp.Field(i).Type.Kind() + if kind == reflect.Struct || kind == reflect.Ptr { + embType := structTyp.Field(i).Type + // check if the embedded field is exported via CanInterface() + if val.Field(i).CanInterface() { + embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface())) + // if it's a ptr to a struct, then we need a few more bits before proceeding. + if kind == reflect.Ptr { + embVal = val.Field(i).Elem() + if !embVal.IsValid() { + continue + } + embType = embVal.Type() + if embType.Kind() != reflect.Struct { + continue + } + } + for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ { + if f, ok := maskPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { + updateFields[f] = embVal.Field(embFieldNum).Interface() + found[strings.ToUpper(f)] = struct{}{} + } + if f, ok := nullPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok { + updateFields[f] = gorm.Expr("NULL") + found[strings.ToUpper(f)] = struct{}{} + } + } + continue + } + } + if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { + updateFields[f] = val.Field(i).Interface() + found[strings.ToUpper(f)] = struct{}{} + } + if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok { + updateFields[f] = gorm.Expr("NULL") + found[strings.ToUpper(f)] = struct{}{} + } + } + + if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 { + return nil, fmt.Errorf("null paths not found in resource: %s", missing) + } + + if missing := findMissingPaths(fieldMaskPaths, found); len(missing) != 0 { + return nil, fmt.Errorf("field mask paths not found in resource: %s", missing) + } + + return updateFields, nil +} + +func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string { + notFound := []string{} + for _, f := range paths { + if _, ok := foundPaths[strings.ToUpper(f)]; !ok { + notFound = append(notFound, f) + } + } + return notFound +} + +// intersection is a case-insensitive search for intersecting values. Returns +// []string of the intersection with values in lowercase, and map[string]string +// of the original av and bv, with the key set to uppercase and value set to the +// original +func intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) { + if av == nil { + return nil, nil, nil, fmt.Errorf("av is missing: %w", ErrNilParameter) + } + if bv == nil { + return nil, nil, nil, fmt.Errorf("bv is missing: %w", ErrNilParameter) + } + if len(av) == 0 && len(bv) == 0 { + return []string{}, map[string]string{}, map[string]string{}, nil + } + s := []string{} + ah := map[string]string{} + bh := map[string]string{} + + for i := 0; i < len(av); i++ { + ah[strings.ToUpper(av[i])] = av[i] + } + for i := 0; i < len(bv); i++ { + k := strings.ToUpper(bv[i]) + bh[k] = bv[i] + if _, found := ah[k]; found { + s = append(s, strings.ToLower(bh[k])) + } + } + return s, ah, bh, nil +} diff --git a/internal/db/common/update_test.go b/internal/db/common/update_test.go new file mode 100644 index 0000000000..38ca0685c2 --- /dev/null +++ b/internal/db/common/update_test.go @@ -0,0 +1,419 @@ +package common + +import ( + "testing" + + "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/helper/base62" + "github.com/hashicorp/watchtower/internal/db/db_test" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" +) + +func Test_intersection(t *testing.T) { + type args struct { + av []string + bv []string + } + tests := []struct { + name string + args args + want []string + want1 map[string]string + want2 map[string]string + wantErr bool + wantErrMsg string + }{ + { + name: "intersect", + args: args{ + av: []string{"alice"}, + bv: []string{"alice", "bob"}, + }, + want: []string{"alice"}, + want1: map[string]string{ + "ALICE": "alice", + }, + want2: map[string]string{ + "ALICE": "alice", + "BOB": "bob", + }, + }, + { + name: "intersect-2", + args: args{ + av: []string{"alice", "bob", "jane", "doe"}, + bv: []string{"alice", "doe", "bert", "ernie", "bigbird"}, + }, + want: []string{"alice", "doe"}, + want1: map[string]string{ + "ALICE": "alice", + "BOB": "bob", + "JANE": "jane", + "DOE": "doe", + }, + want2: map[string]string{ + "ALICE": "alice", + "DOE": "doe", + "BERT": "bert", + "ERNIE": "ernie", + "BIGBIRD": "bigbird", + }, + }, + { + name: "intersect-mixed-case", + args: args{ + av: []string{"AlicE"}, + bv: []string{"alICe", "Bob"}, + }, + want: []string{"alice"}, + want1: map[string]string{ + "ALICE": "AlicE", + }, + want2: map[string]string{ + "ALICE": "alICe", + "BOB": "Bob", + }, + }, + { + name: "no-intersect-mixed-case", + args: args{ + av: []string{"AliCe", "BOb", "jaNe", "DOE"}, + bv: []string{"beRt", "ERnie", "bigBIRD"}, + }, + want: []string{}, + want1: map[string]string{ + "ALICE": "AliCe", + "BOB": "BOb", + "JANE": "jaNe", + "DOE": "DOE", + }, + want2: map[string]string{ + "BERT": "beRt", + "ERNIE": "ERnie", + "BIGBIRD": "bigBIRD", + }, + }, + { + name: "no-intersect-1", + args: args{ + av: []string{"alice", "bob", "jane", "doe"}, + bv: []string{"bert", "ernie", "bigbird"}, + }, + want: []string{}, + want1: map[string]string{ + "ALICE": "alice", + "BOB": "bob", + "JANE": "jane", + "DOE": "doe", + }, + want2: map[string]string{ + "BERT": "bert", + "ERNIE": "ernie", + "BIGBIRD": "bigbird", + }, + }, + { + name: "empty-av", + args: args{ + av: []string{}, + bv: []string{"bert", "ernie", "bigbird"}, + }, + want: []string{}, + want1: map[string]string{}, + want2: map[string]string{ + "BERT": "bert", + "ERNIE": "ernie", + "BIGBIRD": "bigbird", + }, + }, + { + name: "empty-av-and-bv", + args: args{ + av: []string{}, + bv: []string{}, + }, + want: []string{}, + want1: map[string]string{}, + want2: map[string]string{}, + }, + { + name: "nil-av", + args: args{ + av: nil, + bv: []string{"bert", "ernie", "bigbird"}, + }, + want: nil, + want1: nil, + want2: nil, + wantErr: true, + wantErrMsg: "av is missing: nil parameter", + }, + { + name: "nil-bv", + args: args{ + av: []string{}, + bv: nil, + }, + want: nil, + want1: nil, + want2: nil, + wantErr: true, + wantErrMsg: "bv is missing: nil parameter", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + got, got1, got2, err := intersection(tt.args.av, tt.args.bv) + if err == nil && tt.wantErr { + assert.Error(err) + } + if tt.wantErr { + assert.Error(err) + assert.Equal(tt.wantErrMsg, err.Error()) + } + assert.Equal(tt.want, got) + assert.Equal(tt.want1, got1) + assert.Equal(tt.want2, got2) + }) + } +} + +func TestUpdateFields(t *testing.T) { + a := assert.New(t) + id, err := uuid.GenerateUUID() + a.NoError(err) + + type args struct { + i interface{} + fieldMaskPaths []string + setToNullPaths []string + } + tests := []struct { + name string + args args + want map[string]interface{} + wantErr bool + wantErrMsg string + }{ + { + name: "missing interface", + args: args{ + i: nil, + fieldMaskPaths: []string{}, + setToNullPaths: []string{}, + }, + want: nil, + wantErr: true, + wantErrMsg: "interface is missing: nil parameter", + }, + { + name: "missing fieldmasks", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: nil, + setToNullPaths: []string{}, + }, + want: nil, + wantErr: true, + wantErrMsg: "both fieldMaskPaths and setToNullPaths are zero len", + }, + { + name: "missing null fields", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"Name"}, + setToNullPaths: nil, + }, + want: map[string]interface{}{ + "Name": id, + }, + wantErr: false, + }, + { + name: "all zero len", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{}, + setToNullPaths: nil, + }, + wantErr: true, + wantErrMsg: "both fieldMaskPaths and setToNullPaths are zero len", + }, + { + name: "not found masks", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"invalidFieldName"}, + setToNullPaths: []string{}, + }, + want: nil, + wantErr: true, + wantErrMsg: "field mask paths not found in resource: [invalidFieldName]", + }, + { + name: "not found null paths", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"name"}, + setToNullPaths: []string{"invalidFieldName"}, + }, + want: nil, + wantErr: true, + wantErrMsg: "null paths not found in resource: [invalidFieldName]", + }, + { + name: "intersection", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"name"}, + setToNullPaths: []string{"name"}, + }, + want: nil, + wantErr: true, + wantErrMsg: "fieldMashPaths and setToNullPaths cannot intersect", + }, + { + name: "valid", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"name"}, + setToNullPaths: []string{"email"}, + }, + want: map[string]interface{}{ + "name": id, + "email": gorm.Expr("NULL"), + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "valid-just-masks", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{"name", "email"}, + setToNullPaths: []string{}, + }, + want: map[string]interface{}{ + "name": id, + "email": id, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "valid-just-nulls", + args: args{ + i: testUser(t, id, id), + fieldMaskPaths: []string{}, + setToNullPaths: []string{"name", "email"}, + }, + want: map[string]interface{}{ + "name": gorm.Expr("NULL"), + "email": gorm.Expr("NULL"), + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "valid-not-embedded", + args: args{ + i: db_test.StoreTestUser{ + PublicId: testPublicId(t), + Name: id, + Email: "", + }, + fieldMaskPaths: []string{"name"}, + setToNullPaths: []string{"email"}, + }, + want: map[string]interface{}{ + "name": id, + "email": gorm.Expr("NULL"), + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "valid-not-embedded-just-masks", + args: args{ + i: db_test.StoreTestUser{ + PublicId: testPublicId(t), + Name: id, + Email: "", + }, + fieldMaskPaths: []string{"name"}, + setToNullPaths: nil, + }, + want: map[string]interface{}{ + "name": id, + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "valid-not-embedded-just-nulls", + args: args{ + i: db_test.StoreTestUser{ + PublicId: testPublicId(t), + Name: id, + Email: "", + }, + fieldMaskPaths: nil, + setToNullPaths: []string{"email"}, + }, + want: map[string]interface{}{ + "email": gorm.Expr("NULL"), + }, + wantErr: false, + wantErrMsg: "", + }, + { + name: "not found null paths - not embedded", + args: args{ + i: db_test.StoreTestUser{ + PublicId: testPublicId(t), + Name: id, + Email: "", + }, + fieldMaskPaths: []string{"name"}, + setToNullPaths: []string{"invalidFieldName"}, + }, + want: nil, + wantErr: true, + wantErrMsg: "null paths not found in resource: [invalidFieldName]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + got, err := UpdateFields(tt.args.i, tt.args.fieldMaskPaths, tt.args.setToNullPaths) + if err == nil && tt.wantErr { + assert.Error(err) + } + if tt.wantErr { + assert.Error(err) + assert.Equal(tt.wantErrMsg, err.Error()) + } + assert.Equal(tt.want, got) + }) + } +} + +func testUser(t *testing.T, name, email string) *db_test.TestUser { + t.Helper() + return &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + PublicId: testPublicId(t), + Name: name, + Email: email, + }, + } +} + +func testPublicId(t *testing.T) string { + t.Helper() + publicId, err := base62.Random(20) + assert.NoError(t, err) + return publicId +} diff --git a/internal/db/domains_test.go b/internal/db/domains_test.go index 3ded10a4d7..ac0f319d4d 100644 --- a/internal/db/domains_test.go +++ b/internal/db/domains_test.go @@ -30,8 +30,11 @@ returning id; t.Error(err) } }() - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() db := conn.DB() if _, err := db.Exec(createTable); err != nil { t.Fatalf("query: \n%s\n error: %s", createTable, err) @@ -94,8 +97,11 @@ returning id; t.Error(err) } }() - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() db := conn.DB() if _, err := db.Exec(createTable); err != nil { t.Fatalf("query: \n%s\n error: %s", createTable, err) diff --git a/internal/db/migrations/postgres.gen.go b/internal/db/migrations/postgres.gen.go index e5849d4511..f7e88d45ea 100644 --- a/internal/db/migrations/postgres.gen.go +++ b/internal/db/migrations/postgres.gen.go @@ -478,6 +478,36 @@ before insert on iam_scope for each row execute procedure default_create_time(); + +-- iam_sub_names will allow us to enforce the different name constraints for +-- organizations and projects via a before update trigger on the iam_scope +-- table. +create or replace function + iam_sub_names() + returns trigger +as $$ +begin + if new.name != old.name then + if new.type = 'organization' then + update iam_scope_organization set name = new.name where scope_id = old.public_id; + return new; + end if; + if new.type = 'project' then + update iam_scope_project set name = new.name where scope_id = old.public_id; + return new; + end if; + raise exception 'unknown scope type'; + end if; + return new; +end; +$$ language plpgsql; + +create trigger + iam_sub_names +before +update on iam_scope + for each row execute procedure iam_sub_names(); + commit; `), diff --git a/internal/db/migrations/postgres/04_iam.up.sql b/internal/db/migrations/postgres/04_iam.up.sql index 1722bd1dd9..ad07499853 100644 --- a/internal/db/migrations/postgres/04_iam.up.sql +++ b/internal/db/migrations/postgres/04_iam.up.sql @@ -110,4 +110,34 @@ before insert on iam_scope for each row execute procedure default_create_time(); + +-- iam_sub_names will allow us to enforce the different name constraints for +-- organizations and projects via a before update trigger on the iam_scope +-- table. +create or replace function + iam_sub_names() + returns trigger +as $$ +begin + if new.name != old.name then + if new.type = 'organization' then + update iam_scope_organization set name = new.name where scope_id = old.public_id; + return new; + end if; + if new.type = 'project' then + update iam_scope_project set name = new.name where scope_id = old.public_id; + return new; + end if; + raise exception 'unknown scope type'; + end if; + return new; +end; +$$ language plpgsql; + +create trigger + iam_sub_names +before +update on iam_scope + for each row execute procedure iam_sub_names(); + commit; diff --git a/internal/db/option.go b/internal/db/option.go index fd9c9622e5..39c1ea669a 100644 --- a/internal/db/option.go +++ b/internal/db/option.go @@ -25,6 +25,8 @@ type Options struct { withLookup bool // WithFieldMaskPaths must be accessible from other packages WithFieldMaskPaths []string + // WithNullPaths must be accessible from other packages + WithNullPaths []string } type oplogOpts struct { @@ -42,6 +44,7 @@ func getDefaultOptions() Options { withDebug: false, withLookup: false, WithFieldMaskPaths: []string{}, + WithNullPaths: []string{}, } } @@ -70,9 +73,16 @@ func WithOplog(wrapper wrapping.Wrapper, md oplog.Metadata) Option { } } -// WithOplog provides an option to provide field mask paths +// WithFieldMaskPaths provides an option to provide field mask paths func WithFieldMaskPaths(paths []string) Option { return func(o *Options) { o.WithFieldMaskPaths = paths } } + +// WithNullPaths provides an option to provide null paths +func WithNullPaths(paths []string) Option { + return func(o *Options) { + o.WithNullPaths = paths + } +} diff --git a/internal/db/option_test.go b/internal/db/option_test.go index f5895d93fd..3465ee94e1 100644 --- a/internal/db/option_test.go +++ b/internal/db/option_test.go @@ -60,4 +60,30 @@ func Test_getOpts(t *testing.T) { testOpts.withDebug = true assert.Equal(opts, testOpts) }) + t.Run("WithFieldMaskPaths", func(t *testing.T) { + // test default of []string{} + opts := GetOpts() + testOpts := getDefaultOptions() + testOpts.WithFieldMaskPaths = []string{} + assert.Equal(opts, testOpts) + + testPaths := []string{"alice", "bob"} + opts = GetOpts(WithFieldMaskPaths(testPaths)) + testOpts = getDefaultOptions() + testOpts.WithFieldMaskPaths = testPaths + assert.Equal(opts, testOpts) + }) + t.Run("WithNullPaths", func(t *testing.T) { + // test default of []string{} + opts := GetOpts() + testOpts := getDefaultOptions() + testOpts.WithNullPaths = []string{} + assert.Equal(opts, testOpts) + + testPaths := []string{"alice", "bob"} + opts = GetOpts(WithNullPaths(testPaths)) + testOpts = getDefaultOptions() + testOpts.WithNullPaths = testPaths + assert.Equal(opts, testOpts) + }) } diff --git a/internal/db/read_writer.go b/internal/db/read_writer.go index c3a520f2cb..742a67b36e 100644 --- a/internal/db/read_writer.go +++ b/internal/db/read_writer.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/hashicorp/watchtower/internal/db/common" "github.com/hashicorp/watchtower/internal/oplog" "github.com/jinzhu/gorm" "google.golang.org/protobuf/proto" @@ -42,13 +43,17 @@ type Writer interface { // DoTx will wrap the TxHandler in a retryable transaction DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error) - // Update an object in the db, if there's a fieldMask then only the - // field_mask.proto paths are updated, otherwise it will send every field to - // the DB. options: WithOplog the caller is responsible for the transaction - // life cycle of the writer and if an error is returned the caller must - // decide what to do with the transaction, which almost always should be to - // rollback. Update returns the number of rows updated or an error. - Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) + // Update an object in the db, fieldMask is required and provides + // field_mask.proto paths for fields that should be updated. The i interface + // parameter is the type the caller wants to update in the db and its + // fields are set to the update values. setToNullPaths is optional and + // provides field_mask.proto paths for the fields that should be set to + // null. fieldMaskPaths and setToNullPaths must not intersect. The caller + // is responsible for the transaction life cycle of the writer and if an + // error is returned the caller must decide what to do with the transaction, + // which almost always should be to rollback. Update returns the number of + // rows updated or an error. Supported options: WithOplog. + Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) // Create an object in the db with options: WithOplog // the caller is responsible for the transaction life cycle of the writer @@ -99,7 +104,10 @@ const ( DeleteOp OpType = 3 ) -// VetForWriter provides an interface that Create and Update can use to vet the resource before sending it to the db +// VetForWriter provides an interface that Create and Update can use to vet the +// resource before before writing it to the db. For optType == UpdateOp, +// options WithFieldMaskPath and WithNullPaths are supported. For optType == +// CreateOp, no options are supported type VetForWriter interface { VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error } @@ -130,7 +138,7 @@ func (rw *Db) ScanRows(rows *sql.Rows, result interface{}) error { if rw.underlying == nil { return fmt.Errorf("scan rows: missing underlying db %w", ErrNilParameter) } - if result == nil { + if isNil(result) { return fmt.Errorf("scan rows: result is missing %w", ErrNilParameter) } return rw.underlying.ScanRows(rows, result) @@ -152,12 +160,13 @@ func (rw *Db) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option return errors.New("not a resource with an id") } -// Create an object in the db with options: WithOplog and WithLookup (to force a lookup after create)) +// Create an object in the db with options: WithOplog and WithLookup (to force a +// lookup after create). func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { if rw.underlying == nil { return fmt.Errorf("create: missing underlying db %w", ErrNilParameter) } - if i == nil { + if isNil(i) { return fmt.Errorf("create: interface is missing %w", ErrNilParameter) } opts := GetOpts(opt...) @@ -174,9 +183,6 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { rw.underlying.LogMode(true) defer rw.underlying.LogMode(false) } - if i == nil { - return fmt.Errorf("create: missing interface %w", ErrNilParameter) - } // these fields should be nil, since they are not writeable and we want the // db to manage them setFieldsToNil(i, []string{"CreateTime", "UpdateTime"}) @@ -200,19 +206,40 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error { return nil } -// Update an object in the db, if there's a fieldMask then only the -// field_mask.proto paths are updated, otherwise it will send every field to the -// DB. Update supports embedding a struct (or structPtr) one level deep for -// updating. Update returns the number of rows updated and any errors. -func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) { +// Update an object in the db, fieldMask is required and provides +// field_mask.proto paths for fields that should be updated. The i interface +// parameter is the type the caller wants to update in the db and its +// fields are set to the update values. setToNullPaths is optional and +// provides field_mask.proto paths for the fields that should be set to +// null. fieldMaskPaths and setToNullPaths must not intersect. The caller +// is responsible for the transaction life cycle of the writer and if an +// error is returned the caller must decide what to do with the transaction, +// which almost always should be to rollback. Update returns the number of +// rows updated. Supported options: WithOplog. +func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) { if rw.underlying == nil { return NoRowsAffected, fmt.Errorf("update: missing underlying db %w", ErrNilParameter) } - if i == nil { + if isNil(i) { return NoRowsAffected, fmt.Errorf("update: interface is missing %w", ErrNilParameter) } - if len(fieldMaskPaths) == 0 { - return NoRowsAffected, fmt.Errorf("update: missing fieldMaskPaths %w", ErrNilParameter) + if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { + return NoRowsAffected, errors.New("update: both fieldMaskPaths and setToNullPaths are missing") + } + + // we need to filter out some non-updatable fields (like: CreateTime, etc) + fieldMaskPaths = filterPaths(fieldMaskPaths) + setToNullPaths = filterPaths(setToNullPaths) + if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 { + return NoRowsAffected, fmt.Errorf("update: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths") + } + + updateFields, err := common.UpdateFields(i, fieldMaskPaths, setToNullPaths) + if err != nil { + return NoRowsAffected, fmt.Errorf("update: getting update fields failed: %w", err) + } + if len(updateFields) == 0 { + return NoRowsAffected, fmt.Errorf("update: no fields matched using fieldMaskPaths %s", fieldMaskPaths) } // This is not a watchtower scope, but rather a gorm Scope: @@ -237,54 +264,11 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string defer rw.underlying.LogMode(false) } if vetter, ok := i.(VetForWriter); ok { - if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths)); err != nil { + if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths), WithNullPaths(setToNullPaths)); err != nil { return NoRowsAffected, fmt.Errorf("update: vet for write failed %w", err) } } - // we need to filter out some non-updatable fields (like: CreateTime, etc) - fieldMaskPaths = filterFieldMask(fieldMaskPaths) - if len(fieldMaskPaths) == 0 { - return NoRowsAffected, fmt.Errorf("update: after filtering non-updated fields, there are no fields left in fieldMaskPaths: %s", fieldMaskPaths) - } - - updateFields := map[string]interface{}{} - - val := reflect.Indirect(reflect.ValueOf(i)) - structTyp := val.Type() - for _, field := range fieldMaskPaths { - for i := 0; i < structTyp.NumField(); i++ { - kind := structTyp.Field(i).Type.Kind() - if kind == reflect.Struct || kind == reflect.Ptr { - embType := structTyp.Field(i).Type - // check if the embedded field is exported via CanInterface() - if val.Field(i).CanInterface() { - embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface())) - // if it's a ptr to a struct, then we need a few more bits before proceeding. - if kind == reflect.Ptr { - embVal = val.Field(i).Elem() - embType = embVal.Type() - if embType.Kind() != reflect.Struct { - continue - } - } - for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ { - if strings.EqualFold(embType.Field(embFieldNum).Name, field) { - updateFields[field] = embVal.Field(embFieldNum).Interface() - } - } - continue - } - } - // it's not an embedded type, so check if the field name matches - if strings.EqualFold(structTyp.Field(i).Name, field) { - updateFields[field] = val.Field(i).Interface() - } - } - } - if len(updateFields) == 0 { - return NoRowsAffected, fmt.Errorf("update: no fields matched using fieldMaskPaths %s", fieldMaskPaths) - } underlying := rw.underlying.Model(i).Updates(updateFields) if underlying.Error != nil { return NoRowsAffected, fmt.Errorf("update: failed %w", underlying.Error) @@ -311,7 +295,7 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er if rw.underlying == nil { return NoRowsAffected, fmt.Errorf("delete: missing underlying db %w", ErrNilParameter) } - if i == nil { + if isNil(i) { return NoRowsAffected, fmt.Errorf("delete: interface is missing %w", ErrNilParameter) } // This is not a watchtower scope, but rather a gorm Scope: @@ -539,13 +523,13 @@ func (rw *Db) SearchWhere(ctx context.Context, resources interface{}, where stri return nil } -// filterFieldMasks will filter out non-updatable fields -func filterFieldMask(fieldMaskPaths []string) []string { - if len(fieldMaskPaths) == 0 { +// filterPaths will filter out non-updatable fields +func filterPaths(paths []string) []string { + if len(paths) == 0 { return nil } filtered := []string{} - for _, p := range fieldMaskPaths { + for _, p := range paths { switch { case strings.EqualFold(p, "CreateTime"): continue @@ -599,3 +583,14 @@ func setFieldsToNil(i interface{}, fieldNames []string) { } } } + +func isNil(i interface{}) bool { + if i == nil { + return true + } + switch reflect.TypeOf(i).Kind() { + case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice: + return reflect.ValueOf(i).IsNil() + } + return false +} diff --git a/internal/db/read_writer_test.go b/internal/db/read_writer_test.go index f72b0e94a2..3b634ea3d4 100644 --- a/internal/db/read_writer_test.go +++ b/internal/db/read_writer_test.go @@ -4,150 +4,277 @@ import ( "context" "errors" "fmt" + "strings" "testing" + "time" "github.com/golang/protobuf/ptypes" "github.com/hashicorp/go-uuid" + "github.com/hashicorp/vault/sdk/helper/base62" "github.com/hashicorp/watchtower/internal/db/db_test" "github.com/hashicorp/watchtower/internal/oplog" - "github.com/hashicorp/watchtower/internal/oplog/store" + "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" ) func TestDb_Update(t *testing.T) { - // intentionally not run with t.Parallel so we don't need to use DoTx for the Update tests cleanup, db, _ := TestSetup(t, "postgres") + now := &db_test.Timestamp{Timestamp: ptypes.TimestampNow()} + publicId, err := NewPublicId("testuser") + if err != nil { + t.Error(err) + } defer func() { if err := cleanup(); err != nil { t.Error(err) } }() + defer func() { + if err := db.Close(); err != nil { + t.Error(err) + } + }() + id := testId(t) + type args struct { + i *db_test.TestUser + fieldMaskPaths []string + setToNullPaths []string + opt []Option + } + tests := []struct { + name string + args args + want int + wantErr bool + wantErrMsg string + wantName string + wantEmail string + wantPhoneNumber string + }{ + { + name: "simple", + args: args{ + i: &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + Name: "simple-updated" + id, + Email: "updated" + id, + PhoneNumber: "updated" + id, + }, + }, + fieldMaskPaths: []string{"Name", "PhoneNumber"}, + setToNullPaths: []string{"Email"}, + }, + want: 1, + wantErr: false, + wantErrMsg: "", + wantName: "simple-updated" + id, + wantEmail: "", + wantPhoneNumber: "updated" + id, + }, + { + name: "multiple-null", + args: args{ + i: &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + Name: "multiple-null-updated" + id, + Email: "updated" + id, + PhoneNumber: "updated" + id, + }, + }, + fieldMaskPaths: []string{"Name"}, + setToNullPaths: []string{"Email", "PhoneNumber"}, + }, + want: 1, + wantErr: false, + wantErrMsg: "", + wantName: "multiple-null-updated" + id, + wantEmail: "", + wantPhoneNumber: "", + }, + { + name: "non-updatable", + args: args{ + i: &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + Name: "non-updatable" + id, + Email: "updated" + id, + PhoneNumber: "updated" + id, + PublicId: publicId, + CreateTime: now, + UpdateTime: now, + }, + }, + fieldMaskPaths: []string{"Name", "PhoneNumber", "CreateTime", "UpdateTime", "PublicId"}, + setToNullPaths: []string{"Email"}, + }, + want: 1, + wantErr: false, + wantErrMsg: "", + wantName: "non-updatable" + id, + wantEmail: "", + wantPhoneNumber: "updated" + id, + }, + { + name: "both are missing", + args: args{ + i: &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + Name: "both are missing-updated" + id, + Email: id, + PhoneNumber: id, + }, + }, + fieldMaskPaths: nil, + setToNullPaths: []string{}, + }, + want: 0, + wantErr: true, + wantErrMsg: "update: both fieldMaskPaths and setToNullPaths are missing", + }, + { + name: "i is nil", + args: args{ + i: nil, + fieldMaskPaths: []string{"Name", "PhoneNumber"}, + setToNullPaths: []string{"Email"}, + }, + want: 0, + wantErr: true, + wantErrMsg: "update: interface is missing nil parameter", + }, + { + name: "only read-only", + args: args{ + i: &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + Name: "only read-only" + id, + Email: id, + PhoneNumber: id, + }, + }, + fieldMaskPaths: []string{"CreateTime"}, + setToNullPaths: []string{"UpdateTime"}, + }, + want: 0, + wantErr: true, + wantErrMsg: "update: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + rw := &Db{ + underlying: db, + } + u := testUser(t, db, tt.name+id, id, id) + + if tt.args.i != nil { + tt.args.i.Id = u.Id + tt.args.i.PublicId = u.PublicId + } + rowsUpdated, err := rw.Update(context.Background(), tt.args.i, tt.args.fieldMaskPaths, tt.args.setToNullPaths, tt.args.opt...) + if tt.wantErr { + assert.Error(err) + assert.Equal(tt.want, rowsUpdated) + assert.Equal(tt.wantErrMsg, err.Error()) + return + } + assert.NoError(err) + assert.Equal(tt.want, rowsUpdated) + + foundUser, err := db_test.NewTestUser() + assert.NoError(err) + foundUser.PublicId = tt.args.i.PublicId + where := "public_id = ?" + for _, f := range tt.args.setToNullPaths { + switch { + case strings.EqualFold(f, "phonenumber"): + f = "phone_number" + } + where = fmt.Sprintf("%s and %s is null", where, f) + } + err = rw.LookupWhere(context.Background(), foundUser, where, tt.args.i.PublicId) + assert.NoError(err) + assert.Equal(tt.args.i.Id, foundUser.Id) + assert.Equal(tt.wantName, foundUser.Name) + assert.Equal(tt.wantEmail, foundUser.Email) + assert.Equal(tt.wantPhoneNumber, foundUser.PhoneNumber) + assert.NotEqual(now, foundUser.CreateTime) + assert.NotEqual(now, foundUser.UpdateTime) + assert.NotEqual(publicId, foundUser.PublicId) + }) + } + assert := assert.New(t) - defer db.Close() - t.Run("simple", func(t *testing.T) { + t.Run("valid-WithOplog", func(t *testing.T) { w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - err = w.Create(context.Background(), user) - assert.NoError(err) - assert.NotEmpty(user.Id) - - foundUser, err := db_test.NewTestUser() - assert.NoError(err) - foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) - assert.NoError(err) - assert.Equal(foundUser.Id, user.Id) + user := testUser(t, db, "foo-"+id, id, id) user.Name = "friendly-" + id - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil, + // write oplogs for this update + WithOplog( + TestWrapper(t), + oplog.Metadata{ + "key-only": nil, + "deployment": []string{"amex"}, + "project": []string{"central-info-systems", "local-info-systems"}, + "resource-public-id": []string{user.PublicId}, + "op-type": []string{oplog.OpType_OP_TYPE_UPDATE.String()}, + }), + ) assert.NoError(err) assert.Equal(1, rowsUpdated) - err = w.LookupByPublicId(context.Background(), foundUser) - assert.NoError(err) - assert.Equal(foundUser.Name, user.Name) - }) - t.Run("non-updatable-fields", func(t *testing.T) { - w := Db{underlying: db} - id, err := uuid.GenerateUUID() - assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - err = w.Create(context.Background(), user) - db.LogMode(false) - assert.NoError(err) - assert.NotEmpty(user.Id) - foundUser, err := db_test.NewTestUser() assert.NoError(err) foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) - assert.NoError(err) - assert.Equal(foundUser.Id, user.Id) - - user.Name = "friendly-" + id - ts := &db_test.Timestamp{Timestamp: ptypes.TimestampNow()} - user.CreateTime = ts - user.UpdateTime = ts - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name", "CreateTime", "UpdateTime"}) - assert.NoError(err) - assert.Equal(1, rowsUpdated) - err = w.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.Equal(foundUser.Name, user.Name) - assert.NotEqual(foundUser.CreateTime, ts) - assert.NotEqual(foundUser.UpdateTime, ts) - ts = &db_test.Timestamp{Timestamp: ptypes.TimestampNow()} - user.Name = id - user.CreateTime = ts - user.UpdateTime = ts - rowsUpdated, err = w.Update(context.Background(), user, nil) - assert.Error(err) - assert.Equal(0, rowsUpdated) - assert.Equal("update: missing fieldMaskPaths nil parameter", err.Error()) - assert.NotEqual(foundUser.CreateTime, ts) - assert.NotEqual(foundUser.UpdateTime, ts) + err = TestVerifyOplog(t, &w, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_UPDATE), WithCreateNotBefore(10*time.Second)) + assert.NoError(err) }) - t.Run("valid-WithOplog", func(t *testing.T) { + t.Run("vet-for-write", func(t *testing.T) { w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - err = w.Create(context.Background(), user) - assert.NoError(err) - assert.NotEmpty(user.Id) - - foundUser, err := db_test.NewTestUser() - assert.NoError(err) - foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) + user := &testUserWithVet{ + PublicId: id, + Name: id, + PhoneNumber: id, + Email: id, + } + err = db.Create(user).Error assert.NoError(err) - assert.Equal(foundUser.Id, user.Id) user.Name = "friendly-" + id - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, - // write oplogs for this update - WithOplog( - TestWrapper(t), - oplog.Metadata{ - "key-only": nil, - "deployment": []string{"amex"}, - "project": []string{"central-info-systems", "local-info-systems"}, - "resource-public-id": []string{user.GetPublicId()}, - }), - ) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) assert.NoError(err) assert.Equal(1, rowsUpdated) + foundUser := &testUserWithVet{PublicId: user.PublicId} + assert.NoError(err) + foundUser.PublicId = user.PublicId err = w.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.Equal(foundUser.Name, user.Name) - var metadata store.Metadata - err = db.Where("key = ? and value = ?", "resource-public-id", user.PublicId).First(&metadata).Error - assert.NoError(err) - - var foundEntry oplog.Entry - err = db.Where("id = ?", metadata.EntryId).First(&foundEntry).Error - assert.NoError(err) + user.PublicId = "not-allowed-by-vet-for-write" + rowsUpdated, err = w.Update(context.Background(), user, []string{"PublicId"}, nil) + assert.Error(err) + assert.Equal(0, rowsUpdated) }) t.Run("nil-tx", func(t *testing.T) { w := Db{underlying: nil} id, err := uuid.GenerateUUID() assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - rowsUpdated, err := w.Update(context.Background(), user, nil) + + user := testUser(t, db, "foo-"+id, id, id) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) assert.Error(err) assert.Equal(0, rowsUpdated) assert.Equal("update: missing underlying db nil parameter", err.Error()) @@ -156,22 +283,11 @@ func TestDb_Update(t *testing.T) { w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - err = w.Create(context.Background(), user) - assert.NoError(err) - assert.NotEmpty(user.Id) - - foundUser, err := db_test.NewTestUser() - assert.NoError(err) - foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) - assert.NoError(err) - assert.Equal(foundUser.Id, user.Id) + user := testUser(t, db, "foo-"+id, id, id) user.Name = "friendly-" + id rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, + nil, WithOplog( nil, oplog.Metadata{ @@ -188,22 +304,9 @@ func TestDb_Update(t *testing.T) { w := Db{underlying: db} id, err := uuid.GenerateUUID() assert.NoError(err) - user, err := db_test.NewTestUser() - assert.NoError(err) - user.Name = "foo-" + id - err = w.Create(context.Background(), user) - assert.NoError(err) - assert.NotEmpty(user.Id) - - foundUser, err := db_test.NewTestUser() - assert.NoError(err) - foundUser.PublicId = user.PublicId - err = w.LookupByPublicId(context.Background(), foundUser) - assert.NoError(err) - assert.Equal(foundUser.Id, user.Id) - + user := testUser(t, db, "foo-"+id, id, id) user.Name = "friendly-" + id - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil, WithOplog( TestWrapper(t), nil, @@ -215,6 +318,43 @@ func TestDb_Update(t *testing.T) { }) } +// testUserWithVet gives us a model that implements VetForWrite() without any +// cyclic dependencies. +type testUserWithVet struct { + Id uint32 `gorm:"primary_key"` + PublicId string + Name string `gorm:"default:null"` + PhoneNumber string `gorm:"default:null"` + Email string `gorm:"default:null"` +} + +func (u *testUserWithVet) GetPublicId() string { + return u.PublicId +} +func (u *testUserWithVet) TableName() string { + return "db_test_user" +} +func (u *testUserWithVet) VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error { + if u.PublicId == "" { + return errors.New("public id is empty string for user write") + } + if opType == UpdateOp { + dbOptions := GetOpts(opt...) + for _, path := range dbOptions.WithFieldMaskPaths { + switch path { + case "PublicId": + return errors.New("you cannot change the public id") + } + } + } + if opType == CreateOp { + if u.Id != 0 { + return errors.New("id is a db sequence") + } + } + return nil +} + func TestDb_Create(t *testing.T) { // intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests cleanup, db, _ := TestSetup(t, "postgres") @@ -700,7 +840,7 @@ func TestDb_DoTx(t *testing.T) { _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { user.Name = "friendly-" + id - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) if err != nil { return err } @@ -722,7 +862,7 @@ func TestDb_DoTx(t *testing.T) { assert.NoError(err) _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { user2.Name = "friendly2-" + id - rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}) + rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}, nil) if err != nil { return err } @@ -738,7 +878,7 @@ func TestDb_DoTx(t *testing.T) { _, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error { user.Name = "friendly2-" + id - rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}) + rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil) if err != nil { return err } @@ -974,3 +1114,32 @@ func TestDb_ScanRows(t *testing.T) { } }) } + +func testUser(t *testing.T, conn *gorm.DB, name, email, phoneNumber string) *db_test.TestUser { + t.Helper() + assert := assert.New(t) + + publicId, err := base62.Random(20) + assert.NoError(err) + u := &db_test.TestUser{ + StoreTestUser: &db_test.StoreTestUser{ + PublicId: publicId, + Name: name, + Email: email, + PhoneNumber: phoneNumber, + }, + } + if conn != nil { + err = conn.Create(u).Error + assert.NoError(err) + } + return u +} + +func testId(t *testing.T) string { + t.Helper() + assert := assert.New(t) + id, err := uuid.GenerateUUID() + assert.NoError(err) + return id +} diff --git a/internal/db/testing_test.go b/internal/db/testing_test.go index de95829844..5ae181eb77 100644 --- a/internal/db/testing_test.go +++ b/internal/db/testing_test.go @@ -2,7 +2,6 @@ package db import ( "context" - "strconv" "testing" "time" @@ -17,8 +16,16 @@ import ( func Test_Utils(t *testing.T) { t.Parallel() cleanup, conn, _ := TestSetup(t, "postgres") - defer cleanup() - defer conn.Close() + defer func() { + if err := cleanup(); err != nil { + t.Error(err) + } + }() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("nothing", func(t *testing.T) { }) @@ -26,7 +33,11 @@ func Test_Utils(t *testing.T) { func TestVerifyOplogEntry(t *testing.T) { cleanup, db, _ := TestSetup(t, "postgres") - defer cleanup() + defer func() { + if err := cleanup(); err != nil { + t.Error(err) + } + }() assert := assert.New(t) defer db.Close() @@ -43,7 +54,7 @@ func TestVerifyOplogEntry(t *testing.T) { WithOplog( TestWrapper(t), oplog.Metadata{ - "op-type": []string{strconv.Itoa(int(oplog.OpType_OP_TYPE_CREATE))}, + "op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()}, "resource-public-id": []string{user.GetPublicId()}, }), ) @@ -56,7 +67,7 @@ func TestVerifyOplogEntry(t *testing.T) { err = rw.LookupByPublicId(context.Background(), foundUser) assert.NoError(err) assert.Equal(foundUser.Id, user.Id) - TestVerifyOplog(t, &rw, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(5*time.Second)) + err = TestVerifyOplog(t, &rw, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(5*time.Second)) assert.NoError(err) }) t.Run("should-fail", func(t *testing.T) { diff --git a/internal/host/static/repository.go b/internal/host/static/repository.go index 8a9f552083..c574d63db0 100644 --- a/internal/host/static/repository.go +++ b/internal/host/static/repository.go @@ -116,10 +116,18 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas return nil, db.NoRowsAffected, fmt.Errorf("update: static host catalog: %w", db.ErrEmptyFieldMask) } + var dbMask, nullFields []string for _, f := range fieldMask { switch { - case strings.EqualFold("name", f): - case strings.EqualFold("description", f): + case strings.EqualFold("name", f) && c.Name == "": + nullFields = append(nullFields, "name") + case strings.EqualFold("name", f) && c.Name != "": + dbMask = append(dbMask, "name") + case strings.EqualFold("description", f) && c.Description == "": + nullFields = append(nullFields, "description") + case strings.EqualFold("description", f) && c.Description != "": + dbMask = append(dbMask, "description") + default: return nil, db.NoRowsAffected, fmt.Errorf("update: static host catalog: field: %s: %w", f, db.ErrInvalidFieldMask) } @@ -128,8 +136,6 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas metadata := newCatalogMetadata(c, oplog.OpType_OP_TYPE_UPDATE) - // TODO(mgaffney,jimlambrt) 05/2020: uncomment the nullFields line - // below once support for setting columns to nil is added to db.Update. var rowsUpdated int var returnedCatalog *HostCatalog _, err := r.writer.DoTx( @@ -142,8 +148,8 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas rowsUpdated, err = w.Update( ctx, returnedCatalog, - fieldMask, - // nullFields, + dbMask, + nullFields, db.WithOplog(r.wrapper, metadata), ) if err == nil && rowsUpdated > 1 { diff --git a/internal/host/static/repository_test.go b/internal/host/static/repository_test.go index 3008c3f851..f0e5b7c1fc 100644 --- a/internal/host/static/repository_test.go +++ b/internal/host/static/repository_test.go @@ -453,74 +453,76 @@ func TestRepository_UpdateCatalog(t *testing.T) { }, wantCount: 1, }, - // TODO(mgaffney,jimlambrt) 05/2020: enable these tests once - // support for setting columns to nil is added to db.Update. - /* - { - name: "delete-name", - orig: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - Description: "test-description-repo", - }, + { + name: "delete-name", + orig: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", + Description: "test-description-repo", }, - masks: []string{"Name"}, - chgFn: combine(changeDescription("test-update-description-repo"), changeName("")), - want: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Description: "test-description-repo", - }, + }, + masks: []string{"Name"}, + chgFn: combine(changeDescription("test-update-description-repo"), changeName("")), + want: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Description: "test-description-repo", }, }, - { - name: "delete-description", - orig: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - Description: "test-description-repo", - }, + wantCount: 1, + }, + { + name: "delete-description", + orig: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", + Description: "test-description-repo", }, - masks: []string{"Description"}, - chgFn: combine(changeDescription(""), changeName("test-update-name-repo")), - want: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - }, + }, + masks: []string{"Description"}, + chgFn: combine(changeDescription(""), changeName("test-update-name-repo")), + want: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", }, }, - { - name: "do-not-delete-name", - orig: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - Description: "test-description-repo", - }, + wantCount: 1, + }, + { + name: "do-not-delete-name", + orig: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", + Description: "test-description-repo", }, - chgFn: combine(changeDescription("test-update-description-repo"), changeName("")), - want: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - Description: "test-update-description-repo", - }, + }, + masks: []string{"Description"}, + chgFn: combine(changeDescription("test-update-description-repo"), changeName("")), + want: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", + Description: "test-update-description-repo", }, }, - { - name: "do-not-delete-description", - orig: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-name-repo", - Description: "test-description-repo", - }, + wantCount: 1, + }, + { + name: "do-not-delete-description", + orig: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-name-repo", + Description: "test-description-repo", }, - chgFn: combine(changeDescription(""), changeName("test-update-name-repo")), - want: &HostCatalog{ - HostCatalog: &store.HostCatalog{ - Name: "test-update-name-repo", - Description: "test-description-repo", - }, + }, + masks: []string{"Name"}, + chgFn: combine(changeDescription(""), changeName("test-update-name-repo")), + want: &HostCatalog{ + HostCatalog: &store.HostCatalog{ + Name: "test-update-name-repo", + Description: "test-description-repo", }, }, - */ + wantCount: 1, + }, } for _, tt := range tests { diff --git a/internal/iam/repository.go b/internal/iam/repository.go index 7f69d1e4b8..8810c771a9 100644 --- a/internal/iam/repository.go +++ b/internal/iam/repository.go @@ -4,6 +4,8 @@ import ( "context" "errors" "fmt" + "reflect" + "strings" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/watchtower/internal/db" @@ -72,7 +74,7 @@ func (r *Repository) create(ctx context.Context, resource Resource, opt ...Optio } // update will update an iam resource in the db repository with an oplog entry -func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, int, error) { +func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (Resource, int, error) { if resource == nil { return nil, db.NoRowsAffected, errors.New("error updating resource that is nil") } @@ -99,6 +101,7 @@ func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPat ctx, returnedResource, fieldMaskPaths, + setToNullPaths, db.WithOplog(r.wrapper, metadata), ) if err == nil && rowsUpdated > 1 { @@ -152,24 +155,27 @@ func (r *Repository) delete(ctx context.Context, resource Resource, opt ...Optio func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog.Metadata, error) { if s, ok := resource.(*Scope); ok { - if s.Type == "" { - if err := r.reader.LookupByPublicId(ctx, s); err != nil { + scope := allocScope() + scope.PublicId = s.PublicId + scope.Type = s.Type + if scope.Type == "" { + if err := r.reader.LookupByPublicId(ctx, &scope); err != nil { return nil, ErrMetadataScopeNotFound } } - switch s.Type { + switch scope.Type { case OrganizationScope.String(): return oplog.Metadata{ "resource-public-id": []string{resource.GetPublicId()}, - "scope-id": []string{s.PublicId}, - "scope-type": []string{s.Type}, + "scope-id": []string{scope.PublicId}, + "scope-type": []string{scope.Type}, "resource-type": []string{resource.ResourceType().String()}, }, nil case ProjectScope.String(): return oplog.Metadata{ "resource-public-id": []string{resource.GetPublicId()}, - "scope-id": []string{s.ParentId}, - "scope-type": []string{s.Type}, + "scope-id": []string{scope.ParentId}, + "scope-type": []string{scope.Type}, "resource-type": []string{resource.ResourceType().String()}, }, nil default: @@ -191,3 +197,31 @@ func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog. "resource-type": []string{resource.ResourceType().String()}, }, nil } + +func contains(ss []string, t string) bool { + for _, s := range ss { + if strings.EqualFold(s, t) { + return true + } + } + return false +} + +func buildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string) (masks []string, nulls []string) { + for f, v := range fieldValues { + if !contains(fieldMask, f) { + continue + } + switch { + case isZero(v): + nulls = append(nulls, f) + default: + masks = append(masks, f) + } + } + return masks, nulls +} + +func isZero(i interface{}) bool { + return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) +} diff --git a/internal/iam/repository_scope.go b/internal/iam/repository_scope.go index 9467521520..96c978c95a 100644 --- a/internal/iam/repository_scope.go +++ b/internal/iam/repository_scope.go @@ -20,17 +20,41 @@ func (r *Repository) CreateScope(ctx context.Context, scope *Scope, opt ...Optio return resource.(*Scope), nil } -// UpdateScope will update a scope in the repository and return the written scope +// UpdateScope will update a scope in the repository and return the written +// scope. fieldMaskPaths provides field_mask.proto paths for fields that should +// be updated. Fields will be set to NULL if the field is a zero value and +// included in fieldMask. Name and Description are the only updatable fields, +// and everything else is ignored. If no updatable fields are included in the +// fieldMaskPaths, then an error is returned. func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, int, error) { if scope == nil { - return nil, db.NoRowsAffected, errors.New("scope is nil for update") + return nil, db.NoRowsAffected, fmt.Errorf("update scope: missing scope: %w", db.ErrNilParameter) } if scope.PublicId == "" { - return nil, db.NoRowsAffected, errors.New("scope public id is unset for update") + return nil, db.NoRowsAffected, fmt.Errorf("update scope: missing public id: %w", db.ErrNilParameter) } - resource, rowsUpdated, err := r.update(ctx, scope, fieldMaskPaths) + if contains(fieldMaskPaths, "ParentId") { + return nil, db.NoRowsAffected, fmt.Errorf("update scope: you cannot change a scope's parent: %w", db.ErrInvalidFieldMask) + } + var dbMask, nullFields []string + dbMask, nullFields = buildUpdatePaths( + map[string]interface{}{ + "name": scope.Name, + "description": scope.Description, + }, + fieldMaskPaths, + ) + // nada to update, so reload scope from db and return it + if len(dbMask) == 0 && len(nullFields) == 0 { + return nil, db.NoRowsAffected, fmt.Errorf("update scope: %w", db.ErrEmptyFieldMask) + } + + resource, rowsUpdated, err := r.update(ctx, scope, dbMask, nullFields) if err != nil { - return nil, db.NoRowsAffected, fmt.Errorf("failed to update scope: %w", err) + if db.IsUnique(err) { + return nil, db.NoRowsAffected, fmt.Errorf("update scope: %s name %s already exists: %w", scope.PublicId, scope.Name, db.ErrNotUnique) + } + return nil, db.NoRowsAffected, fmt.Errorf("update scope: failed for public id %s: %w", scope.PublicId, err) } return resource.(*Scope), rowsUpdated, err } diff --git a/internal/iam/repository_scope_test.go b/internal/iam/repository_scope_test.go index 9f94317e7f..e0cfdc23ee 100644 --- a/internal/iam/repository_scope_test.go +++ b/internal/iam/repository_scope_test.go @@ -2,11 +2,14 @@ package iam import ( "context" + "fmt" "testing" "time" + "github.com/golang/protobuf/ptypes" "github.com/hashicorp/go-uuid" "github.com/hashicorp/watchtower/internal/db" + iam_store "github.com/hashicorp/watchtower/internal/iam/store" "github.com/hashicorp/watchtower/internal/oplog" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" @@ -21,7 +24,11 @@ func Test_Repository_CreateScope(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid-scope", func(t *testing.T) { rw := db.New(conn) @@ -107,7 +114,11 @@ func Test_Repository_UpdateScope(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid-scope", func(t *testing.T) { rw := db.New(conn) @@ -176,7 +187,7 @@ func Test_Repository_UpdateScope(t *testing.T) { assert.Error(err) assert.Nil(project) assert.Equal(0, updatedRows) - assert.Equal("failed to update scope: update: vet for write failed you cannot change a scope's parent", err.Error()) + assert.Equal("update scope: you cannot change a scope's parent: invalid field mask", err.Error()) }) } @@ -188,8 +199,12 @@ func Test_Repository_LookupScope(t *testing.T) { t.Error(err) } }() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() assert := assert.New(t) - defer conn.Close() t.Run("found-and-not-found", func(t *testing.T) { rw := db.New(conn) @@ -228,7 +243,11 @@ func Test_Repository_DeleteScope(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() rw := db.New(conn) wrapper := db.TestWrapper(t) repo, err := NewRepository(rw, rw, wrapper) @@ -268,3 +287,231 @@ func Test_Repository_DeleteScope(t *testing.T) { assert.Equal(0, rowsDeleted) }) } + +func TestRepository_UpdateScope(t *testing.T) { + cleanup, conn, _ := db.TestSetup(t, "postgres") + now := &iam_store.Timestamp{Timestamp: ptypes.TimestampNow()} + id := testId(t) + defer func() { + if err := cleanup(); err != nil { + t.Error(err) + } + }() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + rw := db.New(conn) + wrapper := db.TestWrapper(t) + publicId := testPublicId(t, "o") + + type args struct { + scope *Scope + fieldMaskPaths []string + opt []Option + } + tests := []struct { + name string + args args + wantName string + wantDescription string + wantUpdatedRows int + wantErr bool + wantErrMsg string + wantNullFields []string + }{ + { + name: "valid-scope", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "valid-scope" + id, + Description: "", + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name", "Description", "CreateTime", "UpdateTime", "PublicId"}, + }, + wantName: "valid-scope" + id, + wantDescription: "", + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + wantNullFields: []string{"Description"}, + }, + { + name: "nil-resource", + args: args{ + scope: nil, + fieldMaskPaths: []string{"Name"}, + }, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "update scope: missing scope: nil parameter", + wantNullFields: nil, + }, + { + name: "no-updates", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "no-updates" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"CreateTime"}, + }, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "update scope: empty field mask", + wantNullFields: nil, + }, + { + name: "no-null", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "no-null" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name"}, + }, + wantName: "no-null" + id, + wantDescription: "orig-" + id, + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + wantNullFields: nil, + }, + { + name: "only-null", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "", + Description: "", + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name", "Description"}, + }, + wantName: "", + wantDescription: "", + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + wantNullFields: nil, + }, + { + name: "parent", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "parent" + id, + Description: "", + ParentId: publicId, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"ParentId", "CreateTime", "UpdateTime", "PublicId"}, + }, + wantName: "parent-orig-" + id, + wantDescription: "orig-" + id, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "update scope: you cannot change a scope's parent: invalid field mask", + wantNullFields: nil, + }, + { + name: "type", + args: args{ + scope: &Scope{ + Scope: &iam_store.Scope{ + Name: "type" + id, + Description: "", + Type: ProjectScope.String(), + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Type", "CreateTime", "UpdateTime", "PublicId"}, + }, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "update scope: empty field mask", + wantNullFields: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + r := &Repository{ + reader: rw, + writer: rw, + wrapper: wrapper, + } + org := testOrg(t, conn, tt.name+"-orig-"+id, "orig-"+id) + if tt.args.scope != nil { + tt.args.scope.PublicId = org.PublicId + } + updatedScope, rowsUpdated, err := r.UpdateScope(context.Background(), tt.args.scope, tt.args.fieldMaskPaths, tt.args.opt...) + if tt.wantErr { + assert.Error(err) + assert.Equal(tt.wantUpdatedRows, rowsUpdated) + assert.Equal(tt.wantErrMsg, err.Error()) + return + } + assert.NoError(err) + assert.Equal(tt.wantUpdatedRows, rowsUpdated) + if tt.wantUpdatedRows > 0 { + err = db.TestVerifyOplog(t, rw, org.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + } + + foundScope := allocScope() + foundScope.PublicId = updatedScope.PublicId + where := "public_id = ?" + for _, f := range tt.wantNullFields { + where = fmt.Sprintf("%s and %s is null", where, f) + } + err = rw.LookupWhere(context.Background(), &foundScope, where, org.PublicId) + assert.NoError(err) + assert.Equal(org.PublicId, foundScope.PublicId) + assert.Equal(tt.wantName, foundScope.Name) + assert.Equal(tt.wantDescription, foundScope.Description) + assert.NotEqual(now, foundScope.CreateTime) + assert.NotEqual(now, foundScope.UpdateTime) + }) + } + t.Run("dup-name", func(t *testing.T) { + assert := assert.New(t) + r := &Repository{ + reader: rw, + writer: rw, + wrapper: wrapper, + } + id := testId(t) + _ = testOrg(t, conn, id, id) + org2 := testOrg(t, conn, "dup-"+id, id) + org2.Name = id + updatedScope, rowsUpdated, err := r.UpdateScope(context.Background(), org2, []string{"Name"}) + assert.Error(err) + assert.Equal(0, rowsUpdated, "updated rows should be 0") + assert.Nil(updatedScope, "scope should be nil") + }) +} diff --git a/internal/iam/repository_test.go b/internal/iam/repository_test.go index a9859d0c19..979a374f22 100644 --- a/internal/iam/repository_test.go +++ b/internal/iam/repository_test.go @@ -2,13 +2,16 @@ package iam import ( "context" + "fmt" "reflect" "testing" "time" + "github.com/golang/protobuf/ptypes" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/hashicorp/go-uuid" "github.com/hashicorp/watchtower/internal/db" + iam_store "github.com/hashicorp/watchtower/internal/iam/store" "github.com/hashicorp/watchtower/internal/oplog" "github.com/hashicorp/watchtower/internal/oplog/store" "github.com/stretchr/testify/assert" @@ -24,8 +27,11 @@ func TestNewRepository(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() rw := db.New(conn) wrapper := db.TestWrapper(t) type args struct { @@ -113,8 +119,11 @@ func Test_Repository_create(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid-scope", func(t *testing.T) { rw := db.New(conn) wrapper := db.TestWrapper(t) @@ -155,7 +164,7 @@ func Test_Repository_create(t *testing.T) { }) } -func Test_Repository_update(t *testing.T) { +func Test_Repository_delete(t *testing.T) { t.Parallel() cleanup, conn, _ := db.TestSetup(t, "postgres") defer func() { @@ -164,67 +173,11 @@ func Test_Repository_update(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - - t.Run("valid-scope", func(t *testing.T) { - rw := db.New(conn) - wrapper := db.TestWrapper(t) - repo, err := NewRepository(rw, rw, wrapper) - assert.NoError(err) - id, err := uuid.GenerateUUID() - assert.NoError(err) - - s, err := NewOrganization() - assert.NoError(err) - retScope, err := repo.create(context.Background(), s) - assert.NoError(err) - assert.NotNil(retScope) - assert.NotEmpty(retScope.GetPublicId()) - assert.Empty(retScope.GetName()) - - retScope.(*Scope).Name = "fname-" + id - retScope, updatedRows, err := repo.update(context.Background(), retScope, []string{"Name"}) - assert.NoError(err) - assert.NotNil(retScope) - assert.Equal(1, updatedRows) - assert.Equal(retScope.GetName(), "fname-"+id) - - foundScope, err := repo.LookupScope(context.Background(), s.PublicId) - assert.NoError(err) - assert.Equal(foundScope.GetPublicId(), retScope.GetPublicId()) - - var metadata store.Metadata - err = conn.Where("key = ? and value = ?", "resource-public-id", s.PublicId).First(&metadata).Error - assert.NoError(err) - - var foundEntry oplog.Entry - err = conn.Where("id = ?", metadata.EntryId).First(&foundEntry).Error - assert.NoError(err) - }) - t.Run("nil-resource", func(t *testing.T) { - rw := db.New(conn) - wrapper := db.TestWrapper(t) - repo, err := NewRepository(rw, rw, wrapper) - assert.NoError(err) - resource, updatedRows, err := repo.update(context.Background(), nil, nil) - assert.NotNil(err) - assert.Nil(resource) - assert.Equal(0, updatedRows) - assert.Equal(err.Error(), "error updating resource that is nil") - }) -} - -func Test_Repository_delete(t *testing.T) { - t.Parallel() - cleanup, conn, _ := db.TestSetup(t, "postgres") defer func() { - if err := cleanup(); err != nil { + if err := conn.Close(); err != nil { t.Error(err) } }() - assert := assert.New(t) - defer conn.Close() - t.Run("valid-org", func(t *testing.T) { rw := db.New(conn) wrapper := db.TestWrapper(t) @@ -257,3 +210,171 @@ func Test_Repository_delete(t *testing.T) { assert.Equal(err.Error(), "error deleting resource that is nil") }) } + +func TestRepository_update(t *testing.T) { + cleanup, conn, _ := db.TestSetup(t, "postgres") + now := &iam_store.Timestamp{Timestamp: ptypes.TimestampNow()} + id := testId(t) + defer func() { + if err := cleanup(); err != nil { + t.Error(err) + } + }() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() + rw := db.New(conn) + wrapper := db.TestWrapper(t) + publicId := testPublicId(t, "o") + + type args struct { + resource Resource + fieldMaskPaths []string + setToNullPaths []string + opt []Option + } + tests := []struct { + name string + args args + wantName string + wantDescription string + wantUpdatedRows int + wantErr bool + wantErrMsg string + }{ + { + name: "valid-scope", + args: args{ + resource: &Scope{ + Scope: &iam_store.Scope{ + Name: "valid-scope" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name"}, + setToNullPaths: []string{"Description"}, + }, + wantName: "valid-scope" + id, + wantDescription: "", + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + }, + { + name: "nil-resource", + args: args{ + resource: nil, + fieldMaskPaths: []string{"Name"}, + setToNullPaths: []string{"Description"}, + }, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "error updating resource that is nil", + }, + { + name: "intersection", + args: args{ + resource: &Scope{ + Scope: &iam_store.Scope{ + Name: "intersection" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name"}, + setToNullPaths: []string{"Name"}, + }, + wantUpdatedRows: 0, + wantErr: true, + wantErrMsg: "update: getting update fields failed: fieldMashPaths and setToNullPaths cannot intersect", + }, + { + name: "only-field-masks", + args: args{ + resource: &Scope{ + Scope: &iam_store.Scope{ + Name: "only-field-masks" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: []string{"Name"}, + setToNullPaths: nil, + }, + wantName: "only-field-masks" + id, + wantDescription: "orig-" + id, + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + }, + { + name: "only-null-fields", + args: args{ + resource: &Scope{ + Scope: &iam_store.Scope{ + Name: "only-null-fields" + id, + Description: "updated" + id, + CreateTime: now, + UpdateTime: now, + PublicId: publicId, + }, + }, + fieldMaskPaths: nil, + setToNullPaths: []string{"Name", "Description"}, + }, + wantName: "", + wantDescription: "", + wantUpdatedRows: 1, + wantErr: false, + wantErrMsg: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + r := &Repository{ + reader: rw, + writer: rw, + wrapper: wrapper, + } + org := testOrg(t, conn, tt.name+"-orig-"+id, "orig-"+id) + if tt.args.resource != nil { + tt.args.resource.(*Scope).PublicId = org.PublicId + } + updatedResource, rowsUpdated, err := r.update(context.Background(), tt.args.resource, tt.args.fieldMaskPaths, tt.args.setToNullPaths, tt.args.opt...) + if tt.wantErr { + assert.Error(err) + assert.Equal(tt.wantUpdatedRows, rowsUpdated) + assert.Equal(tt.wantErrMsg, err.Error()) + return + } + assert.NoError(err) + assert.Equal(tt.wantUpdatedRows, rowsUpdated) + err = db.TestVerifyOplog(t, rw, org.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second)) + assert.NoError(err) + + foundResource := allocScope() + foundResource.PublicId = updatedResource.GetPublicId() + where := "public_id = ?" + for _, f := range tt.args.setToNullPaths { + where = fmt.Sprintf("%s and %s is null", where, f) + } + err = rw.LookupWhere(context.Background(), &foundResource, where, tt.args.resource.GetPublicId()) + assert.NoError(err) + assert.Equal(tt.args.resource.GetPublicId(), foundResource.GetPublicId()) + assert.Equal(tt.wantName, foundResource.GetName()) + assert.Equal(tt.wantDescription, foundResource.GetDescription()) + assert.NotEqual(now, foundResource.GetCreateTime()) + assert.NotEqual(now, foundResource.GetUpdateTime()) + }) + } +} diff --git a/internal/iam/scope_test.go b/internal/iam/scope_test.go index 297a811a1f..8a97643e2b 100644 --- a/internal/iam/scope_test.go +++ b/internal/iam/scope_test.go @@ -19,8 +19,11 @@ func Test_NewScope(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid-org-with-project", func(t *testing.T) { w := db.New(conn) s, err := NewOrganization() @@ -60,8 +63,11 @@ func Test_ScopeCreate(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid", func(t *testing.T) { w := db.New(conn) s, err := NewOrganization() @@ -104,8 +110,11 @@ func Test_ScopeUpdate(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid", func(t *testing.T) { w := db.New(conn) s, err := NewOrganization() @@ -118,7 +127,7 @@ func Test_ScopeUpdate(t *testing.T) { id, err := uuid.GenerateUUID() assert.NoError(err) s.Name = id - updatedRows, err := w.Update(context.Background(), s, []string{"Name"}) + updatedRows, err := w.Update(context.Background(), s, []string{"Name"}, nil) assert.NoError(err) assert.Equal(1, updatedRows) }) @@ -132,7 +141,7 @@ func Test_ScopeUpdate(t *testing.T) { assert.NotEmpty(s.PublicId) s.Type = ProjectScope.String() - updatedRows, err := w.Update(context.Background(), s, []string{"Type"}) + updatedRows, err := w.Update(context.Background(), s, []string{"Type"}, nil) assert.NotNil(err) assert.Equal(0, updatedRows) }) @@ -146,7 +155,11 @@ func Test_ScopeGetScope(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid-scope", func(t *testing.T) { w := db.New(conn) s, err := NewOrganization() @@ -212,8 +225,11 @@ func TestScope_Clone(t *testing.T) { } }() assert := assert.New(t) - defer conn.Close() - + defer func() { + if err := conn.Close(); err != nil { + t.Error(err) + } + }() t.Run("valid", func(t *testing.T) { w := db.New(conn) s, err := NewOrganization() diff --git a/internal/iam/testing.go b/internal/iam/testing.go index 8651324fa4..e2c9b45504 100644 --- a/internal/iam/testing.go +++ b/internal/iam/testing.go @@ -4,6 +4,7 @@ import ( "context" "testing" + "github.com/hashicorp/go-uuid" "github.com/hashicorp/watchtower/internal/db" "github.com/jinzhu/gorm" "github.com/stretchr/testify/assert" @@ -32,3 +33,36 @@ func TestScopes(t *testing.T, conn *gorm.DB) (org *Scope, prj *Scope) { return } + +func testOrg(t *testing.T, conn *gorm.DB, name, description string) (org *Scope) { + t.Helper() + assert := assert.New(t) + rw := db.New(conn) + wrapper := db.TestWrapper(t) + repo, err := NewRepository(rw, rw, wrapper) + assert.NoError(err) + + o, err := NewOrganization(WithDescription(description), WithName(name)) + assert.NoError(err) + o, err = repo.CreateScope(context.Background(), o) + assert.NoError(err) + assert.NotNil(o) + assert.NotEmpty(o.GetPublicId()) + return o +} + +func testId(t *testing.T) string { + t.Helper() + assert := assert.New(t) + id, err := uuid.GenerateUUID() + assert.NoError(err) + return id +} + +func testPublicId(t *testing.T, prefix string) string { + t.Helper() + assert := assert.New(t) + publicId, err := db.NewPublicId(prefix) + assert.NoError(err) + return publicId +} diff --git a/internal/iam/testing_test.go b/internal/iam/testing_test.go index 896207e6c5..7461809023 100644 --- a/internal/iam/testing_test.go +++ b/internal/iam/testing_test.go @@ -1,12 +1,45 @@ package iam import ( + "strings" "testing" "github.com/hashicorp/watchtower/internal/db" "github.com/stretchr/testify/assert" ) +func Test_testOrg(t *testing.T) { + assert := assert.New(t) + + cleanup, conn, _ := db.TestSetup(t, "postgres") + defer func() { + err := cleanup() + assert.NoError(err) + }() + defer func() { + err := conn.Close() + assert.NoError(err) + }() + id := testId(t) + + org := testOrg(t, conn, id, id) + assert.Equal(id, org.Name) + assert.Equal(id, org.Description) + assert.NotEmpty(org.PublicId) +} + +func Test_testId(t *testing.T) { + assert := assert.New(t) + id := testId(t) + assert.NotEmpty(id) +} + +func Test_testPublicId(t *testing.T) { + assert := assert.New(t) + id := testPublicId(t, "test") + assert.NotEmpty(id) + assert.True(strings.HasPrefix(id, "test_")) +} func Test_TestScopes(t *testing.T) { assert := assert.New(t)