Add lookup, list, update, and delete methods for auth methods. (#230)

pull/236/head
Todd Knight 6 years ago committed by GitHub
parent 41b2d36d98
commit 3204e54dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,9 +16,16 @@ type AuthMethod struct {
tableName string
}
func allocAuthMethod() AuthMethod {
return AuthMethod{
AuthMethod: &store.AuthMethod{},
}
}
// NewAuthMethod creates a new in memory AuthMethod assigned to scopeId.
// Name and description are the only valid options. All other options are
// ignored.
// ignored. MinUserNameLength and MinPasswordLength are pre-set to the
// default values of 5 and 8 respectively.
func NewAuthMethod(scopeId string, opt ...Option) (*AuthMethod, error) {
if scopeId == "" {
return nil, fmt.Errorf("new: password auth method: no scope id: %w", db.ErrInvalidParameter)

@ -2,6 +2,7 @@ package password
import (
"fmt"
"strings"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/watchtower/internal/db"
@ -44,3 +45,12 @@ func NewRepository(r db.Reader, w db.Writer, wrapper wrapping.Wrapper, opt ...Op
defaultLimit: opts.withLimit,
}, nil
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}

@ -2,9 +2,12 @@ package password
import (
"context"
"errors"
"fmt"
"strings"
"github.com/hashicorp/watchtower/internal/db"
dbcommon "github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/oplog"
)
@ -76,3 +79,146 @@ func (r *Repository) CreateAuthMethod(ctx context.Context, m *AuthMethod, opt ..
}
return newAuthMethod, nil
}
// LookupAuthMethod will look up an auth method in the repository. If the auth method is not
// found, it will return nil, nil. All options are ignored.
func (r *Repository) LookupAuthMethod(ctx context.Context, publicId string, opt ...Option) (*AuthMethod, error) {
if publicId == "" {
return nil, fmt.Errorf("lookup: password auth method: missing public id %w", db.ErrInvalidParameter)
}
a := allocAuthMethod()
a.PublicId = publicId
if err := r.reader.LookupByPublicId(ctx, &a); err != nil {
if errors.Is(err, db.ErrRecordNotFound) {
return nil, nil
}
return nil, fmt.Errorf("lookup: password auth method: failed %w for %s", err, publicId)
}
return &a, nil
}
// ListAuthMethods returns a slice of AuthMethods for the scopeId. WithLimit is the only option supported.
func (r *Repository) ListAuthMethods(ctx context.Context, scopeId string, opt ...Option) ([]*AuthMethod, error) {
if scopeId == "" {
return nil, fmt.Errorf("list: password auth method: missing scope id: %w", db.ErrInvalidParameter)
}
opts := getOpts(opt...)
limit := r.defaultLimit
if opts.withLimit != 0 {
// non-zero signals an override of the default limit for the repo.
limit = opts.withLimit
}
var authMethods []*AuthMethod
err := r.reader.SearchWhere(ctx, &authMethods, "scope_id = ?", []interface{}{scopeId}, db.WithLimit(limit))
if err != nil {
return nil, fmt.Errorf("list: password auth method: %w", err)
}
return authMethods, nil
}
// DeleteAuthMethod deletes the auth method for the provided id from the repository returning a count of the
// number of records deleted. All options are ignored.
func (r *Repository) DeleteAuthMethod(ctx context.Context, publicId string, opt ...Option) (int, error) {
if publicId == "" {
return db.NoRowsAffected, fmt.Errorf("delete: password auth method: missing public id: %w", db.ErrInvalidParameter)
}
am := allocAuthMethod()
am.PublicId = publicId
var rowsDeleted int
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) (err error) {
metadata := am.oplog(oplog.OpType_OP_TYPE_DELETE)
dAc := am.clone()
rowsDeleted, err = w.Delete(ctx, dAc, db.WithOplog(r.wrapper, metadata))
if err == nil && rowsDeleted > 1 {
return db.ErrMultipleRecords
}
return err
},
)
if err != nil {
return db.NoRowsAffected, fmt.Errorf("delete: password auth method: %s: %w", publicId, err)
}
return rowsDeleted, nil
}
// TODO: Fix the MinPasswordLength and MinUserNameLength update path so they dont have
// to rely on the response of NewAuthMethod but instead can be unset in order to be
// set to the default values.
// UpdateAuthMethod will update an auth method in the repository and return
// the written auth method. MinPasswordLength and MinUserNameLength should
// not be set to null, but instead use the default values returned by
// NewAuthMethod. fieldMaskPaths provides field_mask.proto paths for fields
// that should be updated. Fields will be set to NULL if the field is a zero
// value and included in fieldMask. Name, Description, MinPasswordLength,
// and MinUserNameLength are the only updatable fields, If no updatable fields
// are included in the fieldMaskPaths, then an error is returned.
func (r *Repository) UpdateAuthMethod(ctx context.Context, authMethod *AuthMethod, fieldMaskPaths []string, opt ...Option) (*AuthMethod, int, error) {
if authMethod == nil {
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: missing authMethod: %w", db.ErrNilParameter)
}
if authMethod.PublicId == "" {
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: missing authMethod public id: %w", db.ErrInvalidParameter)
}
for _, f := range fieldMaskPaths {
switch {
case strings.EqualFold("name", f):
case strings.EqualFold("description", f):
case strings.EqualFold("MinUserNameLength", f):
case strings.EqualFold("MinPasswordLength", f):
default:
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: field: %s: %w", f, db.ErrInvalidFieldMask)
}
}
var dbMask, nullFields []string
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"Name": authMethod.Name,
"Description": authMethod.Description,
"MinPasswordLength": authMethod.MinPasswordLength,
"MinUserNameLength": authMethod.MinUserNameLength,
},
fieldMaskPaths,
)
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: %w", db.ErrEmptyFieldMask)
}
upAuthMethod := authMethod.clone()
var rowsUpdated int
_, err := r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
dbOpts := []db.Option{db.WithOplog(r.wrapper, upAuthMethod.oplog(oplog.OpType_OP_TYPE_UPDATE))}
var err error
rowsUpdated, err = w.Update(
ctx,
upAuthMethod,
dbMask,
nullFields,
dbOpts...,
)
if err == nil && rowsUpdated > 1 {
// return err, which will result in a rollback of the update
return errors.New("error more than 1 resource would have been updated ")
}
return err
},
)
if err != nil {
if db.IsUniqueError(err) {
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: authMethod %s already exists in scope %s", authMethod.Name, authMethod.ScopeId)
}
return nil, db.NoRowsAffected, fmt.Errorf("update: password auth method: %w for %s", err, authMethod.PublicId)
}
return upAuthMethod, rowsUpdated, err
}

