You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/oidc/repository_auth_method_crea...

233 lines
7.7 KiB

package oidc
import (
"context"
"crypto/x509"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
)
func TestRepository_CreateAuthMethod(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
ctx := context.Background()
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
convertAlg := func(alg ...Alg) []string {
s := make([]string, 0, len(alg))
for _, a := range alg {
s = append(s, string(a))
}
return s
}
tests := []struct {
name string
am func(*testing.T) *AuthMethod
opt []Option
wantErrMatch *errors.Template
}{
{
name: "valid",
am: func(t *testing.T) *AuthMethod {
algs := []Alg{RS256, ES256}
cbs := TestConvertToUrls(t, "https://www.alice.com/callback")[0]
auds := []string{"alice-rp", "bob-rp"}
cert1, pem1 := testGenerateCA(t, "localhost")
cert2, pem2 := testGenerateCA(t, "localhost")
certs := []*x509.Certificate{cert1, cert2}
pems := []string{pem1, pem2}
am, err := NewAuthMethod(
ctx,
org.PublicId,
"alice-rp",
"alice-secret", WithAudClaims("alice-rp"),
WithAudClaims(auds...),
WithIssuer(TestConvertToUrls(t, "https://www.alice.com")[0]),
WithApiUrl(cbs),
WithSigningAlgs(algs...),
WithCertificates(certs...),
WithName("alice's restaurant"),
WithDescription("it's a good place to eat"),
WithClaimsScopes("email", "profile"),
WithAccountClaimMap(map[string]AccountToClaim{"display_name": ToNameClaim, "oid": ToSubClaim}),
)
require.NoError(t, err)
require.Equal(t, am.SigningAlgs, convertAlg(algs...))
require.Equal(t, am.ApiUrl, cbs.String())
require.Equal(t, "https://www.alice.com", am.Issuer)
require.Equal(t, am.AudClaims, auds)
require.Equal(t, am.Certificates, pems)
require.Equal(t, am.OperationalState, string(InactiveState))
return am
},
},
{
name: "valid with custom ID",
am: func(t *testing.T) *AuthMethod {
algs := []Alg{RS256, ES256}
cbs := TestConvertToUrls(t, "https://www.alice.com/callback")[0]
auds := []string{"alice-rp-custom", "bob-rp-custom"}
cert1, pem1 := testGenerateCA(t, "localhost")
cert2, pem2 := testGenerateCA(t, "localhost")
certs := []*x509.Certificate{cert1, cert2}
pems := []string{pem1, pem2}
am, err := NewAuthMethod(
ctx,
org.PublicId,
"alice-rp-custom",
"alice-secret-custom", WithAudClaims("alice-rp-custom"),
WithAudClaims(auds...),
WithIssuer(TestConvertToUrls(t, "https://www.alice.com")[0]),
WithApiUrl(cbs),
WithSigningAlgs(algs...),
WithCertificates(certs...),
WithName("alice's restaurant with a twist"),
WithDescription("it's an okay but kinda weird place to eat"),
WithClaimsScopes("email", "profile"),
WithAccountClaimMap(map[string]AccountToClaim{"display_name": ToNameClaim, "oid": ToSubClaim}),
)
require.NoError(t, err)
require.Equal(t, am.SigningAlgs, convertAlg(algs...))
require.Equal(t, am.ApiUrl, cbs.String())
require.Equal(t, "https://www.alice.com", am.Issuer)
require.Equal(t, am.AudClaims, auds)
require.Equal(t, am.Certificates, pems)
require.Equal(t, am.OperationalState, string(InactiveState))
return am
},
opt: []Option{WithPublicId("amoidc_1234567890")},
},
{
name: "bad custom ID",
am: func(t *testing.T) *AuthMethod {
algs := []Alg{RS256, ES256}
cbs := TestConvertToUrls(t, "https://www.alice.com/callback")[0]
auds := []string{"alice-rp-bad", "bob-rp-bad"}
cert1, pem1 := testGenerateCA(t, "localhost")
cert2, pem2 := testGenerateCA(t, "localhost")
certs := []*x509.Certificate{cert1, cert2}
pems := []string{pem1, pem2}
am, err := NewAuthMethod(
ctx,
org.PublicId,
"alice-rp-bad",
"alice-secret-bad", WithAudClaims("alice-rp-bad"),
WithAudClaims(auds...),
WithIssuer(TestConvertToUrls(t, "https://www.alice.com")[0]),
WithApiUrl(cbs),
WithSigningAlgs(algs...),
WithCertificates(certs...),
WithName("alice's restaurant is bad"),
WithDescription("their food is awful"),
WithClaimsScopes("email", "profile"),
WithAccountClaimMap(map[string]AccountToClaim{"display_name": ToNameClaim, "oid": ToSubClaim}),
)
require.NoError(t, err)
require.Equal(t, am.SigningAlgs, convertAlg(algs...))
require.Equal(t, am.ApiUrl, cbs.String())
require.Equal(t, "https://www.alice.com", am.Issuer)
require.Equal(t, am.AudClaims, auds)
require.Equal(t, am.Certificates, pems)
require.Equal(t, am.OperationalState, string(InactiveState))
return am
},
opt: []Option{WithPublicId("amoic_1234567890")},
wantErrMatch: errors.T(errors.InvalidParameter),
},
{
name: "bad-state",
am: func(t *testing.T) *AuthMethod {
am, err := NewAuthMethod(ctx, org.PublicId, "bad-state-rp", "alice-secret",
WithAudClaims("alice-rp"), WithApiUrl(TestConvertToUrls(t, "https://api.com")[0]))
require.NoError(t, err)
am.OperationalState = "not-a-valid-state"
return am
},
wantErrMatch: errors.T(errors.InvalidParameter),
},
{
name: "missing-auth-method",
am: func(t *testing.T) *AuthMethod {
return nil
},
wantErrMatch: errors.T(errors.InvalidParameter),
},
{
name: "bad-public-id",
am: func(t *testing.T) *AuthMethod {
id, err := newAuthMethodId(ctx)
require.NoError(t, err)
am := AllocAuthMethod()
am.PublicId = id
return &am
},
wantErrMatch: errors.T(errors.InvalidParameter),
},
{
name: "bad-version",
am: func(t *testing.T) *AuthMethod {
am := AllocAuthMethod()
am.Version = 22
return &am
},
wantErrMatch: errors.T(errors.InvalidParameter),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kmsCache)
assert.NoError(err)
require.NotNil(repo)
am := tt.am(t)
got, err := repo.CreateAuthMethod(ctx, am, tt.opt...)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "want err code: %q got: %q", tt.wantErrMatch, err)
assert.Nil(got)
if am != nil {
err := db.TestVerifyOplog(t, rw, am.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
require.Errorf(err, "should not have found oplog entry for %s", am.PublicId)
}
return
}
require.NoError(err)
if tt.opt != nil {
if opts := getOpts(tt.opt...); opts.withPublicId != "" {
require.Equal(opts.withPublicId, got.PublicId)
}
}
am.PublicId = got.PublicId
am.CreateTime = got.CreateTime
am.UpdateTime = got.UpdateTime
am.Version = got.Version
assert.Truef(proto.Equal(am.AuthMethod, got.AuthMethod), "got %+v expected %+v", got, tt.am)
err = db.TestVerifyOplog(t, rw, am.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
require.NoErrorf(err, "unexpected error verifying oplog entry: %s", err)
found, err := repo.LookupAuthMethod(ctx, am.PublicId)
require.NoError(err)
found.CreateTime = got.CreateTime
found.UpdateTime = got.UpdateTime
found.Version = got.Version
TestSortAuthMethods(t, []*AuthMethod{found, am})
assert.Empty(cmp.Diff(found.AuthMethod, am.AuthMethod, protocmp.Transform()))
})
}
}