You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/password/repository_authmethod_test.go

636 lines
17 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package password
import (
"context"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth/password/store"
"github.com/hashicorp/boundary/internal/db"
dbassert "github.com/hashicorp/boundary/internal/db/assert"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
func TestRepository_CreateAuthMethod(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, iamRepo)
tests := []struct {
name string
in *AuthMethod
opts []Option
want *AuthMethod
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "nil-AuthMethod",
wantIsErr: errors.InvalidParameter,
wantErrMsg: "password.(Repository).CreateAuthMethod: missing AuthMethod: parameter violation: error #100",
},
{
name: "nil-embedded-AuthMethod",
in: &AuthMethod{},
wantIsErr: errors.InvalidParameter,
wantErrMsg: "password.(Repository).CreateAuthMethod: missing embedded AuthMethod: parameter violation: error #100",
},
{
name: "invalid-no-scope-id",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{},
},
wantIsErr: errors.InvalidParameter,
wantErrMsg: "password.(Repository).CreateAuthMethod: missing scope id: parameter violation: error #100",
},
{
name: "invalid-public-id-set",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
PublicId: "hcst_OOOOOOOOOO",
},
},
wantIsErr: errors.InvalidParameter,
wantErrMsg: "password.(Repository).CreateAuthMethod: public id not empty: parameter violation: error #100",
},
{
name: "valid-no-options",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
},
},
want: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
},
},
},
{
name: "valid-with-name",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
Name: "test-name-repo",
},
},
want: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
Name: "test-name-repo",
},
},
},
{
name: "valid-with-description",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
Description: ("test-description-repo"),
},
},
want: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
Description: ("test-description-repo"),
},
},
},
{
name: "invalid-with-config-nil-embedded-config",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
},
},
opts: []Option{
WithConfiguration(&Argon2Configuration{}),
},
wantIsErr: errors.PasswordInvalidConfiguration,
wantErrMsg: "password.(Repository).CreateAuthMethod: password.(Argon2Configuration).validate: missing embedded config: password violation: error #202",
},
{
name: "invalid-with-config-unknown-config-type",
in: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.PublicId,
},
},
opts: []Option{
WithConfiguration(tconf(0)),
},
wantIsErr: errors.PasswordUnsupportedConfiguration,
wantErrMsg: "password.(Repository).CreateAuthMethod: unknown configuration: password violation: error #201",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.CreateAuthMethod(ctx, tt.in, tt.opts...)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
require.NoError(err)
assert.Empty(tt.in.PublicId)
require.NotNil(got)
assertPublicId(t, globals.PasswordAuthMethodPrefix, got.PublicId)
assert.NotSame(tt.in, got)
assert.Equal(tt.want.Name, got.Name)
assert.Equal(tt.want.Description, got.Description)
assert.Equal(got.CreateTime, got.UpdateTime)
})
}
}
func TestRepository_CreateAuthMethod_DupeNames(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
t.Run("invalid-duplicate-names", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
org, _ := iam.TestScopes(t, iamRepo)
in := &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: org.GetPublicId(),
Name: "test-name-repo",
},
}
got, err := repo.CreateAuthMethod(ctx, in)
require.NoError(err)
require.NotNil(got)
assertPublicId(t, globals.PasswordAuthMethodPrefix, got.PublicId)
assert.NotSame(in, got)
assert.Equal(in.Name, got.Name)
assert.Equal(in.Description, got.Description)
assert.Equal(got.CreateTime, got.UpdateTime)
got2, err := repo.CreateAuthMethod(ctx, in)
assert.Truef(errors.Match(errors.T(errors.NotUnique), err), "Unexpected error %s", err)
assert.Nil(got2)
})
t.Run("valid-duplicate-names-diff-scopes", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
org1, _ := iam.TestScopes(t, iamRepo)
in := &AuthMethod{
AuthMethod: &store.AuthMethod{
Name: "test-name-repo",
},
}
in2 := in.Clone()
in.ScopeId = org1.GetPublicId()
got, err := repo.CreateAuthMethod(ctx, in)
require.NoError(err)
require.NotNil(got)
assertPublicId(t, globals.PasswordAuthMethodPrefix, got.PublicId)
assert.NotSame(in, got)
assert.Equal(in.Name, got.Name)
assert.Equal(in.Description, got.Description)
assert.Equal(got.CreateTime, got.UpdateTime)
org2, _ := iam.TestScopes(t, iamRepo)
in2.ScopeId = org2.GetPublicId()
got2, err := repo.CreateAuthMethod(ctx, in2)
require.NoError(err)
require.NotNil(got2)
assertPublicId(t, globals.PasswordAuthMethodPrefix, got2.PublicId)
assert.NotSame(in2, got2)
assert.Equal(in2.Name, got2.Name)
assert.Equal(in2.Description, got2.Description)
assert.Equal(got2.CreateTime, got2.UpdateTime)
})
}
func TestRepository_CreateAuthMethod_PublicId(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
t.Run("valid-with-publicid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
org1, _ := iam.TestScopes(t, iamRepo)
in := allocAuthMethod()
amId, err := newAuthMethodId(ctx)
require.NoError(err)
in.ScopeId = org1.GetPublicId()
got, err := repo.CreateAuthMethod(ctx, &in, WithPublicId(amId))
require.NoError(err)
require.NotNil(got)
assert.Equal(amId, got.GetPublicId())
assert.Equal(got.CreateTime, got.UpdateTime)
})
t.Run("invalid-with-badpublicid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
org1, _ := iam.TestScopes(t, iamRepo)
in := allocAuthMethod()
in.ScopeId = org1.GetPublicId()
got, err := repo.CreateAuthMethod(ctx, &in, WithPublicId("invalid_idwithabadprefix"))
assert.Error(err)
assert.Nil(got)
assert.Truef(errors.Match(errors.T(errors.InvalidPublicId), err), "Unexpected error %s", err)
})
}
func TestRepository_LookupAuthMethod(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
o, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
authMethod := TestAuthMethods(t, conn, o.GetPublicId(), 1)[0]
amId, err := newAuthMethodId(ctx)
require.NoError(t, err)
tests := []struct {
name string
in string
want *AuthMethod
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "With no public id",
wantIsErr: errors.InvalidPublicId,
wantErrMsg: "password.(Repository).LookupAuthMethod: missing public id: parameter violation: error #102",
},
{
name: "With non existing auth method id",
in: amId,
},
{
name: "With existing auth method id",
in: authMethod.GetPublicId(),
want: authMethod,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.LookupAuthMethod(ctx, tt.in)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
require.NoError(err)
assert.EqualValues(tt.want, got)
})
}
}
func TestRepository_DeleteAuthMethod(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
o, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
authMethod := TestAuthMethods(t, conn, o.GetPublicId(), 1)[0]
newAuthMethodId, err := newAuthMethodId(ctx)
require.NoError(t, err)
tests := []struct {
name string
in string
want int
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "With no public id",
wantIsErr: errors.InvalidPublicId,
wantErrMsg: "password.(Repository).DeleteAuthMethod: missing public id: parameter violation: error #102",
},
{
name: "With non existing auth method id",
in: newAuthMethodId,
want: 0,
},
{
name: "With existing auth method id",
in: authMethod.GetPublicId(),
want: 1,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.DeleteAuthMethod(ctx, o.GetPublicId(), tt.in)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
require.NoError(err)
assert.EqualValues(tt.want, got)
})
}
}
func TestRepository_UpdateAuthMethod(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
iamRepo := iam.TestRepo(t, conn, wrapper)
type args struct {
updates *store.AuthMethod
fieldMaskPaths []string
}
tests := []struct {
name string
args args
wantRowsUpdate int
wantErr bool
skipVersionCheck bool
}{
{
name: "change name",
args: args{
updates: &store.AuthMethod{
Name: "updated",
},
fieldMaskPaths: []string{"Name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null name",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"Name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "change description",
args: args{
updates: &store.AuthMethod{
Description: "updated",
},
fieldMaskPaths: []string{"Description"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null description",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"Description"},
},
wantErr: false,
wantRowsUpdate: 1,
skipVersionCheck: true,
},
{
name: "null name ignored description",
args: args{
updates: &store.AuthMethod{Description: "ignored"},
fieldMaskPaths: []string{"name"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "change min pw",
args: args{
updates: &store.AuthMethod{
MinPasswordLength: 13,
},
fieldMaskPaths: []string{"MinPasswordLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null min pw",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"MinPasswordLength"},
},
wantErr: false,
wantRowsUpdate: 1,
skipVersionCheck: true,
},
{
name: "change min login name",
args: args{
updates: &store.AuthMethod{
MinLoginNameLength: 13,
},
fieldMaskPaths: []string{"MinLoginNameLength"},
},
wantErr: false,
wantRowsUpdate: 1,
},
{
name: "null min login name",
args: args{
updates: &store.AuthMethod{},
fieldMaskPaths: []string{"MinLoginNameLength"},
},
wantErr: false,
wantRowsUpdate: 1,
skipVersionCheck: true,
},
{
name: "noop update",
args: args{
updates: &store.AuthMethod{
Name: "default",
},
fieldMaskPaths: []string{"name"},
},
wantErr: false,
wantRowsUpdate: 1,
skipVersionCheck: true,
},
{
name: "not fround",
args: args{
updates: &store.AuthMethod{
PublicId: func() string {
s, err := newAuthMethodId(ctx)
require.NoError(t, err)
return s
}(),
},
fieldMaskPaths: []string{"name"},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "empty field mask",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "nil field mask",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: nil,
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "read-only-fields",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{"CreateTime"},
},
wantErr: true,
wantRowsUpdate: 0,
},
{
name: "unknown fields",
args: args{
updates: &store.AuthMethod{Name: "Test"},
fieldMaskPaths: []string{"RandomUnknownName"},
},
wantErr: true,
wantRowsUpdate: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// create the initial auth method
o, _ := iam.TestScopes(t, iamRepo)
am, err := NewAuthMethod(ctx, o.GetPublicId(), WithName("default"), WithDescription("default"))
require.NoError(err)
origAM, err := repo.CreateAuthMethod(ctx, am)
require.NoError(err)
assert.EqualValues(1, origAM.Version)
amToUpdate, err := NewAuthMethod(ctx, o.GetPublicId())
require.NoError(err)
amToUpdate.PublicId = origAM.GetPublicId()
amToUpdate.Version = origAM.Version
proto.Merge(amToUpdate.AuthMethod, tt.args.updates)
assert.EqualValues(1, amToUpdate.Version)
updatedAM, updatedRows, err := repo.UpdateAuthMethod(ctx, amToUpdate, amToUpdate.Version, tt.args.fieldMaskPaths)
assert.Equal(tt.wantRowsUpdate, updatedRows)
if tt.wantErr {
require.Error(err)
assert.Nil(updatedAM)
err = db.TestVerifyOplog(t, rw, amToUpdate.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
require.Error(err)
assert.Contains(err.Error(), "record not found")
return
}
require.NoError(err)
if !tt.skipVersionCheck {
assert.EqualValues(2, updatedAM.Version)
}
assert.NotEqual(origAM.UpdateTime, updatedAM.UpdateTime)
foundAuthMethod, err := repo.LookupAuthMethod(ctx, origAM.PublicId)
require.NoError(err)
assert.Empty(cmp.Diff(updatedAM, foundAuthMethod, protocmp.Transform()))
underlyingDB, err := conn.SqlDB(ctx)
require.NoError(err)
dbassert := dbassert.New(t, underlyingDB)
if amToUpdate.Name == "" && contains(tt.args.fieldMaskPaths, "name") {
dbassert.IsNull(foundAuthMethod, "name")
}
if amToUpdate.Description == "" && contains(tt.args.fieldMaskPaths, "description") {
dbassert.IsNull(foundAuthMethod, "description")
}
err = db.TestVerifyOplog(t, rw, updatedAM.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
assert.NoError(err)
})
}
}
func assertPublicId(t *testing.T, prefix, actual string) {
t.Helper()
assert.NotEmpty(t, actual)
parts := strings.Split(actual, "_")
assert.Equalf(t, 2, len(parts), "want one '_' in PublicId, got multiple in %q", actual)
assert.Equalf(t, prefix, parts[0], "PublicId want prefix: %q, got: %q in %q", prefix, parts[0], actual)
}