@ -5,12 +5,18 @@ import (
"errors"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/watchtower/internal/auth/password/store"
"github.com/hashicorp/watchtower/internal/db"
dbassert "github.com/hashicorp/watchtower/internal/db/assert"
"github.com/hashicorp/watchtower/internal/iam"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
func TestRepository_CreateAuthMethod(t *testing.T) {
@ -211,6 +217,434 @@ func TestRepository_CreateAuthMethod(t *testing.T) {
})
}
func TestRepository_LookupAuthMethod(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
o, _ := iam.TestScopes(t, conn)
authMethod := TestAuthMethods(t, conn, o.GetPublicId(), 1)[0]
amId, err := newAuthMethodId()
require.NoError(t, err)
var tests = []struct {
name string
in string
want *AuthMethod
wantIsErr error
}{
{
name: "With no public id",
wantIsErr: db.ErrInvalidParameter,
},
{
name: "With non existing auth method id",
in: amId,
},
{
name: "With existing auth method id",
in: authMethod.GetPublicId(),
want: authMethod,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.LookupAuthMethod(context.Background(), tt.in)
if tt.wantIsErr != nil {
assert.Truef(errors.Is(err, tt.wantIsErr), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
assert.EqualValues(tt.want, got)
})
}
}
func TestRepository_DeleteAuthMethod(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
o, _ := iam.TestScopes(t, conn)
authMethod := TestAuthMethods(t, conn, o.GetPublicId(), 1)[0]
newAuthMethodId, err := newAuthMethodId()
require.NoError(t, err)
var tests = []struct {
name string
in string
want int
wantIsErr error
}{
{
name: "With no public id",
wantIsErr: db.ErrInvalidParameter,
},
{
name: "With non existing auth method id",
in: newAuthMethodId,
want: 0,
},
{
name: "With existing auth method id",
in: authMethod.GetPublicId(),
want: 1,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.DeleteAuthMethod(context.Background(), tt.in)
if tt.wantIsErr != nil {
assert.Truef(errors.Is(err, tt.wantIsErr), "want err: %q got: %q", tt.wantIsErr, err)
assert.Zero(got)
return
}
require.NoError(err)
assert.EqualValues(tt.want, got)
})
}
}
func TestRepository_ListAuthMethods(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
noAuthMethodOrg, _ := iam.TestScopes(t, conn)
o, _ := iam.TestScopes(t, conn)
authMethods := TestAuthMethods(t, conn, o.GetPublicId(), 3)
var tests = []struct {
name string
in string
opts []Option
want []*AuthMethod
wantIsErr error
}{
{
name: "With no scope id",
wantIsErr: db.ErrInvalidParameter,
},
{
name: "Scope with no auth methods",
in: noAuthMethodOrg.GetPublicId(),
want: []*AuthMethod{},
},
{
name: "With populated scope id",
in: o.GetPublicId(),
want: authMethods,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, wrapper)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.ListAuthMethods(context.Background(), tt.in, tt.opts...)
if tt.wantIsErr != nil {
assert.Truef(errors.Is(err, tt.wantIsErr), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
assert.Empty(cmp.Diff(tt.want, got, protocmp.Transform()))
})
}
}
func TestRepository_ListAuthMethods_Limits(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
o, _ := iam.TestScopes(t, conn)
authMethodCount := 10
ams := TestAuthMethods(t, conn, o.GetPublicId(), authMethodCount)
var tests = []struct {
name string
repoOpts []Option
listOpts []Option
wantLen int
}{
{
name: "With no limits",
wantLen: authMethodCount,
},
{
name: "With repo limit",
repoOpts: []Option{WithLimit(3)},
wantLen: 3,
},
{
name: "With negative repo limit",
repoOpts: []Option{WithLimit(-1)},
wantLen: authMethodCount,
},
{
name: "With List limit",
listOpts: []Option{WithLimit(3)},
wantLen: 3,
},
{
name: "With negative List limit",
listOpts: []Option{WithLimit(-1)},
wantLen: authMethodCount,
},
{
name: "With repo smaller than list limit",
repoOpts: []Option{WithLimit(2)},
listOpts: []Option{WithLimit(6)},
wantLen: 6,
},
{
name: "With repo larger than list limit",
repoOpts: []Option{WithLimit(6)},
listOpts: []Option{WithLimit(2)},
wantLen: 2,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, wrapper, tt.repoOpts...)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.ListAuthMethods(context.Background(), ams[0].GetScopeId(), tt.listOpts...)
require.NoError(err)
assert.Len(got, tt.wantLen)
})
}
}
func TestRepository_UpdateAuthMethod(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
repo, err := NewRepository(rw, rw, wrapper)
require.NoError(t, err)
ctx := context.Background()
type args struct {
updates *store.AuthMethod
fieldMaskPaths []string
}
tests := []struct {
name string
args args
wantRowsUpdate int
wantErr bool
}{
{
name: "change name",
args: args{
updates: &store.AuthMethod{
Name: "updated",
},
fieldMaskPaths: []string{"Name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null name",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"Name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "change description",
args: args{
updates: &store.AuthMethod{
Description: "updated",
},
fieldMaskPaths: []string{"Description"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null description",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"Description"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "change min pw",
args: args{
updates: &store.AuthMethod{
MinPasswordLength: 13,
},
fieldMaskPaths: []string{"MinPasswordLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null min pw",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"MinPasswordLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "change min username",
args: args{
updates: &store.AuthMethod{
MinUserNameLength: 13,
},
fieldMaskPaths: []string{"MinUserNameLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null min username",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"MinUserNameLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "noop update",
args: args{
updates: &store.AuthMethod{
Name: "default",
},
fieldMaskPaths: []string{"name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "not fround",
args: args{
updates: &store.AuthMethod{
PublicId: func() string {
s, err := newAuthMethodId()
require.NoError(t, err)
return s
}(),
},
fieldMaskPaths: []string{"name"},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "empty field mask",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "nil field mask",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: nil,
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "read-only-fields",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{"CreateTime"},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "unknown fields",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{"RandomUnknownName"},
},
wantErr: true,
wantRowsUpdate: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// create the initial auth method
o, _ := iam.TestScopes(t, conn)
am, err := NewAuthMethod(o.GetPublicId(), WithName("default"), WithDescription("default"))
require.NoError(err)
origAM, err := repo.CreateAuthMethod(ctx, am)
require.NoError(err)
amToUpdate, err := NewAuthMethod(o.GetPublicId())
require.NoError(err)
amToUpdate.PublicId = origAM.GetPublicId()
proto.Merge(amToUpdate.AuthMethod, tt.args.updates)
updatedAM, updatedRows, err := repo.UpdateAuthMethod(ctx, amToUpdate, tt.args.fieldMaskPaths)
assert.Equal(tt.wantRowsUpdate, updatedRows)
if tt.wantErr {
require.Error(err)
assert.Nil(updatedAM)
err = db.TestVerifyOplog(t, rw, amToUpdate.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
require.Error(err)
assert.Equal("record not found", err.Error())
return
}
require.NoError(err)
assert.NotEqual(origAM.UpdateTime, updatedAM.UpdateTime)
foundAuthMethod, err := repo.LookupAuthMethod(ctx, origAM.PublicId)
require.NoError(err)
assert.Empty(cmp.Diff(updatedAM, foundAuthMethod, protocmp.Transform()))
dbassert := dbassert.New(t, rw)
if amToUpdate.Name == "" && contains(tt.args.fieldMaskPaths, "name") {
dbassert.IsNull(foundAuthMethod, "name")
}
if amToUpdate.Description == "" && contains(tt.args.fieldMaskPaths, "description") {
dbassert.IsNull(foundAuthMethod, "description")
}
err = db.TestVerifyOplog(t, rw, updatedAM.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.NoError(err)
})
}
}
func assertPublicId(t *testing.T, prefix, actual string) {
t.Helper()
assert.NotEmpty(t, actual)

@ -124,7 +124,8 @@ func TestRepository_AuthenticateRehash(t *testing.T) {
wrapper := db.TestWrapper(t)
assert, require := assert.New(t), require.New(t)
authMethods := testAuthMethods(t, conn, 1)
o, _ := iam.TestScopes(t, conn)
authMethods := TestAuthMethods(t, conn, o.GetPublicId(), 1)
authMethod := authMethods[0]
authMethodId := authMethod.GetPublicId()
userName := "kazmierczak"

@ -141,3 +141,34 @@ func Intersection(av, bv []string) ([]string, map[string]string, map[string]stri
}
return s, ah, bh, nil
}
// BuildUpdatePaths takes a map of field names to field values and field masks
// and returns both a list of field names to udpate and a list of field names
// that should be set to null.
func BuildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string) (masks []string, nulls []string) {
for f, v := range fieldValues {
if !contains(fieldMask, f) {
continue
}
switch {
case isZero(v):
nulls = append(nulls, f)
default:
masks = append(masks, f)
}
}
return masks, nulls
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}
func isZero(i interface{}) bool {
return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface())
}

@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"reflect"
"strings"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -235,22 +234,3 @@ func contains(ss []string, t string) bool {
}
return false
}
func buildUpdatePaths(fieldValues map[string]interface{}, fieldMask []string) (masks []string, nulls []string) {
for f, v := range fieldValues {
if !contains(fieldMask, f) {
continue
}
switch {
case isZero(v):
nulls = append(nulls, f)
default:
masks = append(masks, f)
}
}
return masks, nulls
}
func isZero(i interface{}) bool {
return i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface())
}

@ -6,6 +6,7 @@ import (
"strings"
"github.com/hashicorp/watchtower/internal/db"
dbcommon "github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/oplog"
)
@ -64,7 +65,7 @@ func (r *Repository) UpdateGroup(ctx context.Context, group *Group, fieldMaskPat
}
}
var dbMask, nullFields []string
dbMask, nullFields = buildUpdatePaths(
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"name": group.Name,
"description": group.Description,

@ -6,6 +6,7 @@ import (
"strings"
"github.com/hashicorp/watchtower/internal/db"
dbcommon "github.com/hashicorp/watchtower/internal/db/common"
)
// CreateRole will create a role in the repository and return the written
@ -65,7 +66,7 @@ func (r *Repository) UpdateRole(ctx context.Context, role *Role, fieldMaskPaths
}
}
var dbMask, nullFields []string
dbMask, nullFields = buildUpdatePaths(
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"name": role.Name,
"description": role.Description,

@ -7,6 +7,7 @@ import (
"strings"
"github.com/hashicorp/watchtower/internal/db"
dbcommon "github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/types/scope"
)
@ -72,7 +73,7 @@ func (r *Repository) UpdateScope(ctx context.Context, scope *Scope, fieldMaskPat
return nil, db.NoRowsAffected, fmt.Errorf("update scope: you cannot change a scope's parent: %w", db.ErrInvalidFieldMask)
}
var dbMask, nullFields []string
dbMask, nullFields = buildUpdatePaths(
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"name": scope.Name,
"description": scope.Description,

@ -7,6 +7,7 @@ import (
"strings"
"github.com/hashicorp/watchtower/internal/db"
dbcommon "github.com/hashicorp/watchtower/internal/db/common"
"github.com/hashicorp/watchtower/internal/oplog"
"github.com/hashicorp/watchtower/internal/types/scope"
)
@ -56,7 +57,7 @@ func (r *Repository) UpdateUser(ctx context.Context, user *User, fieldMaskPaths
}
}
var dbMask, nullFields []string
dbMask, nullFields = buildUpdatePaths(
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"name": user.Name,
"description": user.Description,

Loading…
Cancel
Save