Null updates (#78)

pull/95/head
Jim 6 years ago committed by GitHub
parent 6d123d91bb
commit ef6edbd515
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,143 @@
// common package contains functions from internal/db which need to be shared
// commonly with other packages that have a cyclic dependency on internal/db
// like internal/oplog.
package common
import (
"errors"
"fmt"
"reflect"
"strings"
"github.com/jinzhu/gorm"
)
var (
// ErrNilParameter is returned when a required parameter is nil.
ErrNilParameter = errors.New("nil parameter")
)
// UpdateFields will create a map[string]interface of the update values to be
// sent to the db. The map keys will be the field names for the fields to be
// updated. The caller provided fieldMaskPaths and setToNullPaths must not
// intersect. fieldMaskPaths and setToNullPaths cannot both be zero len.
func UpdateFields(i interface{}, fieldMaskPaths []string, setToNullPaths []string) (map[string]interface{}, error) {
if i == nil {
return nil, fmt.Errorf("interface is missing: %w", ErrNilParameter)
}
if fieldMaskPaths == nil {
fieldMaskPaths = []string{}
}
if setToNullPaths == nil {
setToNullPaths = []string{}
}
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return nil, errors.New("both fieldMaskPaths and setToNullPaths are zero len")
}
inter, maskPaths, nullPaths, err := intersection(fieldMaskPaths, setToNullPaths)
if err != nil {
return nil, err
}
if len(inter) != 0 {
return nil, fmt.Errorf("fieldMashPaths and setToNullPaths cannot intersect")
}
updateFields := map[string]interface{}{} // case sensitive update fields to values
found := map[string]struct{}{} // we need something to keep track of found fields (case insensitive)
val := reflect.Indirect(reflect.ValueOf(i))
structTyp := val.Type()
for i := 0; i < structTyp.NumField(); i++ {
kind := structTyp.Field(i).Type.Kind()
if kind == reflect.Struct || kind == reflect.Ptr {
embType := structTyp.Field(i).Type
// check if the embedded field is exported via CanInterface()
if val.Field(i).CanInterface() {
embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface()))
// if it's a ptr to a struct, then we need a few more bits before proceeding.
if kind == reflect.Ptr {
embVal = val.Field(i).Elem()
if !embVal.IsValid() {
continue
}
embType = embVal.Type()
if embType.Kind() != reflect.Struct {
continue
}
}
for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ {
if f, ok := maskPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok {
updateFields[f] = embVal.Field(embFieldNum).Interface()
found[strings.ToUpper(f)] = struct{}{}
}
if f, ok := nullPaths[strings.ToUpper(embType.Field(embFieldNum).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
}
}
continue
}
}
if f, ok := maskPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = val.Field(i).Interface()
found[strings.ToUpper(f)] = struct{}{}
}
if f, ok := nullPaths[strings.ToUpper(structTyp.Field(i).Name)]; ok {
updateFields[f] = gorm.Expr("NULL")
found[strings.ToUpper(f)] = struct{}{}
}
}
if missing := findMissingPaths(setToNullPaths, found); len(missing) != 0 {
return nil, fmt.Errorf("null paths not found in resource: %s", missing)
}
if missing := findMissingPaths(fieldMaskPaths, found); len(missing) != 0 {
return nil, fmt.Errorf("field mask paths not found in resource: %s", missing)
}
return updateFields, nil
}
func findMissingPaths(paths []string, foundPaths map[string]struct{}) []string {
notFound := []string{}
for _, f := range paths {
if _, ok := foundPaths[strings.ToUpper(f)]; !ok {
notFound = append(notFound, f)
}
}
return notFound
}
// intersection is a case-insensitive search for intersecting values. Returns
// []string of the intersection with values in lowercase, and map[string]string
// of the original av and bv, with the key set to uppercase and value set to the
// original
func intersection(av, bv []string) ([]string, map[string]string, map[string]string, error) {
if av == nil {
return nil, nil, nil, fmt.Errorf("av is missing: %w", ErrNilParameter)
}
if bv == nil {
return nil, nil, nil, fmt.Errorf("bv is missing: %w", ErrNilParameter)
}
if len(av) == 0 && len(bv) == 0 {
return []string{}, map[string]string{}, map[string]string{}, nil
}
s := []string{}
ah := map[string]string{}
bh := map[string]string{}
for i := 0; i < len(av); i++ {
ah[strings.ToUpper(av[i])] = av[i]
}
for i := 0; i < len(bv); i++ {
k := strings.ToUpper(bv[i])
bh[k] = bv[i]
if _, found := ah[k]; found {
s = append(s, strings.ToLower(bh[k]))
}
}
return s, ah, bh, nil
}

