mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
737 lines
22 KiB
737 lines
22 KiB
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"net/url"
|
|
"sort"
|
|
"testing"
|
|
|
|
"github.com/hashicorp/boundary/internal/auth/oidc/store"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
|
|
aead "github.com/hashicorp/go-kms-wrapping/v2/aead"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/proto"
|
|
)
|
|
|
|
func TestAuthMethod_Create(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.TODO()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
|
rw := db.New(conn)
|
|
|
|
databaseWrapper, err := kmsCache.GetWrapper(context.Background(), org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
type args struct {
|
|
scopeId string
|
|
issuer *url.URL
|
|
clientId string
|
|
clientSecret ClientSecret
|
|
opt []Option
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *AuthMethod
|
|
wantErr bool
|
|
wantIsErr errors.Code
|
|
create bool
|
|
wantNewErr bool
|
|
wantNewIsErr errors.Code
|
|
wantCreateErr bool
|
|
wantCreateIsErr errors.Code
|
|
}{
|
|
{
|
|
name: "valid",
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
opt: []Option{
|
|
WithIssuer(TestConvertToUrls(t, "http://alice.com")[0]),
|
|
WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]),
|
|
WithDescription("alice's restaurant rp"),
|
|
WithName("alice.com"),
|
|
WithMaxAge(-1),
|
|
},
|
|
},
|
|
create: true,
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.Issuer = "http://alice.com"
|
|
a.ClientId = "alice_rp"
|
|
a.ClientSecret = "rp-secret"
|
|
a.MaxAge = -1
|
|
a.Name = "alice.com"
|
|
a.Description = "alice's restaurant rp"
|
|
a.ApiUrl = "https://api.com"
|
|
return &a
|
|
}(),
|
|
},
|
|
{
|
|
name: "valid with operational state set active",
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp2",
|
|
clientSecret: ClientSecret("rp-secret2"),
|
|
opt: []Option{
|
|
WithIssuer(TestConvertToUrls(t, "http://alice.com")[0]),
|
|
WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]),
|
|
WithDescription("alice's restaurant rp"),
|
|
WithName("alice2.com"),
|
|
WithMaxAge(-1),
|
|
WithOperationalState(ActivePublicState),
|
|
WithSigningAlgs(EdDSA),
|
|
},
|
|
},
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(ActivePublicState)
|
|
a.Issuer = "http://alice.com"
|
|
a.ClientId = "alice_rp2"
|
|
a.ClientSecret = "rp-secret2"
|
|
a.MaxAge = -1
|
|
a.Name = "alice2.com"
|
|
a.Description = "alice's restaurant rp"
|
|
a.ApiUrl = "https://api.com"
|
|
a.SigningAlgs = []string{string(EdDSA)}
|
|
return &a
|
|
}(),
|
|
},
|
|
{
|
|
name: "incomplete with operational state set active",
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp3",
|
|
clientSecret: ClientSecret("rp-secret3"),
|
|
opt: []Option{
|
|
WithIssuer(TestConvertToUrls(t, "http://alice.com")[0]),
|
|
WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]),
|
|
WithDescription("alice's restaurant rp"),
|
|
WithName("alice3.com"),
|
|
WithMaxAge(-1),
|
|
WithOperationalState(ActivePublicState),
|
|
},
|
|
},
|
|
wantNewErr: true,
|
|
wantNewIsErr: errors.InvalidParameter,
|
|
},
|
|
{
|
|
name: "dup", // must follow "valid" test. combination of ScopeId, Issuer and ClientId must be unique.
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
opt: []Option{WithIssuer(TestConvertToUrls(t, "http://alice.com")[0]), WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]), WithDescription("alice's restaurant rp"), WithName("alice.com"), WithMaxAge(-1)},
|
|
},
|
|
create: true,
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.Issuer = "http://alice.com"
|
|
a.ClientId = "alice_rp"
|
|
a.ClientSecret = "rp-secret"
|
|
a.MaxAge = -1
|
|
a.Name = "alice.com"
|
|
a.Description = "alice's restaurant rp"
|
|
a.ApiUrl = "https://api.com"
|
|
return &a
|
|
}(),
|
|
wantCreateErr: true,
|
|
wantCreateIsErr: errors.NotUnique,
|
|
},
|
|
{
|
|
name: "valid-with-no-options",
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "eve_rp",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
},
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.ClientId = "eve_rp"
|
|
a.ClientSecret = "rp-secret"
|
|
return &a
|
|
}(),
|
|
},
|
|
{
|
|
name: "empty-scope-id",
|
|
args: args{
|
|
scopeId: "",
|
|
clientId: "alice_rp",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
opt: []Option{WithIssuer(TestConvertToUrls(t, "https://alice.com")[0]), WithDescription("alice's restaurant rp"), WithName("alice.com")},
|
|
},
|
|
wantErr: true,
|
|
wantIsErr: errors.InvalidParameter,
|
|
},
|
|
{
|
|
name: "nil-issuer", // should succeed.
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
opt: []Option{WithApiUrl(TestConvertToUrls(t, "https://api.com")[0])},
|
|
},
|
|
create: true,
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.Issuer = ""
|
|
a.ClientId = "alice_rp"
|
|
a.ClientSecret = "rp-secret"
|
|
a.MaxAge = 0
|
|
a.ApiUrl = "https://api.com"
|
|
return &a
|
|
}(),
|
|
},
|
|
{
|
|
name: "missing-client-id", // should succeed.
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "",
|
|
clientSecret: ClientSecret("rp-secret"),
|
|
opt: []Option{WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]), WithIssuer(TestConvertToUrls(t, "http://alice.com")[0])},
|
|
},
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.Issuer = "http://alice.com"
|
|
a.ClientId = ""
|
|
a.ClientSecret = "rp-secret"
|
|
a.MaxAge = 0
|
|
a.ApiUrl = "https://api.com"
|
|
return &a
|
|
}(),
|
|
},
|
|
{
|
|
name: "missing-client-secret", // should succeed
|
|
args: args{
|
|
scopeId: org.PublicId,
|
|
clientId: "alice_rp",
|
|
clientSecret: ClientSecret(""),
|
|
opt: []Option{WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]), WithIssuer(TestConvertToUrls(t, "http://alice.com")[0])},
|
|
},
|
|
want: func() *AuthMethod {
|
|
a := AllocAuthMethod()
|
|
a.ScopeId = org.PublicId
|
|
a.OperationalState = string(InactiveState)
|
|
a.Issuer = "http://alice.com"
|
|
a.ClientId = "alice_rp"
|
|
a.ClientSecret = ""
|
|
a.MaxAge = 0
|
|
a.ApiUrl = "https://api.com"
|
|
return &a
|
|
}(),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
got, err := NewAuthMethod(ctx, tt.args.scopeId, tt.args.clientId, tt.args.clientSecret, tt.args.opt...)
|
|
if tt.wantErr {
|
|
require.Error(err)
|
|
assert.True(errors.Match(errors.T(tt.wantIsErr), err))
|
|
return
|
|
}
|
|
if tt.wantNewErr {
|
|
assert.Error(err)
|
|
assert.True(errors.Match(errors.T(tt.wantNewIsErr), err))
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tt.want, got)
|
|
if tt.create {
|
|
ctx := context.Background()
|
|
id, err := newAuthMethodId(ctx)
|
|
require.NoError(err)
|
|
got.PublicId = id
|
|
err = got.encrypt(ctx, databaseWrapper)
|
|
require.NoError(err)
|
|
err = rw.Create(ctx, got)
|
|
if tt.wantCreateErr {
|
|
assert.Error(err)
|
|
assert.True(errors.Match(errors.T(tt.wantCreateIsErr), err))
|
|
return
|
|
} else {
|
|
assert.NoError(err)
|
|
}
|
|
|
|
found := AllocAuthMethod()
|
|
found.PublicId = got.PublicId
|
|
require.NoError(rw.LookupByPublicId(ctx, &found))
|
|
require.NoError(found.decrypt(ctx, databaseWrapper))
|
|
assert.Equal(got, &found)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMethod_Delete(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.TODO()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
|
rw := db.New(conn)
|
|
|
|
databaseWrapper, err := kmsCache.GetWrapper(context.Background(), org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
testResource := func(issuer string, clientId, clientSecret string) *AuthMethod {
|
|
got, err := NewAuthMethod(ctx, org.PublicId, clientId, ClientSecret(clientSecret),
|
|
WithIssuer(TestConvertToUrls(t, issuer)[0]), WithApiUrl(TestConvertToUrls(t, "http://api.com")[0]))
|
|
require.NoError(t, err)
|
|
id, err := newAuthMethodId(ctx)
|
|
require.NoError(t, err)
|
|
got.PublicId = id
|
|
err = got.encrypt(ctx, databaseWrapper)
|
|
require.NoError(t, err)
|
|
return got
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
authMethod *AuthMethod
|
|
overrides func(*AuthMethod)
|
|
wantRowsDeleted int
|
|
wantErr bool
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "valid",
|
|
authMethod: testResource("https://alice.com", "alice-rp", "alice's dog's name"),
|
|
wantErr: false,
|
|
wantRowsDeleted: 1,
|
|
},
|
|
{
|
|
name: "bad-id",
|
|
authMethod: testResource("https://alice.com", "alice-rp", "alice's dog's name"),
|
|
overrides: func(a *AuthMethod) { a.PublicId = "bad-id" },
|
|
wantErr: false,
|
|
wantRowsDeleted: 0,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
ctx := context.Background()
|
|
cp := tt.authMethod.Clone()
|
|
require.NoError(rw.Create(ctx, &cp))
|
|
|
|
if tt.overrides != nil {
|
|
tt.overrides(cp)
|
|
}
|
|
deletedRows, err := rw.Delete(context.Background(), &cp)
|
|
if tt.wantErr {
|
|
require.Error(err)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
if tt.wantRowsDeleted == 0 {
|
|
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
|
return
|
|
}
|
|
assert.Equal(tt.wantRowsDeleted, deletedRows)
|
|
foundAuthMethod := AllocAuthMethod()
|
|
foundAuthMethod.PublicId = tt.authMethod.PublicId
|
|
err = rw.LookupById(context.Background(), &foundAuthMethod)
|
|
require.Error(err)
|
|
assert.True(errors.IsNotFoundError(err))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMethod_Clone(t *testing.T) {
|
|
t.Parallel()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
|
|
t.Run("valid", func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
|
databaseWrapper, err := kmsCache.GetWrapper(context.Background(), org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
m := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, InactiveState, "alice_rp", "my-dogs-name",
|
|
WithIssuer(TestConvertToUrls(t, "https://alice.com")[0]), WithApiUrl(TestConvertToUrls(t, "https://alice.com")[0]))
|
|
m.DisableDiscoveredConfigValidation = true
|
|
|
|
cp := m.Clone()
|
|
assert.Equal(cp.DisableDiscoveredConfigValidation, true)
|
|
assert.True(proto.Equal(cp.AuthMethod, m.AuthMethod))
|
|
})
|
|
t.Run("not-equal", func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
|
databaseWrapper, err := kmsCache.GetWrapper(context.Background(), org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
m := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, InactiveState, "alice_rp", "my-dogs-name",
|
|
WithIssuer(TestConvertToUrls(t, "https://alice.com")[0]), WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]))
|
|
m2 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, InactiveState, "alice_rp2", "my-dogs-name",
|
|
WithIssuer(TestConvertToUrls(t, "https://alice.com")[0]), WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]))
|
|
|
|
cp := m.Clone()
|
|
assert.True(!proto.Equal(cp.AuthMethod, m2.AuthMethod))
|
|
})
|
|
}
|
|
|
|
func TestAuthMethod_SetTableName(t *testing.T) {
|
|
t.Parallel()
|
|
defaultTableName := defaultAuthMethodTableName
|
|
tests := []struct {
|
|
name string
|
|
setNameTo string
|
|
want string
|
|
}{
|
|
{
|
|
name: "new-name",
|
|
setNameTo: "new-name",
|
|
want: "new-name",
|
|
},
|
|
{
|
|
name: "reset to default",
|
|
setNameTo: "",
|
|
want: defaultTableName,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
def := AllocAuthMethod()
|
|
require.Equal(defaultTableName, def.TableName())
|
|
m := AllocAuthMethod()
|
|
m.SetTableName(tt.setNameTo)
|
|
assert.Equal(tt.want, m.TableName())
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_encrypt_decrypt_hmac(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
org, proj := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
projDatabaseWrapper, err := kmsCache.GetWrapper(ctx, proj.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
m := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, InactiveState, "alice_rp", "my-dogs-name",
|
|
WithIssuer(TestConvertToUrls(t, "https://alice.com")[0]), WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]))
|
|
|
|
tests := []struct {
|
|
name string
|
|
am *AuthMethod
|
|
hmacWrapper wrapping.Wrapper
|
|
wantHmacErrMatch *errors.Template
|
|
encryptWrapper wrapping.Wrapper
|
|
wantEncryptErrMatch *errors.Template
|
|
decryptWrapper wrapping.Wrapper
|
|
wantDecryptErrMatch *errors.Template
|
|
}{
|
|
{
|
|
name: "success",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
encryptWrapper: databaseWrapper,
|
|
decryptWrapper: databaseWrapper,
|
|
},
|
|
{
|
|
name: "encrypt-missing-wrapper",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
wantEncryptErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "encrypt-bad-wrapper",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
encryptWrapper: &aead.Wrapper{},
|
|
wantEncryptErrMatch: errors.T(errors.Encrypt),
|
|
},
|
|
{
|
|
name: "encrypt-missing-wrapper",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
encryptWrapper: databaseWrapper,
|
|
wantDecryptErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "decrypt-bad-wrapper",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
encryptWrapper: databaseWrapper,
|
|
decryptWrapper: &aead.Wrapper{},
|
|
wantDecryptErrMatch: errors.T(errors.Decrypt),
|
|
},
|
|
{
|
|
name: "decrypt-wrong-wrapper",
|
|
am: m,
|
|
hmacWrapper: databaseWrapper,
|
|
encryptWrapper: databaseWrapper,
|
|
decryptWrapper: projDatabaseWrapper,
|
|
wantDecryptErrMatch: errors.T(errors.Decrypt),
|
|
},
|
|
{
|
|
name: "hmac-missing-wrapper",
|
|
am: m,
|
|
encryptWrapper: databaseWrapper,
|
|
decryptWrapper: databaseWrapper,
|
|
wantHmacErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
{
|
|
name: "hmac-bad-wrapper",
|
|
am: m,
|
|
hmacWrapper: &aead.Wrapper{},
|
|
encryptWrapper: databaseWrapper,
|
|
decryptWrapper: databaseWrapper,
|
|
wantHmacErrMatch: errors.T(errors.InvalidParameter),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
hmacAuthMethod := tt.am.Clone()
|
|
err = hmacAuthMethod.hmacClientSecret(ctx, tt.hmacWrapper)
|
|
if tt.wantHmacErrMatch != nil {
|
|
require.Error(err)
|
|
} else {
|
|
require.NoError(err)
|
|
}
|
|
|
|
encryptedAuthMethod := tt.am.Clone()
|
|
err = encryptedAuthMethod.encrypt(ctx, tt.encryptWrapper)
|
|
if tt.wantEncryptErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantEncryptErrMatch, err), "expected %q and got err: %+v", tt.wantEncryptErrMatch.Code, err)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.NotEmpty(encryptedAuthMethod.CtClientSecret)
|
|
assert.NotEmpty(encryptedAuthMethod.ClientSecretHmac)
|
|
|
|
decryptedAuthMethod := encryptedAuthMethod.Clone()
|
|
decryptedAuthMethod.ClientSecret = ""
|
|
err = decryptedAuthMethod.decrypt(ctx, tt.decryptWrapper)
|
|
if tt.wantDecryptErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantDecryptErrMatch, err), "expected %q and got err: %+v", tt.wantDecryptErrMatch.Code, err)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tt.am.ClientSecret, decryptedAuthMethod.ClientSecret)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_convertValueObjects(t *testing.T) {
|
|
ctx := context.TODO()
|
|
testPublicId := "test-id"
|
|
|
|
testAlgs := []string{string(RS256), string(RS384)}
|
|
var testSigningAlgs []interface{}
|
|
for _, a := range []Alg{RS256, RS384} {
|
|
obj, err := NewSigningAlg(ctx, testPublicId, Alg(a))
|
|
require.NoError(t, err)
|
|
testSigningAlgs = append(testSigningAlgs, obj)
|
|
}
|
|
|
|
testAuds := []string{"alice", "eve"}
|
|
var testAudiences []interface{}
|
|
for _, a := range testAuds {
|
|
obj, err := NewAudClaim(ctx, testPublicId, a)
|
|
require.NoError(t, err)
|
|
testAudiences = append(testAudiences, obj)
|
|
}
|
|
|
|
_, pem := testGenerateCA(t, "localhost")
|
|
testCerts := []string{pem}
|
|
c, err := NewCertificate(ctx, testPublicId, pem)
|
|
require.NoError(t, err)
|
|
testCertificates := []interface{}{c}
|
|
|
|
testScopes := []string{"profile", "email"}
|
|
testClaimsScopes := make([]interface{}, 0, len(testScopes))
|
|
for _, s := range testScopes {
|
|
obj, err := NewClaimsScope(ctx, testPublicId, s)
|
|
require.NoError(t, err)
|
|
testClaimsScopes = append(testClaimsScopes, obj)
|
|
}
|
|
|
|
testClaimMaps := []string{"oid=sub", "display_name=name"}
|
|
testAccountClaimMaps := make([]interface{}, 0, len(testClaimMaps))
|
|
acms, err := ParseAccountClaimMaps(ctx, testClaimMaps...)
|
|
require.NoError(t, err)
|
|
for _, m := range acms {
|
|
toClaim, err := ConvertToAccountToClaim(ctx, m.To)
|
|
require.NoError(t, err)
|
|
obj, err := NewAccountClaimMap(ctx, testPublicId, m.From, toClaim)
|
|
require.NoError(t, err)
|
|
testAccountClaimMaps = append(testAccountClaimMaps, obj)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
authMethodId string
|
|
algs []string
|
|
auds []string
|
|
certs []string
|
|
scopes []string
|
|
maps []string
|
|
wantValues *convertedValues
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "success",
|
|
authMethodId: testPublicId,
|
|
algs: testAlgs,
|
|
auds: testAuds,
|
|
certs: testCerts,
|
|
scopes: testScopes,
|
|
maps: testClaimMaps,
|
|
wantValues: &convertedValues{
|
|
Algs: testSigningAlgs,
|
|
Auds: testAudiences,
|
|
Certs: testCertificates,
|
|
ClaimsScopes: testClaimsScopes,
|
|
AccountClaimMaps: testAccountClaimMaps,
|
|
},
|
|
},
|
|
{
|
|
name: "missing-public-id",
|
|
algs: testAlgs,
|
|
wantErrMatch: errors.T(errors.InvalidPublicId),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
am := &AuthMethod{
|
|
AuthMethod: &store.AuthMethod{
|
|
PublicId: tt.authMethodId,
|
|
SigningAlgs: tt.algs,
|
|
AudClaims: tt.auds,
|
|
Certificates: tt.certs,
|
|
ClaimsScopes: tt.scopes,
|
|
AccountClaimMaps: tt.maps,
|
|
},
|
|
}
|
|
|
|
convertedAlgs, err := am.convertSigningAlgs(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
} else {
|
|
assert.Equal(tt.wantValues.Algs, convertedAlgs)
|
|
}
|
|
|
|
convertedAuds, err := am.convertAudClaims(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
} else {
|
|
assert.Equal(tt.wantValues.Auds, convertedAuds)
|
|
}
|
|
|
|
convertedCerts, err := am.convertCertificates(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
} else {
|
|
assert.Equal(tt.wantValues.Certs, convertedCerts)
|
|
}
|
|
|
|
convertedScopes, err := am.convertClaimsScopes(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
} else {
|
|
assert.Equal(tt.wantValues.ClaimsScopes, convertedScopes)
|
|
}
|
|
|
|
convertedMaps, err := am.convertAccountClaimMaps(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
} else {
|
|
want := make([]*AccountClaimMap, 0, len(tt.wantValues.AccountClaimMaps))
|
|
for _, v := range tt.wantValues.AccountClaimMaps {
|
|
want = append(want, v.(*AccountClaimMap))
|
|
}
|
|
got := make([]*AccountClaimMap, 0, len(convertedMaps))
|
|
for _, v := range convertedMaps {
|
|
got = append(got, v.(*AccountClaimMap))
|
|
}
|
|
sort.Slice(want, func(a, b int) bool {
|
|
return want[a].ToClaim < want[b].ToClaim
|
|
})
|
|
sort.Slice(got, func(a, b int) bool {
|
|
return got[a].ToClaim < got[b].ToClaim
|
|
})
|
|
assert.Equal(want, got)
|
|
}
|
|
|
|
values, err := am.convertValueObjects(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted err %q and got: %+v", tt.wantErrMatch.Code, err)
|
|
if tt.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tt.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
testSortConverted(t, tt.wantValues)
|
|
testSortConverted(t, values)
|
|
assert.Equal(tt.wantValues, values)
|
|
})
|
|
}
|
|
}
|
|
|
|
type converted []interface{}
|
|
|
|
func (a converted) Len() int { return len(a) }
|
|
func (a converted) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
|
func (a converted) Less(i, j int) bool {
|
|
switch a[i].(type) {
|
|
case *SigningAlg:
|
|
return a[i].(*SigningAlg).GetAlg() < a[j].(*SigningAlg).GetAlg()
|
|
case *Certificate:
|
|
return a[i].(*Certificate).GetCert() < a[j].(*Certificate).GetCert()
|
|
case *AccountClaimMap:
|
|
return a[i].(*AccountClaimMap).GetToClaim() < a[j].(*AccountClaimMap).GetToClaim()
|
|
case *AudClaim:
|
|
return a[i].(*AudClaim).GetAud() < a[j].(*AudClaim).GetAud()
|
|
case *ClaimsScope:
|
|
return a[i].(*ClaimsScope).GetScope() < a[j].(*ClaimsScope).GetScope()
|
|
}
|
|
return false
|
|
}
|
|
|
|
func testSortConverted(t *testing.T, c *convertedValues) {
|
|
sort.Sort(converted(c.Algs))
|
|
sort.Sort(converted(c.Certs))
|
|
sort.Sort(converted(c.AccountClaimMaps))
|
|
sort.Sort(converted(c.Auds))
|
|
sort.Sort(converted(c.ClaimsScopes))
|
|
}
|