mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
275 lines
8.5 KiB
275 lines
8.5 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package ldap
|
|
|
|
import (
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/x509"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/hashicorp/boundary/internal/oplog"
|
|
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/testing/protocmp"
|
|
)
|
|
|
|
func TestRepository_CreateAuthMethod(t *testing.T) {
|
|
testConn, _ := db.TestSetup(t, "postgres")
|
|
testRw := db.New(testConn)
|
|
testWrapper := db.TestWrapper(t)
|
|
testKms := kms.TestKms(t, testConn, testWrapper)
|
|
testCtx := context.Background()
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, testConn, testWrapper))
|
|
testCert, _ := TestGenerateCA(t, "localhost")
|
|
testCert2, _ := TestGenerateCA(t, "localhost")
|
|
_, testPrivKey, err := ed25519.GenerateKey(rand.Reader)
|
|
require.NoError(t, err)
|
|
derPrivKey, err := x509.MarshalPKCS8PrivateKey(testPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
testAm, err := NewAuthMethod(
|
|
testCtx,
|
|
org.PublicId,
|
|
WithUrls(testCtx, TestConvertToUrls(t, "ldaps://ldap1", "ldap://ldap2")...),
|
|
WithName(testCtx, "test-name"),
|
|
WithDescription(testCtx, "test-description"),
|
|
WithStartTLS(testCtx),
|
|
WithInsecureTLS(testCtx),
|
|
WithDiscoverDn(testCtx),
|
|
WithAnonGroupSearch(testCtx),
|
|
WithUpnDomain(testCtx, "alice.com"),
|
|
WithUserDn(testCtx, "user-dn"),
|
|
WithUserAttr(testCtx, "user-attr"),
|
|
WithUserFilter(testCtx, "user-filter"),
|
|
WithEnableGroups(testCtx),
|
|
WithUseTokenGroups(testCtx),
|
|
WithGroupDn(testCtx, "group-dn"),
|
|
WithGroupAttr(testCtx, "group-attr"),
|
|
WithGroupFilter(testCtx, "group-filter"),
|
|
WithBindCredential(testCtx, "bind-dn", "bind-password"),
|
|
WithCertificates(testCtx, testCert, testCert2),
|
|
WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test.
|
|
WithAccountAttributeMap(testCtx, map[string]AccountToAttribute{
|
|
"mail": ToEmailAttribute,
|
|
}),
|
|
WithDerefAliases(testCtx, DerefFindingBaseObj),
|
|
WithMaximumPageSize(testCtx, 10),
|
|
)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
kms kms.GetWrapperer
|
|
setup func(*testing.T) *AuthMethod
|
|
opt []Option
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "valid",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
return testAm.clone()
|
|
},
|
|
},
|
|
{
|
|
name: "bind-cred-encrypt-err",
|
|
kms: &mockGetWrapperer{
|
|
returnDbWrapper: &kms.MockWrapper{
|
|
EncryptErr: errors.New(testCtx, errors.Encrypt, "test", "bind-cred-encrypt-err"),
|
|
},
|
|
},
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
return testAm.clone()
|
|
},
|
|
wantErrMatch: errors.T(errors.Unknown),
|
|
wantErrContains: "bind-cred-encrypt-err",
|
|
},
|
|
{
|
|
name: "get-db-wrapper-err",
|
|
kms: &mockGetWrapperer{
|
|
getErr: errors.New(testCtx, errors.Encrypt, "test", "get-db-wrapper-err"),
|
|
},
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
return testAm.clone()
|
|
},
|
|
wantErrMatch: errors.T(errors.Encrypt),
|
|
wantErrContains: "unable to get database wrapper",
|
|
},
|
|
{
|
|
name: "get-oplog-wrapper-err",
|
|
kms: &mockGetWrapperer{
|
|
getErr: errors.New(testCtx, errors.Encrypt, "test", "get-db-wrapper-err"),
|
|
returnDbWrapper: testWrapper,
|
|
},
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
return testAm.clone()
|
|
},
|
|
wantErrMatch: errors.T(errors.Encrypt),
|
|
wantErrContains: "unable to get oplog wrapper",
|
|
},
|
|
{
|
|
name: "bad-state",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
am, err := NewAuthMethod(testCtx, org.PublicId, WithUrls(testCtx, TestConvertToUrls(t, "ldaps://ldap1")...))
|
|
require.NoError(t, err)
|
|
am.OperationalState = "not-a-valid-state"
|
|
return am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "invalid state",
|
|
},
|
|
{
|
|
name: "missing-auth-method",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
return nil
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing auth method",
|
|
},
|
|
{
|
|
name: "missing-scope",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
am, err := NewAuthMethod(testCtx, org.PublicId, WithUrls(testCtx, TestConvertToUrls(t, "ldaps://ldap1")...))
|
|
require.NoError(t, err)
|
|
am.ScopeId = ""
|
|
return am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing scope id",
|
|
},
|
|
{
|
|
name: "convert-err",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
am, err := NewAuthMethod(testCtx, org.PublicId, WithUrls(testCtx, TestConvertToUrls(t, "ldaps://ldap1")...))
|
|
require.NoError(t, err)
|
|
am.BindDn = "bind-dn"
|
|
return am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing password",
|
|
},
|
|
{
|
|
name: "missing-urls",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
am, err := NewAuthMethod(testCtx, org.PublicId, WithUrls(testCtx, TestConvertToUrls(t, "ldaps://ldap1")...))
|
|
require.NoError(t, err)
|
|
am.Urls = nil
|
|
return am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing urls (there must be at least one)",
|
|
},
|
|
{
|
|
name: "bad-public-id",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
id, err := newAuthMethodId(testCtx)
|
|
require.NoError(t, err)
|
|
am := AllocAuthMethod()
|
|
am.PublicId = id
|
|
return &am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "public id must be empty",
|
|
},
|
|
{
|
|
name: "bad-version",
|
|
kms: testKms,
|
|
setup: func(t *testing.T) *AuthMethod {
|
|
am := AllocAuthMethod()
|
|
am.Version = 22
|
|
return &am
|
|
},
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "version must be empty",
|
|
},
|
|
}
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
repo, err := NewRepository(testCtx, testRw, testRw, tc.kms)
|
|
assert.NoError(err)
|
|
require.NotNil(repo)
|
|
am := tc.setup(t)
|
|
got, err := repo.CreateAuthMethod(testCtx, am, tc.opt...)
|
|
if tc.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tc.wantErrMatch, err), "want err code: %q got: %q", tc.wantErrMatch, err)
|
|
assert.Nil(got)
|
|
|
|
if am != nil {
|
|
err := db.TestVerifyOplog(t, testRw, am.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
|
|
require.Errorf(err, "should not have found oplog entry for %s", am.PublicId)
|
|
}
|
|
|
|
if tc.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tc.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
am.PublicId = got.PublicId
|
|
am.CreateTime = got.CreateTime
|
|
am.UpdateTime = got.UpdateTime
|
|
am.Version = got.Version
|
|
am.BindPasswordHmac = got.BindPasswordHmac
|
|
am.ClientCertificateKeyHmac = got.ClientCertificateKeyHmac
|
|
TestSortAuthMethods(t, []*AuthMethod{am, got})
|
|
assert.Truef(proto.Equal(am.AuthMethod, got.AuthMethod), "got %+v expected %+v", got.AuthMethod, am.AuthMethod)
|
|
|
|
err = db.TestVerifyOplog(t, testRw, am.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
|
|
require.NoErrorf(err, "unexpected error verifying oplog entry: %s", err)
|
|
|
|
found, err := repo.LookupAuthMethod(testCtx, am.PublicId)
|
|
require.NoError(err)
|
|
found.CreateTime = got.CreateTime
|
|
found.UpdateTime = got.UpdateTime
|
|
found.Version = got.Version
|
|
TestSortAuthMethods(t, []*AuthMethod{found, am})
|
|
assert.Empty(cmp.Diff(found.AuthMethod, am.AuthMethod, protocmp.Transform()))
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockGetWrapperer struct {
|
|
// kms is the underlying kms which is used to provide the mock's default
|
|
// behavior
|
|
kms kms.GetWrapperer
|
|
|
|
// getErr is a mock value to return for the GetWrapper(...) operation
|
|
getErr error
|
|
|
|
returnOplogWrapper wrapping.Wrapper
|
|
returnDbWrapper wrapping.Wrapper
|
|
}
|
|
|
|
func (m *mockGetWrapperer) GetWrapper(ctx context.Context, scopeId string, purpose kms.KeyPurpose, opt ...kms.Option) (wrapping.Wrapper, error) {
|
|
switch {
|
|
case purpose == kms.KeyPurposeOplog && m.returnOplogWrapper != nil:
|
|
return m.returnOplogWrapper, nil
|
|
case purpose == kms.KeyPurposeDatabase && m.returnDbWrapper != nil:
|
|
return m.returnDbWrapper, nil
|
|
case m.getErr != nil:
|
|
return nil, m.getErr
|
|
default:
|
|
return m.kms.GetWrapper(ctx, scopeId, purpose, opt...)
|
|
}
|
|
}
|