@ -0,0 +1,419 @@
package common
import (
"testing"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/watchtower/internal/db/db_test"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func Test_intersection(t *testing.T) {
type args struct {
av []string
bv []string
}
tests := []struct {
name string
args args
want []string
want1 map[string]string
want2 map[string]string
wantErr bool
wantErrMsg string
}{
{
name: "intersect",
args: args{
av: []string{"alice"},
bv: []string{"alice", "bob"},
},
want: []string{"alice"},
want1: map[string]string{
"ALICE": "alice",
},
want2: map[string]string{
"ALICE": "alice",
"BOB": "bob",
},
},
{
name: "intersect-2",
args: args{
av: []string{"alice", "bob", "jane", "doe"},
bv: []string{"alice", "doe", "bert", "ernie", "bigbird"},
},
want: []string{"alice", "doe"},
want1: map[string]string{
"ALICE": "alice",
"BOB": "bob",
"JANE": "jane",
"DOE": "doe",
},
want2: map[string]string{
"ALICE": "alice",
"DOE": "doe",
"BERT": "bert",
"ERNIE": "ernie",
"BIGBIRD": "bigbird",
},
},
{
name: "intersect-mixed-case",
args: args{
av: []string{"AlicE"},
bv: []string{"alICe", "Bob"},
},
want: []string{"alice"},
want1: map[string]string{
"ALICE": "AlicE",
},
want2: map[string]string{
"ALICE": "alICe",
"BOB": "Bob",
},
},
{
name: "no-intersect-mixed-case",
args: args{
av: []string{"AliCe", "BOb", "jaNe", "DOE"},
bv: []string{"beRt", "ERnie", "bigBIRD"},
},
want: []string{},
want1: map[string]string{
"ALICE": "AliCe",
"BOB": "BOb",
"JANE": "jaNe",
"DOE": "DOE",
},
want2: map[string]string{
"BERT": "beRt",
"ERNIE": "ERnie",
"BIGBIRD": "bigBIRD",
},
},
{
name: "no-intersect-1",
args: args{
av: []string{"alice", "bob", "jane", "doe"},
bv: []string{"bert", "ernie", "bigbird"},
},
want: []string{},
want1: map[string]string{
"ALICE": "alice",
"BOB": "bob",
"JANE": "jane",
"DOE": "doe",
},
want2: map[string]string{
"BERT": "bert",
"ERNIE": "ernie",
"BIGBIRD": "bigbird",
},
},
{
name: "empty-av",
args: args{
av: []string{},
bv: []string{"bert", "ernie", "bigbird"},
},
want: []string{},
want1: map[string]string{},
want2: map[string]string{
"BERT": "bert",
"ERNIE": "ernie",
"BIGBIRD": "bigbird",
},
},
{
name: "empty-av-and-bv",
args: args{
av: []string{},
bv: []string{},
},
want: []string{},
want1: map[string]string{},
want2: map[string]string{},
},
{
name: "nil-av",
args: args{
av: nil,
bv: []string{"bert", "ernie", "bigbird"},
},
want: nil,
want1: nil,
want2: nil,
wantErr: true,
wantErrMsg: "av is missing: nil parameter",
},
{
name: "nil-bv",
args: args{
av: []string{},
bv: nil,
},
want: nil,
want1: nil,
want2: nil,
wantErr: true,
wantErrMsg: "bv is missing: nil parameter",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, got1, got2, err := intersection(tt.args.av, tt.args.bv)
if err == nil && tt.wantErr {
assert.Error(err)
}
if tt.wantErr {
assert.Error(err)
assert.Equal(tt.wantErrMsg, err.Error())
}
assert.Equal(tt.want, got)
assert.Equal(tt.want1, got1)
assert.Equal(tt.want2, got2)
})
}
}
func TestUpdateFields(t *testing.T) {
a := assert.New(t)
id, err := uuid.GenerateUUID()
a.NoError(err)
type args struct {
i interface{}
fieldMaskPaths []string
setToNullPaths []string
}
tests := []struct {
name string
args args
want map[string]interface{}
wantErr bool
wantErrMsg string
}{
{
name: "missing interface",
args: args{
i: nil,
fieldMaskPaths: []string{},
setToNullPaths: []string{},
},
want: nil,
wantErr: true,
wantErrMsg: "interface is missing: nil parameter",
},
{
name: "missing fieldmasks",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: nil,
setToNullPaths: []string{},
},
want: nil,
wantErr: true,
wantErrMsg: "both fieldMaskPaths and setToNullPaths are zero len",
},
{
name: "missing null fields",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"Name"},
setToNullPaths: nil,
},
want: map[string]interface{}{
"Name": id,
},
wantErr: false,
},
{
name: "all zero len",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{},
setToNullPaths: nil,
},
wantErr: true,
wantErrMsg: "both fieldMaskPaths and setToNullPaths are zero len",
},
{
name: "not found masks",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"invalidFieldName"},
setToNullPaths: []string{},
},
want: nil,
wantErr: true,
wantErrMsg: "field mask paths not found in resource: [invalidFieldName]",
},
{
name: "not found null paths",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"name"},
setToNullPaths: []string{"invalidFieldName"},
},
want: nil,
wantErr: true,
wantErrMsg: "null paths not found in resource: [invalidFieldName]",
},
{
name: "intersection",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"name"},
setToNullPaths: []string{"name"},
},
want: nil,
wantErr: true,
wantErrMsg: "fieldMashPaths and setToNullPaths cannot intersect",
},
{
name: "valid",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"name"},
setToNullPaths: []string{"email"},
},
want: map[string]interface{}{
"name": id,
"email": gorm.Expr("NULL"),
},
wantErr: false,
wantErrMsg: "",
},
{
name: "valid-just-masks",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{"name", "email"},
setToNullPaths: []string{},
},
want: map[string]interface{}{
"name": id,
"email": id,
},
wantErr: false,
wantErrMsg: "",
},
{
name: "valid-just-nulls",
args: args{
i: testUser(t, id, id),
fieldMaskPaths: []string{},
setToNullPaths: []string{"name", "email"},
},
want: map[string]interface{}{
"name": gorm.Expr("NULL"),
"email": gorm.Expr("NULL"),
},
wantErr: false,
wantErrMsg: "",
},
{
name: "valid-not-embedded",
args: args{
i: db_test.StoreTestUser{
PublicId: testPublicId(t),
Name: id,
Email: "",
},
fieldMaskPaths: []string{"name"},
setToNullPaths: []string{"email"},
},
want: map[string]interface{}{
"name": id,
"email": gorm.Expr("NULL"),
},
wantErr: false,
wantErrMsg: "",
},
{
name: "valid-not-embedded-just-masks",
args: args{
i: db_test.StoreTestUser{
PublicId: testPublicId(t),
Name: id,
Email: "",
},
fieldMaskPaths: []string{"name"},
setToNullPaths: nil,
},
want: map[string]interface{}{
"name": id,
},
wantErr: false,
wantErrMsg: "",
},
{
name: "valid-not-embedded-just-nulls",
args: args{
i: db_test.StoreTestUser{
PublicId: testPublicId(t),
Name: id,
Email: "",
},
fieldMaskPaths: nil,
setToNullPaths: []string{"email"},
},
want: map[string]interface{}{
"email": gorm.Expr("NULL"),
},
wantErr: false,
wantErrMsg: "",
},
{
name: "not found null paths - not embedded",
args: args{
i: db_test.StoreTestUser{
PublicId: testPublicId(t),
Name: id,
Email: "",
},
fieldMaskPaths: []string{"name"},
setToNullPaths: []string{"invalidFieldName"},
},
want: nil,
wantErr: true,
wantErrMsg: "null paths not found in resource: [invalidFieldName]",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, err := UpdateFields(tt.args.i, tt.args.fieldMaskPaths, tt.args.setToNullPaths)
if err == nil && tt.wantErr {
assert.Error(err)
}
if tt.wantErr {
assert.Error(err)
assert.Equal(tt.wantErrMsg, err.Error())
}
assert.Equal(tt.want, got)
})
}
}
func testUser(t *testing.T, name, email string) *db_test.TestUser {
t.Helper()
return &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
PublicId: testPublicId(t),
Name: name,
Email: email,
},
}
}
func testPublicId(t *testing.T) string {
t.Helper()
publicId, err := base62.Random(20)
assert.NoError(t, err)
return publicId
}

