You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/ldap/auth_method_test.go

624 lines
19 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/x509"
"net/url"
"sort"
"testing"
"github.com/hashicorp/boundary/internal/auth/ldap/store"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func TestNewAuthMethod(t *testing.T) {
t.Parallel()
testCtx := context.Background()
testCert, testCertEncoded := TestGenerateCA(t, "localhost")
_, testPrivKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
derPrivKey, err := x509.MarshalPKCS8PrivateKey(testPrivKey)
require.NoError(t, err)
tests := []struct {
name string
ctx context.Context
scopeId string
urls []*url.URL
opts []Option
want *AuthMethod
wantErr bool
wantErrCode errors.Code
wantErrContains string
}{
{
name: "valid-all-opts",
ctx: testCtx,
scopeId: "global",
urls: TestConvertToUrls(t, "ldaps://alice.com"), // converted to an option
opts: []Option{
WithName(testCtx, "test-name"),
WithDescription(testCtx, "test-description"),
WithStartTLS(testCtx),
WithInsecureTLS(testCtx),
WithDiscoverDn(testCtx),
WithAnonGroupSearch(testCtx),
WithUpnDomain(testCtx, "alice.com"),
WithUserDn(testCtx, "user-dn"),
WithUserAttr(testCtx, "user-attr"),
WithUserFilter(testCtx, "user-filter"),
WithGroupDn(testCtx, "group-dn"),
WithEnableGroups(testCtx),
WithGroupAttr(testCtx, "group-attr"),
WithGroupFilter(testCtx, "group-filter"),
WithBindCredential(testCtx, "bind-dn", "bind-password"),
WithCertificates(testCtx, testCert),
WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test.
WithAccountAttributeMap(testCtx, map[string]AccountToAttribute{"mail": "email"}),
WithDerefAliases(testCtx, DerefAlways),
WithMaximumPageSize(testCtx, 10),
},
want: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: "global",
Urls: []string{"ldaps://alice.com"},
OperationalState: string(InactiveState),
Name: "test-name",
Description: "test-description",
StartTls: true,
InsecureTls: true,
DiscoverDn: true,
AnonGroupSearch: true,
UpnDomain: "alice.com",
UserDn: "user-dn",
UserAttr: "user-attr",
UserFilter: "user-filter",
EnableGroups: true,
GroupDn: "group-dn",
GroupAttr: "group-attr",
GroupFilter: "group-filter",
BindDn: "bind-dn",
BindPassword: "bind-password",
Certificates: []string{testCertEncoded},
ClientCertificate: testCertEncoded,
ClientCertificateKey: derPrivKey,
AccountAttributeMaps: []string{"mail=email"},
DereferenceAliases: string(DerefAlways),
MaximumPageSize: 10,
},
},
},
{
name: "valid-no-opts",
ctx: testCtx,
scopeId: "global",
urls: TestConvertToUrls(t, "ldaps://alice.com"),
want: &AuthMethod{
AuthMethod: &store.AuthMethod{
ScopeId: "global",
Urls: []string{"ldaps://alice.com"},
OperationalState: string(InactiveState),
},
},
},
{
name: "missing-scope",
ctx: testCtx,
urls: TestConvertToUrls(t, "ldaps://alice.com"),
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrContains: "missing scope id",
},
{
name: "invalid-urls",
ctx: testCtx,
scopeId: "global",
urls: func() []*url.URL {
parsed, err := url.Parse("https://alice.com")
require.NoError(t, err)
return []*url.URL{parsed}
}(),
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrContains: `https scheme in url "https://alice.com" is not either ldap or ldaps`,
},
{
name: "invalid-deref-aliases",
ctx: testCtx,
scopeId: "global",
urls: TestConvertToUrls(t, "ldaps://alice.com"),
opts: []Option{
WithDerefAliases(testCtx, "invalid"),
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrContains: `"invalid" is not a valid ldap dereference alias type`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tc.opts = append(tc.opts, WithUrls(tc.ctx, tc.urls...))
am, err := NewAuthMethod(tc.ctx, tc.scopeId, tc.opts...)
if tc.wantErr {
require.Error(err)
if tc.wantErrCode != errors.Unknown {
assert.True(errors.Match(errors.T(tc.wantErrCode), err))
}
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)
assert.Equal(tc.want, am)
})
}
}
func TestAuthMethod_SetTableName(t *testing.T) {
t.Parallel()
defaultTableName := authMethodTableName
tests := []struct {
name string
setNameTo string
want string
}{
{
name: "new-name",
setNameTo: "new-name",
want: "new-name",
},
{
name: "reset to default",
setNameTo: "",
want: defaultTableName,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
def := AllocAuthMethod()
require.Equal(defaultTableName, def.TableName())
m := AllocAuthMethod()
m.SetTableName(tt.setNameTo)
assert.Equal(tt.want, m.TableName())
})
}
}
func TestAuthMethod_clone(t *testing.T) {
t.Parallel()
testCtx := context.Background()
t.Run("valid", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
testCert, _ := TestGenerateCA(t, "localhost")
_, testPrivKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(err)
derPrivKey, err := x509.MarshalPKCS8PrivateKey(testPrivKey)
require.NoError(err)
am, err := NewAuthMethod(
testCtx,
"global",
WithUrls(testCtx, TestConvertToUrls(t, "ldaps://alice.com")...),
WithStartTLS(testCtx),
WithInsecureTLS(testCtx),
WithDiscoverDn(testCtx),
WithAnonGroupSearch(testCtx),
WithUpnDomain(testCtx, "alice.com"),
WithUserDn(testCtx, "user-dn"),
WithUserAttr(testCtx, "user-attr"),
WithUserFilter(testCtx, "user-filter"),
WithGroupDn(testCtx, "group-dn"),
WithGroupAttr(testCtx, "group-attr"),
WithGroupFilter(testCtx, "group-filter"),
WithBindCredential(testCtx, "bind-dn", "bind-password"),
WithCertificates(testCtx, testCert),
WithClientCertificate(testCtx, derPrivKey, testCert), // not a client cert but good enough for this test.
WithDerefAliases(testCtx, DerefAlways),
WithMaximumPageSize(testCtx, 10),
)
require.NoError(err)
cp := am.clone()
assert.True(proto.Equal(cp.AuthMethod, am.AuthMethod))
})
t.Run("not-equal", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
am, err := NewAuthMethod(testCtx, "global", WithUrls(testCtx, TestConvertToUrls(t, "ldaps://alice.com")...))
require.NoError(err)
am2, err := NewAuthMethod(testCtx, "global", WithUrls(testCtx, TestConvertToUrls(t, "ldaps://bob.com")...))
require.NoError(err)
cp := am.clone()
assert.True(!proto.Equal(cp.AuthMethod, am2.AuthMethod))
})
}
func TestAuthMethod_oplog(t *testing.T) {
t.Parallel()
testCtx := context.Background()
testAm, err := NewAuthMethod(testCtx, "global", WithUrls(testCtx, TestConvertToUrls(t, "ldap://ldap1")...))
testAm.PublicId = "test-public-id"
require.NoError(t, err)
tests := []struct {
name string
ctx context.Context
am *AuthMethod
opType oplog.OpType
want oplog.Metadata
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "create",
ctx: testCtx,
am: testAm,
opType: oplog.OpType_OP_TYPE_CREATE,
want: oplog.Metadata{
"resource-public-id": {"test-public-id"},
"scope-id": {"global"},
"op-type": {oplog.OpType_OP_TYPE_CREATE.String()},
"resource-type": {"ldap auth method"},
},
},
{
name: "update",
ctx: testCtx,
am: testAm,
opType: oplog.OpType_OP_TYPE_UPDATE,
want: oplog.Metadata{
"resource-public-id": {"test-public-id"},
"scope-id": {"global"},
"op-type": {oplog.OpType_OP_TYPE_UPDATE.String()},
"resource-type": {"ldap auth method"},
},
},
{
name: "missing-scope-id",
ctx: testCtx,
am: func() *AuthMethod {
cp := testAm.clone()
cp.ScopeId = ""
return cp
}(),
opType: oplog.OpType_OP_TYPE_UPDATE,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing scope id",
},
{
name: "missing-public-id",
ctx: testCtx,
am: func() *AuthMethod {
cp := testAm.clone()
cp.PublicId = ""
return cp
}(),
opType: oplog.OpType_OP_TYPE_UPDATE,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing public id",
},
{
name: "missing-op-type",
ctx: testCtx,
am: testAm,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing op type",
},
{
name: "missing-auth-method",
ctx: testCtx,
opType: oplog.OpType_OP_TYPE_UPDATE,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing auth method",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := tc.am.oplog(tc.ctx, tc.opType)
if tc.wantErrMatch != nil {
require.Error(err)
assert.Nil(got)
assert.True(errors.Match(tc.wantErrMatch, err))
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)
assert.Equal(tc.want, got)
})
}
}
func Test_convertValueObjects(t *testing.T) {
testCtx := context.TODO()
testPublicId := "test-id"
testLdapServers := []string{"ldaps://ldap1.alice.com", "ldaps://ldap2.alice.com"}
_, pem := TestGenerateCA(t, "localhost")
testCerts := []string{pem}
c, err := NewCertificate(testCtx, testPublicId, pem)
require.NoError(t, err)
testCertificates := []any{c}
testUrls := make([]any, 0, len(testLdapServers))
for priority, uu := range TestConvertToUrls(t, testLdapServers...) {
u, err := NewUrl(testCtx, testPublicId, priority+1, uu)
require.NoError(t, err)
testUrls = append(testUrls, u)
}
testAttrMaps := []string{"email_address=email", "display_name=fullName"}
testAccountAttributeMaps := make([]any, 0, len(testAttrMaps))
acms, err := ParseAccountAttributeMaps(testCtx, testAttrMaps...)
require.NoError(t, err)
for _, m := range acms {
toAttribute, err := ConvertToAccountToAttribute(testCtx, m.To)
require.NoError(t, err)
obj, err := NewAccountAttributeMap(testCtx, testPublicId, m.From, toAttribute)
require.NoError(t, err)
testAccountAttributeMaps = append(testAccountAttributeMaps, obj)
}
testUserSearchConf, err := NewUserEntrySearchConf(testCtx, testPublicId, WithUserDn(testCtx, "user-dn"), WithUserAttr(testCtx, "user-attr"))
require.NoError(t, err)
testGroupSearchConf, err := NewGroupEntrySearchConf(testCtx, testPublicId, WithGroupDn(testCtx, "group-dn"), WithGroupAttr(testCtx, "group-attr"))
require.NoError(t, err)
_, testClientCertEncoded := TestGenerateCA(t, "client-cert-host")
_, testPrivKey, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
testClientCertKey, err := x509.MarshalPKCS8PrivateKey(testPrivKey)
require.NoError(t, err)
testClientCertificate, err := NewClientCertificate(testCtx, testPublicId, testClientCertKey, testClientCertEncoded)
require.NoError(t, err)
testBindCredential, err := NewBindCredential(testCtx, testPublicId, "bind-dn", []byte("bind-password"))
require.NoError(t, err)
testDerefAliases, err := NewDerefAliases(testCtx, testPublicId, DerefAlways)
require.NoError(t, err)
tests := []struct {
name string
am *AuthMethod
wantValues *convertedValues
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "success",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
Certificates: testCerts,
Urls: testLdapServers,
UserDn: "user-dn",
UserAttr: "user-attr",
GroupDn: "group-dn",
GroupAttr: "group-attr",
ClientCertificateKey: testClientCertKey,
ClientCertificate: testClientCertEncoded,
BindDn: "bind-dn",
BindPassword: "bind-password",
AccountAttributeMaps: testAttrMaps,
DereferenceAliases: string(DerefAlways),
},
},
wantValues: &convertedValues{
Certs: testCertificates,
Urls: testUrls,
UserEntrySearchConf: testUserSearchConf,
GroupEntrySearchConf: testGroupSearchConf,
ClientCertificate: testClientCertificate,
BindCredential: testBindCredential,
AccountAttributeMaps: testAccountAttributeMaps,
DerefAliases: testDerefAliases,
},
},
{
name: "missing-public-id",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
Certificates: testCerts,
Urls: testLdapServers,
UserDn: "user-dn",
UserAttr: "user-attr",
GroupDn: "group-dn",
GroupAttr: "group-attr",
ClientCertificateKey: testClientCertKey,
ClientCertificate: testClientCertEncoded,
BindDn: "bind-dn",
BindPassword: "bind-password",
},
},
wantErrMatch: errors.T(errors.InvalidPublicId),
},
{
name: "invalid-to-account-attr-map",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
AccountAttributeMaps: []string{"displayName=invalid-to-attr"},
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "not a valid ToAccountAttribute value",
},
{
name: "invalid-account-attr-map-format",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
AccountAttributeMaps: []string{"not-valid"},
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "format must be key=value",
},
{
name: "invalid-cert",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
Certificates: []string{TestInvalidPem},
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "failed to parse certificate",
},
{
name: "invalid-url-scheme",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
Urls: []string{"https://ldap1"},
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "scheme \"https\" is not ldap or ldaps",
},
{
name: "invalid-url-starts-with-space",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
Urls: []string{" ldaps://ldap1"},
},
},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "first path segment in URL cannot contain colon",
},
{
name: "invalid-client-cert",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
ClientCertificateKey: testClientCertKey,
ClientCertificate: TestInvalidPem,
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "failed to parse certificate",
},
{
name: "invalid-deref-aliases",
am: &AuthMethod{
AuthMethod: &store.AuthMethod{
PublicId: testPublicId,
DereferenceAliases: "invalid",
},
},
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: `"invalid" is not a valid ldap dereference alias type:`,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
values, err := tc.am.convertValueObjects(testCtx)
if tc.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tc.wantErrMatch, err), "wanted err %q and got: %+v", tc.wantErrMatch.Code, err)
if tc.wantErrContains != "" {
assert.Contains(err.Error(), tc.wantErrContains)
}
return
}
require.NoError(err)
testSortConverted(t, tc.wantValues)
testSortConverted(t, values)
assert.Equal(tc.wantValues, values)
})
}
t.Run("missing-public-id", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
wantErrMatch := errors.T(errors.InvalidPublicId)
am := &AuthMethod{
AuthMethod: &store.AuthMethod{
Certificates: testCerts,
Urls: testLdapServers,
UserDn: "user-dn",
UserAttr: "user-attr",
GroupDn: "group-dn",
GroupAttr: "group-attr",
ClientCertificateKey: testClientCertKey,
ClientCertificate: testClientCertEncoded,
BindDn: "bind-dn",
BindPassword: "bind-password",
AccountAttributeMaps: testAttrMaps,
},
}
convertedCerts, err := am.convertCertificates(testCtx)
require.Error(err)
assert.Nil(convertedCerts)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedUrls, err := am.convertUrls(testCtx)
require.Error(err)
assert.Nil(convertedUrls)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedMaps, err := am.convertAccountAttributeMaps(testCtx)
require.Error(err)
assert.Nil(convertedMaps)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedUserSearchConf, err := am.convertUserEntrySearchConf(testCtx)
require.Error(err)
assert.Nil(convertedUserSearchConf)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedGroupSearchConf, err := am.convertGroupEntrySearchConf(testCtx)
require.Error(err)
assert.Nil(convertedGroupSearchConf)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedClientCertificate, err := am.convertClientCertificate(testCtx)
require.Error(err)
assert.Nil(convertedClientCertificate)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedBindCredential, err := am.convertBindCredential(testCtx)
require.Error(err)
assert.Nil(convertedBindCredential)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
convertedDerefAliases, err := am.convertDerefAliases(testCtx)
require.Error(err)
assert.Nil(convertedDerefAliases)
assert.Truef(errors.Match(wantErrMatch, err), "wanted err %q and got: %+v", wantErrMatch.Code, err)
})
}
type converted []any
func (a converted) Len() int { return len(a) }
func (a converted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a converted) Less(i, j int) bool {
switch a[i].(type) {
case *Url:
return a[i].(*Url).GetServerUrl() < a[j].(*Url).GetServerUrl()
case *Certificate:
return a[i].(*Certificate).GetCert() < a[j].(*Certificate).GetCert()
}
return false
}
func testSortConverted(t *testing.T, c *convertedValues) {
sort.Sort(converted(c.Urls))
sort.Sort(converted(c.Certs))
}