@ -30,8 +30,11 @@ returning id;
t.Error(err)
}
}()
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
db := conn.DB()
if _, err := db.Exec(createTable); err != nil {
t.Fatalf("query: \n%s\n error: %s", createTable, err)
@ -94,8 +97,11 @@ returning id;
t.Error(err)
}
}()
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
db := conn.DB()
if _, err := db.Exec(createTable); err != nil {
t.Fatalf("query: \n%s\n error: %s", createTable, err)

@ -478,6 +478,36 @@ before
insert on iam_scope
for each row execute procedure default_create_time();
-- iam_sub_names will allow us to enforce the different name constraints for
-- organizations and projects via a before update trigger on the iam_scope
-- table.
create or replace function
iam_sub_names()
returns trigger
as $$
begin
if new.name != old.name then
if new.type = 'organization' then
update iam_scope_organization set name = new.name where scope_id = old.public_id;
return new;
end if;
if new.type = 'project' then
update iam_scope_project set name = new.name where scope_id = old.public_id;
return new;
end if;
raise exception 'unknown scope type';
end if;
return new;
end;
$$ language plpgsql;
create trigger
iam_sub_names
before
update on iam_scope
for each row execute procedure iam_sub_names();
commit;
`),

@ -110,4 +110,34 @@ before
insert on iam_scope
for each row execute procedure default_create_time();
-- iam_sub_names will allow us to enforce the different name constraints for
-- organizations and projects via a before update trigger on the iam_scope
-- table.
create or replace function
iam_sub_names()
returns trigger
as $$
begin
if new.name != old.name then
if new.type = 'organization' then
update iam_scope_organization set name = new.name where scope_id = old.public_id;
return new;
end if;
if new.type = 'project' then
update iam_scope_project set name = new.name where scope_id = old.public_id;
return new;
end if;
raise exception 'unknown scope type';
end if;
return new;
end;
$$ language plpgsql;
create trigger
iam_sub_names
before
update on iam_scope
for each row execute procedure iam_sub_names();
commit;

@ -25,6 +25,8 @@ type Options struct {
withLookup bool
// WithFieldMaskPaths must be accessible from other packages
WithFieldMaskPaths []string
// WithNullPaths must be accessible from other packages
WithNullPaths []string
}
type oplogOpts struct {
@ -42,6 +44,7 @@ func getDefaultOptions() Options {
withDebug: false,
withLookup: false,
WithFieldMaskPaths: []string{},
WithNullPaths: []string{},
}
}
@ -70,9 +73,16 @@ func WithOplog(wrapper wrapping.Wrapper, md oplog.Metadata) Option {
}
}
// WithOplog provides an option to provide field mask paths
// WithFieldMaskPaths provides an option to provide field mask paths
func WithFieldMaskPaths(paths []string) Option {
return func(o *Options) {
o.WithFieldMaskPaths = paths
}
}
// WithNullPaths provides an option to provide null paths
func WithNullPaths(paths []string) Option {
return func(o *Options) {
o.WithNullPaths = paths
}
}

@ -60,4 +60,30 @@ func Test_getOpts(t *testing.T) {
testOpts.withDebug = true
assert.Equal(opts, testOpts)
})
t.Run("WithFieldMaskPaths", func(t *testing.T) {
// test default of []string{}
opts := GetOpts()
testOpts := getDefaultOptions()
testOpts.WithFieldMaskPaths = []string{}
assert.Equal(opts, testOpts)
testPaths := []string{"alice", "bob"}
opts = GetOpts(WithFieldMaskPaths(testPaths))
testOpts = getDefaultOptions()
testOpts.WithFieldMaskPaths = testPaths
assert.Equal(opts, testOpts)
})
t.Run("WithNullPaths", func(t *testing.T) {
// test default of []string{}
opts := GetOpts()
testOpts := getDefaultOptions()
testOpts.WithNullPaths = []string{}
assert.Equal(opts, testOpts)
testPaths := []string{"alice", "bob"}
opts = GetOpts(WithNullPaths(testPaths))
testOpts = getDefaultOptions()
testOpts.WithNullPaths = testPaths
assert.Equal(opts, testOpts)
})
}

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/jinzhu/gorm"
"google.golang.org/protobuf/proto"
@ -42,13 +43,17 @@ type Writer interface {
// DoTx will wrap the TxHandler in a retryable transaction
DoTx(ctx context.Context, retries uint, backOff Backoff, Handler TxHandler) (RetryInfo, error)
// Update an object in the db, if there's a fieldMask then only the
// field_mask.proto paths are updated, otherwise it will send every field to
// the DB. options: WithOplog the caller is responsible for the transaction
// life cycle of the writer and if an error is returned the caller must
// decide what to do with the transaction, which almost always should be to
// rollback. Update returns the number of rows updated or an error.
Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error)
// Update an object in the db, fieldMask is required and provides
// field_mask.proto paths for fields that should be updated. The i interface
// parameter is the type the caller wants to update in the db and its
// fields are set to the update values. setToNullPaths is optional and
// provides field_mask.proto paths for the fields that should be set to
// null. fieldMaskPaths and setToNullPaths must not intersect. The caller
// is responsible for the transaction life cycle of the writer and if an
// error is returned the caller must decide what to do with the transaction,
// which almost always should be to rollback. Update returns the number of
// rows updated or an error. Supported options: WithOplog.
Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error)
// Create an object in the db with options: WithOplog
// the caller is responsible for the transaction life cycle of the writer
@ -99,7 +104,10 @@ const (
DeleteOp OpType = 3
)
// VetForWriter provides an interface that Create and Update can use to vet the resource before sending it to the db
// VetForWriter provides an interface that Create and Update can use to vet the
// resource before before writing it to the db. For optType == UpdateOp,
// options WithFieldMaskPath and WithNullPaths are supported. For optType ==
// CreateOp, no options are supported
type VetForWriter interface {
VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error
}
@ -130,7 +138,7 @@ func (rw *Db) ScanRows(rows *sql.Rows, result interface{}) error {
if rw.underlying == nil {
return fmt.Errorf("scan rows: missing underlying db %w", ErrNilParameter)
}
if result == nil {
if isNil(result) {
return fmt.Errorf("scan rows: result is missing %w", ErrNilParameter)
}
return rw.underlying.ScanRows(rows, result)
@ -152,12 +160,13 @@ func (rw *Db) lookupAfterWrite(ctx context.Context, i interface{}, opt ...Option
return errors.New("not a resource with an id")
}
// Create an object in the db with options: WithOplog and WithLookup (to force a lookup after create))
// Create an object in the db with options: WithOplog and WithLookup (to force a
// lookup after create).
func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
if rw.underlying == nil {
return fmt.Errorf("create: missing underlying db %w", ErrNilParameter)
}
if i == nil {
if isNil(i) {
return fmt.Errorf("create: interface is missing %w", ErrNilParameter)
}
opts := GetOpts(opt...)
@ -174,9 +183,6 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
rw.underlying.LogMode(true)
defer rw.underlying.LogMode(false)
}
if i == nil {
return fmt.Errorf("create: missing interface %w", ErrNilParameter)
}
// these fields should be nil, since they are not writeable and we want the
// db to manage them
setFieldsToNil(i, []string{"CreateTime", "UpdateTime"})
@ -200,19 +206,40 @@ func (rw *Db) Create(ctx context.Context, i interface{}, opt ...Option) error {
return nil
}
// Update an object in the db, if there's a fieldMask then only the
// field_mask.proto paths are updated, otherwise it will send every field to the
// DB. Update supports embedding a struct (or structPtr) one level deep for
// updating. Update returns the number of rows updated and any errors.
func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, opt ...Option) (int, error) {
// Update an object in the db, fieldMask is required and provides
// field_mask.proto paths for fields that should be updated. The i interface
// parameter is the type the caller wants to update in the db and its
// fields are set to the update values. setToNullPaths is optional and
// provides field_mask.proto paths for the fields that should be set to
// null. fieldMaskPaths and setToNullPaths must not intersect. The caller
// is responsible for the transaction life cycle of the writer and if an
// error is returned the caller must decide what to do with the transaction,
// which almost always should be to rollback. Update returns the number of
// rows updated. Supported options: WithOplog.
func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (int, error) {
if rw.underlying == nil {
return NoRowsAffected, fmt.Errorf("update: missing underlying db %w", ErrNilParameter)
}
if i == nil {
if isNil(i) {
return NoRowsAffected, fmt.Errorf("update: interface is missing %w", ErrNilParameter)
}
if len(fieldMaskPaths) == 0 {
return NoRowsAffected, fmt.Errorf("update: missing fieldMaskPaths %w", ErrNilParameter)
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return NoRowsAffected, errors.New("update: both fieldMaskPaths and setToNullPaths are missing")
}
// we need to filter out some non-updatable fields (like: CreateTime, etc)
fieldMaskPaths = filterPaths(fieldMaskPaths)
setToNullPaths = filterPaths(setToNullPaths)
if len(fieldMaskPaths) == 0 && len(setToNullPaths) == 0 {
return NoRowsAffected, fmt.Errorf("update: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths")
}
updateFields, err := common.UpdateFields(i, fieldMaskPaths, setToNullPaths)
if err != nil {
return NoRowsAffected, fmt.Errorf("update: getting update fields failed: %w", err)
}
if len(updateFields) == 0 {
return NoRowsAffected, fmt.Errorf("update: no fields matched using fieldMaskPaths %s", fieldMaskPaths)
}
// This is not a watchtower scope, but rather a gorm Scope:
@ -237,54 +264,11 @@ func (rw *Db) Update(ctx context.Context, i interface{}, fieldMaskPaths []string
defer rw.underlying.LogMode(false)
}
if vetter, ok := i.(VetForWriter); ok {
if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths)); err != nil {
if err := vetter.VetForWrite(ctx, rw, UpdateOp, WithFieldMaskPaths(fieldMaskPaths), WithNullPaths(setToNullPaths)); err != nil {
return NoRowsAffected, fmt.Errorf("update: vet for write failed %w", err)
}
}
// we need to filter out some non-updatable fields (like: CreateTime, etc)
fieldMaskPaths = filterFieldMask(fieldMaskPaths)
if len(fieldMaskPaths) == 0 {
return NoRowsAffected, fmt.Errorf("update: after filtering non-updated fields, there are no fields left in fieldMaskPaths: %s", fieldMaskPaths)
}
updateFields := map[string]interface{}{}
val := reflect.Indirect(reflect.ValueOf(i))
structTyp := val.Type()
for _, field := range fieldMaskPaths {
for i := 0; i < structTyp.NumField(); i++ {
kind := structTyp.Field(i).Type.Kind()
if kind == reflect.Struct || kind == reflect.Ptr {
embType := structTyp.Field(i).Type
// check if the embedded field is exported via CanInterface()
if val.Field(i).CanInterface() {
embVal := reflect.Indirect(reflect.ValueOf(val.Field(i).Interface()))
// if it's a ptr to a struct, then we need a few more bits before proceeding.
if kind == reflect.Ptr {
embVal = val.Field(i).Elem()
embType = embVal.Type()
if embType.Kind() != reflect.Struct {
continue
}
}
for embFieldNum := 0; embFieldNum < embType.NumField(); embFieldNum++ {
if strings.EqualFold(embType.Field(embFieldNum).Name, field) {
updateFields[field] = embVal.Field(embFieldNum).Interface()
}
}
continue
}
}
// it's not an embedded type, so check if the field name matches
if strings.EqualFold(structTyp.Field(i).Name, field) {
updateFields[field] = val.Field(i).Interface()
}
}
}
if len(updateFields) == 0 {
return NoRowsAffected, fmt.Errorf("update: no fields matched using fieldMaskPaths %s", fieldMaskPaths)
}
underlying := rw.underlying.Model(i).Updates(updateFields)
if underlying.Error != nil {
return NoRowsAffected, fmt.Errorf("update: failed %w", underlying.Error)
@ -311,7 +295,7 @@ func (rw *Db) Delete(ctx context.Context, i interface{}, opt ...Option) (int, er
if rw.underlying == nil {
return NoRowsAffected, fmt.Errorf("delete: missing underlying db %w", ErrNilParameter)
}
if i == nil {
if isNil(i) {
return NoRowsAffected, fmt.Errorf("delete: interface is missing %w", ErrNilParameter)
}
// This is not a watchtower scope, but rather a gorm Scope:
@ -539,13 +523,13 @@ func (rw *Db) SearchWhere(ctx context.Context, resources interface{}, where stri
return nil
}
// filterFieldMasks will filter out non-updatable fields
func filterFieldMask(fieldMaskPaths []string) []string {
if len(fieldMaskPaths) == 0 {
// filterPaths will filter out non-updatable fields
func filterPaths(paths []string) []string {
if len(paths) == 0 {
return nil
}
filtered := []string{}
for _, p := range fieldMaskPaths {
for _, p := range paths {
switch {
case strings.EqualFold(p, "CreateTime"):
continue
@ -599,3 +583,14 @@ func setFieldsToNil(i interface{}, fieldNames []string) {
}
}
}
func isNil(i interface{}) bool {
if i == nil {
return true
}
switch reflect.TypeOf(i).Kind() {
case reflect.Ptr, reflect.Map, reflect.Array, reflect.Chan, reflect.Slice:
return reflect.ValueOf(i).IsNil()
}
return false
}

@ -4,150 +4,277 @@ import (
"context"
"errors"
"fmt"
"strings"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/watchtower/internal/db/db_test"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/hashicorp/watchtower/internal/oplog/store"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
)
func TestDb_Update(t *testing.T) {
// intentionally not run with t.Parallel so we don't need to use DoTx for the Update tests
cleanup, db, _ := TestSetup(t, "postgres")
now := &db_test.Timestamp{Timestamp: ptypes.TimestampNow()}
publicId, err := NewPublicId("testuser")
if err != nil {
t.Error(err)
}
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
id := testId(t)
type args struct {
i *db_test.TestUser
fieldMaskPaths []string
setToNullPaths []string
opt []Option
}
tests := []struct {
name string
args args
want int
wantErr bool
wantErrMsg string
wantName string
wantEmail string
wantPhoneNumber string
}{
{
name: "simple",
args: args{
i: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
Name: "simple-updated" + id,
Email: "updated" + id,
PhoneNumber: "updated" + id,
},
},
fieldMaskPaths: []string{"Name", "PhoneNumber"},
setToNullPaths: []string{"Email"},
},
want: 1,
wantErr: false,
wantErrMsg: "",
wantName: "simple-updated" + id,
wantEmail: "",
wantPhoneNumber: "updated" + id,
},
{
name: "multiple-null",
args: args{
i: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
Name: "multiple-null-updated" + id,
Email: "updated" + id,
PhoneNumber: "updated" + id,
},
},
fieldMaskPaths: []string{"Name"},
setToNullPaths: []string{"Email", "PhoneNumber"},
},
want: 1,
wantErr: false,
wantErrMsg: "",
wantName: "multiple-null-updated" + id,
wantEmail: "",
wantPhoneNumber: "",
},
{
name: "non-updatable",
args: args{
i: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
Name: "non-updatable" + id,
Email: "updated" + id,
PhoneNumber: "updated" + id,
PublicId: publicId,
CreateTime: now,
UpdateTime: now,
},
},
fieldMaskPaths: []string{"Name", "PhoneNumber", "CreateTime", "UpdateTime", "PublicId"},
setToNullPaths: []string{"Email"},
},
want: 1,
wantErr: false,
wantErrMsg: "",
wantName: "non-updatable" + id,
wantEmail: "",
wantPhoneNumber: "updated" + id,
},
{
name: "both are missing",
args: args{
i: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
Name: "both are missing-updated" + id,
Email: id,
PhoneNumber: id,
},
},
fieldMaskPaths: nil,
setToNullPaths: []string{},
},
want: 0,
wantErr: true,
wantErrMsg: "update: both fieldMaskPaths and setToNullPaths are missing",
},
{
name: "i is nil",
args: args{
i: nil,
fieldMaskPaths: []string{"Name", "PhoneNumber"},
setToNullPaths: []string{"Email"},
},
want: 0,
wantErr: true,
wantErrMsg: "update: interface is missing nil parameter",
},
{
name: "only read-only",
args: args{
i: &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
Name: "only read-only" + id,
Email: id,
PhoneNumber: id,
},
},
fieldMaskPaths: []string{"CreateTime"},
setToNullPaths: []string{"UpdateTime"},
},
want: 0,
wantErr: true,
wantErrMsg: "update: after filtering non-updated fields, there are no fields left in fieldMaskPaths or setToNullPaths",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
rw := &Db{
underlying: db,
}
u := testUser(t, db, tt.name+id, id, id)
if tt.args.i != nil {
tt.args.i.Id = u.Id
tt.args.i.PublicId = u.PublicId
}
rowsUpdated, err := rw.Update(context.Background(), tt.args.i, tt.args.fieldMaskPaths, tt.args.setToNullPaths, tt.args.opt...)
if tt.wantErr {
assert.Error(err)
assert.Equal(tt.want, rowsUpdated)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
assert.Equal(tt.want, rowsUpdated)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = tt.args.i.PublicId
where := "public_id = ?"
for _, f := range tt.args.setToNullPaths {
switch {
case strings.EqualFold(f, "phonenumber"):
f = "phone_number"
}
where = fmt.Sprintf("%s and %s is null", where, f)
}
err = rw.LookupWhere(context.Background(), foundUser, where, tt.args.i.PublicId)
assert.NoError(err)
assert.Equal(tt.args.i.Id, foundUser.Id)
assert.Equal(tt.wantName, foundUser.Name)
assert.Equal(tt.wantEmail, foundUser.Email)
assert.Equal(tt.wantPhoneNumber, foundUser.PhoneNumber)
assert.NotEqual(now, foundUser.CreateTime)
assert.NotEqual(now, foundUser.UpdateTime)
assert.NotEqual(publicId, foundUser.PublicId)
})
}
assert := assert.New(t)
defer db.Close()
t.Run("simple", func(t *testing.T) {
t.Run("valid-WithOplog", func(t *testing.T) {
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
err = w.Create(context.Background(), user)
assert.NoError(err)
assert.NotEmpty(user.Id)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
user := testUser(t, db, "foo-"+id, id, id)
user.Name = "friendly-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"})
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil,
// write oplogs for this update
WithOplog(
TestWrapper(t),
oplog.Metadata{
"key-only": nil,
"deployment": []string{"amex"},
"project": []string{"central-info-systems", "local-info-systems"},
"resource-public-id": []string{user.PublicId},
"op-type": []string{oplog.OpType_OP_TYPE_UPDATE.String()},
}),
)
assert.NoError(err)
assert.Equal(1, rowsUpdated)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Name, user.Name)
})
t.Run("non-updatable-fields", func(t *testing.T) {
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
err = w.Create(context.Background(), user)
db.LogMode(false)
assert.NoError(err)
assert.NotEmpty(user.Id)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
ts := &db_test.Timestamp{Timestamp: ptypes.TimestampNow()}
user.CreateTime = ts
user.UpdateTime = ts
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name", "CreateTime", "UpdateTime"})
assert.NoError(err)
assert.Equal(1, rowsUpdated)
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Name, user.Name)
assert.NotEqual(foundUser.CreateTime, ts)
assert.NotEqual(foundUser.UpdateTime, ts)
ts = &db_test.Timestamp{Timestamp: ptypes.TimestampNow()}
user.Name = id
user.CreateTime = ts
user.UpdateTime = ts
rowsUpdated, err = w.Update(context.Background(), user, nil)
assert.Error(err)
assert.Equal(0, rowsUpdated)
assert.Equal("update: missing fieldMaskPaths nil parameter", err.Error())
assert.NotEqual(foundUser.CreateTime, ts)
assert.NotEqual(foundUser.UpdateTime, ts)
err = TestVerifyOplog(t, &w, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_UPDATE), WithCreateNotBefore(10*time.Second))
assert.NoError(err)
})
t.Run("valid-WithOplog", func(t *testing.T) {
t.Run("vet-for-write", func(t *testing.T) {
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
err = w.Create(context.Background(), user)
assert.NoError(err)
assert.NotEmpty(user.Id)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
user := &testUserWithVet{
PublicId: id,
Name: id,
PhoneNumber: id,
Email: id,
}
err = db.Create(user).Error
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
user.Name = "friendly-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
// write oplogs for this update
WithOplog(
TestWrapper(t),
oplog.Metadata{
"key-only": nil,
"deployment": []string{"amex"},
"project": []string{"central-info-systems", "local-info-systems"},
"resource-public-id": []string{user.GetPublicId()},
}),
)
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil)
assert.NoError(err)
assert.Equal(1, rowsUpdated)
foundUser := &testUserWithVet{PublicId: user.PublicId}
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Name, user.Name)
var metadata store.Metadata
err = db.Where("key = ? and value = ?", "resource-public-id", user.PublicId).First(&metadata).Error
assert.NoError(err)
var foundEntry oplog.Entry
err = db.Where("id = ?", metadata.EntryId).First(&foundEntry).Error
assert.NoError(err)
user.PublicId = "not-allowed-by-vet-for-write"
rowsUpdated, err = w.Update(context.Background(), user, []string{"PublicId"}, nil)
assert.Error(err)
assert.Equal(0, rowsUpdated)
})
t.Run("nil-tx", func(t *testing.T) {
w := Db{underlying: nil}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
rowsUpdated, err := w.Update(context.Background(), user, nil)
user := testUser(t, db, "foo-"+id, id, id)
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil)
assert.Error(err)
assert.Equal(0, rowsUpdated)
assert.Equal("update: missing underlying db nil parameter", err.Error())
@ -156,22 +283,11 @@ func TestDb_Update(t *testing.T) {
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
err = w.Create(context.Background(), user)
assert.NoError(err)
assert.NotEmpty(user.Id)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
user := testUser(t, db, "foo-"+id, id, id)
user.Name = "friendly-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
nil,
WithOplog(
nil,
oplog.Metadata{
@ -188,22 +304,9 @@ func TestDb_Update(t *testing.T) {
w := Db{underlying: db}
id, err := uuid.GenerateUUID()
assert.NoError(err)
user, err := db_test.NewTestUser()
assert.NoError(err)
user.Name = "foo-" + id
err = w.Create(context.Background(), user)
assert.NoError(err)
assert.NotEmpty(user.Id)
foundUser, err := db_test.NewTestUser()
assert.NoError(err)
foundUser.PublicId = user.PublicId
err = w.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
user := testUser(t, db, "foo-"+id, id, id)
user.Name = "friendly-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"},
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil,
WithOplog(
TestWrapper(t),
nil,
@ -215,6 +318,43 @@ func TestDb_Update(t *testing.T) {
})
}
// testUserWithVet gives us a model that implements VetForWrite() without any
// cyclic dependencies.
type testUserWithVet struct {
Id uint32 `gorm:"primary_key"`
PublicId string
Name string `gorm:"default:null"`
PhoneNumber string `gorm:"default:null"`
Email string `gorm:"default:null"`
}
func (u *testUserWithVet) GetPublicId() string {
return u.PublicId
}
func (u *testUserWithVet) TableName() string {
return "db_test_user"
}
func (u *testUserWithVet) VetForWrite(ctx context.Context, r Reader, opType OpType, opt ...Option) error {
if u.PublicId == "" {
return errors.New("public id is empty string for user write")
}
if opType == UpdateOp {
dbOptions := GetOpts(opt...)
for _, path := range dbOptions.WithFieldMaskPaths {
switch path {
case "PublicId":
return errors.New("you cannot change the public id")
}
}
}
if opType == CreateOp {
if u.Id != 0 {
return errors.New("id is a db sequence")
}
}
return nil
}
func TestDb_Create(t *testing.T) {
// intentionally not run with t.Parallel so we don't need to use DoTx for the Create tests
cleanup, db, _ := TestSetup(t, "postgres")
@ -700,7 +840,7 @@ func TestDb_DoTx(t *testing.T) {
_, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error {
user.Name = "friendly-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"})
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil)
if err != nil {
return err
}
@ -722,7 +862,7 @@ func TestDb_DoTx(t *testing.T) {
assert.NoError(err)
_, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error {
user2.Name = "friendly2-" + id
rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"})
rowsUpdated, err := w.Update(context.Background(), user2, []string{"Name"}, nil)
if err != nil {
return err
}
@ -738,7 +878,7 @@ func TestDb_DoTx(t *testing.T) {
_, err = w.DoTx(context.Background(), 10, ExpBackoff{}, func(w Writer) error {
user.Name = "friendly2-" + id
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"})
rowsUpdated, err := w.Update(context.Background(), user, []string{"Name"}, nil)
if err != nil {
return err
}
@ -974,3 +1114,32 @@ func TestDb_ScanRows(t *testing.T) {
}
})
}
func testUser(t *testing.T, conn *gorm.DB, name, email, phoneNumber string) *db_test.TestUser {
t.Helper()
assert := assert.New(t)
publicId, err := base62.Random(20)
assert.NoError(err)
u := &db_test.TestUser{
StoreTestUser: &db_test.StoreTestUser{
PublicId: publicId,
Name: name,
Email: email,
PhoneNumber: phoneNumber,
},
}
if conn != nil {
err = conn.Create(u).Error
assert.NoError(err)
}
return u
}
func testId(t *testing.T) string {
t.Helper()
assert := assert.New(t)
id, err := uuid.GenerateUUID()
assert.NoError(err)
return id
}

@ -2,7 +2,6 @@ package db
import (
"context"
"strconv"
"testing"
"time"
@ -17,8 +16,16 @@ import (
func Test_Utils(t *testing.T) {
t.Parallel()
cleanup, conn, _ := TestSetup(t, "postgres")
defer cleanup()
defer conn.Close()
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("nothing", func(t *testing.T) {
})
@ -26,7 +33,11 @@ func Test_Utils(t *testing.T) {
func TestVerifyOplogEntry(t *testing.T) {
cleanup, db, _ := TestSetup(t, "postgres")
defer cleanup()
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
assert := assert.New(t)
defer db.Close()
@ -43,7 +54,7 @@ func TestVerifyOplogEntry(t *testing.T) {
WithOplog(
TestWrapper(t),
oplog.Metadata{
"op-type": []string{strconv.Itoa(int(oplog.OpType_OP_TYPE_CREATE))},
"op-type": []string{oplog.OpType_OP_TYPE_CREATE.String()},
"resource-public-id": []string{user.GetPublicId()},
}),
)
@ -56,7 +67,7 @@ func TestVerifyOplogEntry(t *testing.T) {
err = rw.LookupByPublicId(context.Background(), foundUser)
assert.NoError(err)
assert.Equal(foundUser.Id, user.Id)
TestVerifyOplog(t, &rw, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(5*time.Second))
err = TestVerifyOplog(t, &rw, user.PublicId, WithOperation(oplog.OpType_OP_TYPE_CREATE), WithCreateNotBefore(5*time.Second))
assert.NoError(err)
})
t.Run("should-fail", func(t *testing.T) {

@ -116,10 +116,18 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas
return nil, db.NoRowsAffected, fmt.Errorf("update: static host catalog: %w", db.ErrEmptyFieldMask)
}
var dbMask, nullFields []string
for _, f := range fieldMask {
switch {
case strings.EqualFold("name", f):
case strings.EqualFold("description", f):
case strings.EqualFold("name", f) && c.Name == "":
nullFields = append(nullFields, "name")
case strings.EqualFold("name", f) && c.Name != "":
dbMask = append(dbMask, "name")
case strings.EqualFold("description", f) && c.Description == "":
nullFields = append(nullFields, "description")
case strings.EqualFold("description", f) && c.Description != "":
dbMask = append(dbMask, "description")
default:
return nil, db.NoRowsAffected, fmt.Errorf("update: static host catalog: field: %s: %w", f, db.ErrInvalidFieldMask)
}
@ -128,8 +136,6 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas
metadata := newCatalogMetadata(c, oplog.OpType_OP_TYPE_UPDATE)
// TODO(mgaffney,jimlambrt) 05/2020: uncomment the nullFields line
// below once support for setting columns to nil is added to db.Update.
var rowsUpdated int
var returnedCatalog *HostCatalog
_, err := r.writer.DoTx(
@ -142,8 +148,8 @@ func (r *Repository) UpdateCatalog(ctx context.Context, c *HostCatalog, fieldMas
rowsUpdated, err = w.Update(
ctx,
returnedCatalog,
fieldMask,
// nullFields,
dbMask,
nullFields,
db.WithOplog(r.wrapper, metadata),
)
if err == nil && rowsUpdated > 1 {

@ -453,74 +453,76 @@ func TestRepository_UpdateCatalog(t *testing.T) {
},
wantCount: 1,
},
// TODO(mgaffney,jimlambrt) 05/2020: enable these tests once
// support for setting columns to nil is added to db.Update.
/*
{
name: "delete-name",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
{
name: "delete-name",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
masks: []string{"Name"},
chgFn: combine(changeDescription("test-update-description-repo"), changeName("")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Description: "test-description-repo",
},
},
masks: []string{"Name"},
chgFn: combine(changeDescription("test-update-description-repo"), changeName("")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Description: "test-description-repo",
},
},
{
name: "delete-description",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
wantCount: 1,
},
{
name: "delete-description",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
masks: []string{"Description"},
chgFn: combine(changeDescription(""), changeName("test-update-name-repo")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
},
},
masks: []string{"Description"},
chgFn: combine(changeDescription(""), changeName("test-update-name-repo")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
},
},
{
name: "do-not-delete-name",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
wantCount: 1,
},
{
name: "do-not-delete-name",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
chgFn: combine(changeDescription("test-update-description-repo"), changeName("")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-update-description-repo",
},
},
masks: []string{"Description"},
chgFn: combine(changeDescription("test-update-description-repo"), changeName("")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-update-description-repo",
},
},
{
name: "do-not-delete-description",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
wantCount: 1,
},
{
name: "do-not-delete-description",
orig: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-name-repo",
Description: "test-description-repo",
},
chgFn: combine(changeDescription(""), changeName("test-update-name-repo")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-update-name-repo",
Description: "test-description-repo",
},
},
masks: []string{"Name"},
chgFn: combine(changeDescription(""), changeName("test-update-name-repo")),
want: &HostCatalog{
HostCatalog: &store.HostCatalog{
Name: "test-update-name-repo",
Description: "test-description-repo",
},
},
*/
wantCount: 1,
},
}
for _, tt := range tests {

@ -4,6 +4,8 @@ import (
"context"
"errors"
"fmt"
"reflect"
"strings"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/watchtower/internal/db"
@ -72,7 +74,7 @@ func (r *Repository) create(ctx context.Context, resource Resource, opt ...Optio
}
// update will update an iam resource in the db repository with an oplog entry
func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, opt ...Option) (Resource, int, error) {
func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPaths []string, setToNullPaths []string, opt ...Option) (Resource, int, error) {
if resource == nil {
return nil, db.NoRowsAffected, errors.New("error updating resource that is nil")
}
@ -99,6 +101,7 @@ func (r *Repository) update(ctx context.Context, resource Resource, fieldMaskPat
ctx,
returnedResource,
fieldMaskPaths,
setToNullPaths,
db.WithOplog(r.wrapper, metadata),
)
if err == nil && rowsUpdated > 1 {
@ -152,24 +155,27 @@ func (r *Repository) delete(ctx context.Context, resource Resource, opt ...Optio
func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog.Metadata, error) {
if s, ok := resource.(*Scope); ok {
if s.Type == "" {
if err := r.reader.LookupByPublicId(ctx, s); err != nil {
scope := allocScope()
scope.PublicId = s.PublicId
scope.Type = s.Type
if scope.Type == "" {
if err := r.reader.LookupByPublicId(ctx, &scope); err != nil {
return nil, ErrMetadataScopeNotFound
}
}
switch s.Type {
switch scope.Type {
case OrganizationScope.String():
return oplog.Metadata{
"resource-public-id": []string{resource.GetPublicId()},
"scope-id": []string{s.PublicId},
"scope-type": []string{s.Type},
"scope-id": []string{scope.PublicId},
"scope-type": []string{scope.Type},
"resource-type": []string{resource.ResourceType().String()},
}, nil
case ProjectScope.String():
return oplog.Metadata{
"resource-public-id": []string{resource.GetPublicId()},
"scope-id": []string{s.ParentId},
"scope-type": []string{s.Type},
"scope-id": []string{scope.ParentId},
"scope-type": []string{scope.Type},
"resource-type": []string{resource.ResourceType().String()},
}, nil
default:
@ -191,3 +197,31 @@ func (r *Repository) stdMetadata(ctx context.Context, resource Resource) (oplog.
"resource-type": []string{resource.ResourceType().String()},
}, nil
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}
func buildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string) (masks []string, nulls []string) {
for f, v := range fieldValues {
if !contains(fieldMask, f) {
continue
}
switch {
case isZero(v):
nulls = append(nulls, f)
default:
masks = append(masks, f)
}
}
return masks, nulls
}
func isZero(i interface{}) bool {
return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface())
}

@ -20,17 +20,41 @@ func (r *Repository) CreateScope(ctx context.Context, scope *Scope, opt ...Optio
return resource.(*Scope), nil
}
// UpdateScope will update a scope in the repository and return the written scope
// UpdateScope will update a scope in the repository and return the written
// scope. fieldMaskPaths provides field_mask.proto paths for fields that should
// be updated. Fields will be set to NULL if the field is a zero value and
// included in fieldMask. Name and Description are the only updatable fields,
// and everything else is ignored. If no updatable fields are included in the
// fieldMaskPaths, then an error is returned.
func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPaths []string, opt ...Option) (*Scope, int, error) {
if scope == nil {
return nil, db.NoRowsAffected, errors.New("scope is nil for update")
return nil, db.NoRowsAffected, fmt.Errorf("update scope: missing scope: %w", db.ErrNilParameter)
}
if scope.PublicId == "" {
return nil, db.NoRowsAffected, errors.New("scope public id is unset for update")
return nil, db.NoRowsAffected, fmt.Errorf("update scope: missing public id: %w", db.ErrNilParameter)
}
resource, rowsUpdated, err := r.update(ctx, scope, fieldMaskPaths)
if contains(fieldMaskPaths, "ParentId") {
return nil, db.NoRowsAffected, fmt.Errorf("update scope: you cannot change a scope's parent: %w", db.ErrInvalidFieldMask)
}
var dbMask, nullFields []string
dbMask, nullFields = buildUpdatePaths(
map[string]interface{}{
"name": scope.Name,
"description": scope.Description,
},
fieldMaskPaths,
)
// nada to update, so reload scope from db and return it
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, db.NoRowsAffected, fmt.Errorf("update scope: %w", db.ErrEmptyFieldMask)
}
resource, rowsUpdated, err := r.update(ctx, scope, dbMask, nullFields)
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("failed to update scope: %w", err)
if db.IsUnique(err) {
return nil, db.NoRowsAffected, fmt.Errorf("update scope: %s name %s already exists: %w", scope.PublicId, scope.Name, db.ErrNotUnique)
}
return nil, db.NoRowsAffected, fmt.Errorf("update scope: failed for public id %s: %w", scope.PublicId, err)
}
return resource.(*Scope), rowsUpdated, err
}

@ -2,11 +2,14 @@ package iam
import (
"context"
"fmt"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/watchtower/internal/db"
iam_store "github.com/hashicorp/watchtower/internal/iam/store"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/proto"
@ -21,7 +24,11 @@ func Test_Repository_CreateScope(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid-scope", func(t *testing.T) {
rw := db.New(conn)
@ -107,7 +114,11 @@ func Test_Repository_UpdateScope(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid-scope", func(t *testing.T) {
rw := db.New(conn)
@ -176,7 +187,7 @@ func Test_Repository_UpdateScope(t *testing.T) {
assert.Error(err)
assert.Nil(project)
assert.Equal(0, updatedRows)
assert.Equal("failed to update scope: update: vet for write failed you cannot change a scope's parent", err.Error())
assert.Equal("update scope: you cannot change a scope's parent: invalid field mask", err.Error())
})
}
@ -188,8 +199,12 @@ func Test_Repository_LookupScope(t *testing.T) {
t.Error(err)
}
}()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
assert := assert.New(t)
defer conn.Close()
t.Run("found-and-not-found", func(t *testing.T) {
rw := db.New(conn)
@ -228,7 +243,11 @@ func Test_Repository_DeleteScope(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
@ -268,3 +287,231 @@ func Test_Repository_DeleteScope(t *testing.T) {
assert.Equal(0, rowsDeleted)
})
}
func TestRepository_UpdateScope(t *testing.T) {
cleanup, conn, _ := db.TestSetup(t, "postgres")
now := &iam_store.Timestamp{Timestamp: ptypes.TimestampNow()}
id := testId(t)
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
rw := db.New(conn)
wrapper := db.TestWrapper(t)
publicId := testPublicId(t, "o")
type args struct {
scope *Scope
fieldMaskPaths []string
opt []Option
}
tests := []struct {
name string
args args
wantName string
wantDescription string
wantUpdatedRows int
wantErr bool
wantErrMsg string
wantNullFields []string
}{
{
name: "valid-scope",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "valid-scope" + id,
Description: "",
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name", "Description", "CreateTime", "UpdateTime", "PublicId"},
},
wantName: "valid-scope" + id,
wantDescription: "",
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
wantNullFields: []string{"Description"},
},
{
name: "nil-resource",
args: args{
scope: nil,
fieldMaskPaths: []string{"Name"},
},
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "update scope: missing scope: nil parameter",
wantNullFields: nil,
},
{
name: "no-updates",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "no-updates" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"CreateTime"},
},
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "update scope: empty field mask",
wantNullFields: nil,
},
{
name: "no-null",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "no-null" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name"},
},
wantName: "no-null" + id,
wantDescription: "orig-" + id,
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
wantNullFields: nil,
},
{
name: "only-null",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "",
Description: "",
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name", "Description"},
},
wantName: "",
wantDescription: "",
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
wantNullFields: nil,
},
{
name: "parent",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "parent" + id,
Description: "",
ParentId: publicId,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"ParentId", "CreateTime", "UpdateTime", "PublicId"},
},
wantName: "parent-orig-" + id,
wantDescription: "orig-" + id,
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "update scope: you cannot change a scope's parent: invalid field mask",
wantNullFields: nil,
},
{
name: "type",
args: args{
scope: &Scope{
Scope: &iam_store.Scope{
Name: "type" + id,
Description: "",
Type: ProjectScope.String(),
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Type", "CreateTime", "UpdateTime", "PublicId"},
},
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "update scope: empty field mask",
wantNullFields: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
r := &Repository{
reader: rw,
writer: rw,
wrapper: wrapper,
}
org := testOrg(t, conn, tt.name+"-orig-"+id, "orig-"+id)
if tt.args.scope != nil {
tt.args.scope.PublicId = org.PublicId
}
updatedScope, rowsUpdated, err := r.UpdateScope(context.Background(), tt.args.scope, tt.args.fieldMaskPaths, tt.args.opt...)
if tt.wantErr {
assert.Error(err)
assert.Equal(tt.wantUpdatedRows, rowsUpdated)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
assert.Equal(tt.wantUpdatedRows, rowsUpdated)
if tt.wantUpdatedRows > 0 {
err = db.TestVerifyOplog(t, rw, org.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.NoError(err)
}
foundScope := allocScope()
foundScope.PublicId = updatedScope.PublicId
where := "public_id = ?"
for _, f := range tt.wantNullFields {
where = fmt.Sprintf("%s and %s is null", where, f)
}
err = rw.LookupWhere(context.Background(), &foundScope, where, org.PublicId)
assert.NoError(err)
assert.Equal(org.PublicId, foundScope.PublicId)
assert.Equal(tt.wantName, foundScope.Name)
assert.Equal(tt.wantDescription, foundScope.Description)
assert.NotEqual(now, foundScope.CreateTime)
assert.NotEqual(now, foundScope.UpdateTime)
})
}
t.Run("dup-name", func(t *testing.T) {
assert := assert.New(t)
r := &Repository{
reader: rw,
writer: rw,
wrapper: wrapper,
}
id := testId(t)
_ = testOrg(t, conn, id, id)
org2 := testOrg(t, conn, "dup-"+id, id)
org2.Name = id
updatedScope, rowsUpdated, err := r.UpdateScope(context.Background(), org2, []string{"Name"})
assert.Error(err)
assert.Equal(0, rowsUpdated, "updated rows should be 0")
assert.Nil(updatedScope, "scope should be nil")
})
}

@ -2,13 +2,16 @@ package iam
import (
"context"
"fmt"
"reflect"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/watchtower/internal/db"
iam_store "github.com/hashicorp/watchtower/internal/iam/store"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/hashicorp/watchtower/internal/oplog/store"
"github.com/stretchr/testify/assert"
@ -24,8 +27,11 @@ func TestNewRepository(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
rw := db.New(conn)
wrapper := db.TestWrapper(t)
type args struct {
@ -113,8 +119,11 @@ func Test_Repository_create(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid-scope", func(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
@ -155,7 +164,7 @@ func Test_Repository_create(t *testing.T) {
})
}
func Test_Repository_update(t *testing.T) {
func Test_Repository_delete(t *testing.T) {
t.Parallel()
cleanup, conn, _ := db.TestSetup(t, "postgres")
defer func() {
@ -164,67 +173,11 @@ func Test_Repository_update(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
t.Run("valid-scope", func(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
id, err := uuid.GenerateUUID()
assert.NoError(err)
s, err := NewOrganization()
assert.NoError(err)
retScope, err := repo.create(context.Background(), s)
assert.NoError(err)
assert.NotNil(retScope)
assert.NotEmpty(retScope.GetPublicId())
assert.Empty(retScope.GetName())
retScope.(*Scope).Name = "fname-" + id
retScope, updatedRows, err := repo.update(context.Background(), retScope, []string{"Name"})
assert.NoError(err)
assert.NotNil(retScope)
assert.Equal(1, updatedRows)
assert.Equal(retScope.GetName(), "fname-"+id)
foundScope, err := repo.LookupScope(context.Background(), s.PublicId)
assert.NoError(err)
assert.Equal(foundScope.GetPublicId(), retScope.GetPublicId())
var metadata store.Metadata
err = conn.Where("key = ? and value = ?", "resource-public-id", s.PublicId).First(&metadata).Error
assert.NoError(err)
var foundEntry oplog.Entry
err = conn.Where("id = ?", metadata.EntryId).First(&foundEntry).Error
assert.NoError(err)
})
t.Run("nil-resource", func(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
resource, updatedRows, err := repo.update(context.Background(), nil, nil)
assert.NotNil(err)
assert.Nil(resource)
assert.Equal(0, updatedRows)
assert.Equal(err.Error(), "error updating resource that is nil")
})
}
func Test_Repository_delete(t *testing.T) {
t.Parallel()
cleanup, conn, _ := db.TestSetup(t, "postgres")
defer func() {
if err := cleanup(); err != nil {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
assert := assert.New(t)
defer conn.Close()
t.Run("valid-org", func(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
@ -257,3 +210,171 @@ func Test_Repository_delete(t *testing.T) {
assert.Equal(err.Error(), "error deleting resource that is nil")
})
}
func TestRepository_update(t *testing.T) {
cleanup, conn, _ := db.TestSetup(t, "postgres")
now := &iam_store.Timestamp{Timestamp: ptypes.TimestampNow()}
id := testId(t)
defer func() {
if err := cleanup(); err != nil {
t.Error(err)
}
}()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
rw := db.New(conn)
wrapper := db.TestWrapper(t)
publicId := testPublicId(t, "o")
type args struct {
resource Resource
fieldMaskPaths []string
setToNullPaths []string
opt []Option
}
tests := []struct {
name string
args args
wantName string
wantDescription string
wantUpdatedRows int
wantErr bool
wantErrMsg string
}{
{
name: "valid-scope",
args: args{
resource: &Scope{
Scope: &iam_store.Scope{
Name: "valid-scope" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name"},
setToNullPaths: []string{"Description"},
},
wantName: "valid-scope" + id,
wantDescription: "",
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
},
{
name: "nil-resource",
args: args{
resource: nil,
fieldMaskPaths: []string{"Name"},
setToNullPaths: []string{"Description"},
},
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "error updating resource that is nil",
},
{
name: "intersection",
args: args{
resource: &Scope{
Scope: &iam_store.Scope{
Name: "intersection" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name"},
setToNullPaths: []string{"Name"},
},
wantUpdatedRows: 0,
wantErr: true,
wantErrMsg: "update: getting update fields failed: fieldMashPaths and setToNullPaths cannot intersect",
},
{
name: "only-field-masks",
args: args{
resource: &Scope{
Scope: &iam_store.Scope{
Name: "only-field-masks" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: []string{"Name"},
setToNullPaths: nil,
},
wantName: "only-field-masks" + id,
wantDescription: "orig-" + id,
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
},
{
name: "only-null-fields",
args: args{
resource: &Scope{
Scope: &iam_store.Scope{
Name: "only-null-fields" + id,
Description: "updated" + id,
CreateTime: now,
UpdateTime: now,
PublicId: publicId,
},
},
fieldMaskPaths: nil,
setToNullPaths: []string{"Name", "Description"},
},
wantName: "",
wantDescription: "",
wantUpdatedRows: 1,
wantErr: false,
wantErrMsg: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
r := &Repository{
reader: rw,
writer: rw,
wrapper: wrapper,
}
org := testOrg(t, conn, tt.name+"-orig-"+id, "orig-"+id)
if tt.args.resource != nil {
tt.args.resource.(*Scope).PublicId = org.PublicId
}
updatedResource, rowsUpdated, err := r.update(context.Background(), tt.args.resource, tt.args.fieldMaskPaths, tt.args.setToNullPaths, tt.args.opt...)
if tt.wantErr {
assert.Error(err)
assert.Equal(tt.wantUpdatedRows, rowsUpdated)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
assert.Equal(tt.wantUpdatedRows, rowsUpdated)
err = db.TestVerifyOplog(t, rw, org.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.NoError(err)
foundResource := allocScope()
foundResource.PublicId = updatedResource.GetPublicId()
where := "public_id = ?"
for _, f := range tt.args.setToNullPaths {
where = fmt.Sprintf("%s and %s is null", where, f)
}
err = rw.LookupWhere(context.Background(), &foundResource, where, tt.args.resource.GetPublicId())
assert.NoError(err)
assert.Equal(tt.args.resource.GetPublicId(), foundResource.GetPublicId())
assert.Equal(tt.wantName, foundResource.GetName())
assert.Equal(tt.wantDescription, foundResource.GetDescription())
assert.NotEqual(now, foundResource.GetCreateTime())
assert.NotEqual(now, foundResource.GetUpdateTime())
})
}
}

@ -19,8 +19,11 @@ func Test_NewScope(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid-org-with-project", func(t *testing.T) {
w := db.New(conn)
s, err := NewOrganization()
@ -60,8 +63,11 @@ func Test_ScopeCreate(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid", func(t *testing.T) {
w := db.New(conn)
s, err := NewOrganization()
@ -104,8 +110,11 @@ func Test_ScopeUpdate(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid", func(t *testing.T) {
w := db.New(conn)
s, err := NewOrganization()
@ -118,7 +127,7 @@ func Test_ScopeUpdate(t *testing.T) {
id, err := uuid.GenerateUUID()
assert.NoError(err)
s.Name = id
updatedRows, err := w.Update(context.Background(), s, []string{"Name"})
updatedRows, err := w.Update(context.Background(), s, []string{"Name"}, nil)
assert.NoError(err)
assert.Equal(1, updatedRows)
})
@ -132,7 +141,7 @@ func Test_ScopeUpdate(t *testing.T) {
assert.NotEmpty(s.PublicId)
s.Type = ProjectScope.String()
updatedRows, err := w.Update(context.Background(), s, []string{"Type"})
updatedRows, err := w.Update(context.Background(), s, []string{"Type"}, nil)
assert.NotNil(err)
assert.Equal(0, updatedRows)
})
@ -146,7 +155,11 @@ func Test_ScopeGetScope(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid-scope", func(t *testing.T) {
w := db.New(conn)
s, err := NewOrganization()
@ -212,8 +225,11 @@ func TestScope_Clone(t *testing.T) {
}
}()
assert := assert.New(t)
defer conn.Close()
defer func() {
if err := conn.Close(); err != nil {
t.Error(err)
}
}()
t.Run("valid", func(t *testing.T) {
w := db.New(conn)
s, err := NewOrganization()

@ -4,6 +4,7 @@ import (
"context"
"testing"
"github.com/hashicorp/go-uuid"
"github.com/hashicorp/watchtower/internal/db"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
@ -32,3 +33,36 @@ func TestScopes(t *testing.T, conn *gorm.DB) (org *Scope, prj *Scope) {
return
}
func testOrg(t *testing.T, conn *gorm.DB, name, description string) (org *Scope) {
t.Helper()
assert := assert.New(t)
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
o, err := NewOrganization(WithDescription(description), WithName(name))
assert.NoError(err)
o, err = repo.CreateScope(context.Background(), o)
assert.NoError(err)
assert.NotNil(o)
assert.NotEmpty(o.GetPublicId())
return o
}
func testId(t *testing.T) string {
t.Helper()
assert := assert.New(t)
id, err := uuid.GenerateUUID()
assert.NoError(err)
return id
}
func testPublicId(t *testing.T, prefix string) string {
t.Helper()
assert := assert.New(t)
publicId, err := db.NewPublicId(prefix)
assert.NoError(err)
return publicId
}

@ -1,12 +1,45 @@
package iam
import (
"strings"
"testing"
"github.com/hashicorp/watchtower/internal/db"
"github.com/stretchr/testify/assert"
)
func Test_testOrg(t *testing.T) {
assert := assert.New(t)
cleanup, conn, _ := db.TestSetup(t, "postgres")
defer func() {
err := cleanup()
assert.NoError(err)
}()
defer func() {
err := conn.Close()
assert.NoError(err)
}()
id := testId(t)
org := testOrg(t, conn, id, id)
assert.Equal(id, org.Name)
assert.Equal(id, org.Description)
assert.NotEmpty(org.PublicId)
}
func Test_testId(t *testing.T) {
assert := assert.New(t)
id := testId(t)
assert.NotEmpty(id)
}
func Test_testPublicId(t *testing.T) {
assert := assert.New(t)
id := testPublicId(t, "test")
assert.NotEmpty(id)
assert.True(strings.HasPrefix(id, "test_"))
}
func Test_TestScopes(t *testing.T) {
assert := assert.New(t)

Loading…
Cancel
Save