feat(oidc-prompts): Add Validation for Prompts (#4130)

Add validation to prevent "none" from being combined with multiple prompts.

- Valid `["none"]`
- Invalid `["none", "select_account"]`
- Valid `["consent", "select_account"]`
tmessi-target-list-reduce-query-params
Elim Tsiagbey 2 years ago committed by GitHub
parent c70a6ff6ba
commit 053421600d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/cap/oidc"
"github.com/hashicorp/go-secure-stdlib/strutil"
"google.golang.org/protobuf/types/known/timestamppb"
)
@ -55,6 +56,13 @@ func StartAuth(ctx context.Context, oidcRepoFn OidcRepoFactory, authMethodId str
if am.OperationalState == string(InactiveState) {
return nil, "", errors.New(ctx, errors.AuthMethodInactive, op, "not allowed to start authentication attempt")
}
if len(am.Prompts) > 0 {
prompts := strutil.RemoveDuplicatesStable(am.Prompts, false)
if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 {
return nil, "", errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf(`prompts (%s) includes "none" with other values`, am.Prompts))
}
}
// get the provider from the cache (if possible)
provider, err := providerCache().get(ctx, am)

@ -33,7 +33,9 @@ func Test_StartAuth(t *testing.T) {
_, _, tpAlg, _ := tp.SigningKeys()
tpCert, err := ParseCertificates(ctx, tp.CACert())
require.NoError(t, err)
tpPrompt := []PromptParam{Consent, SelectAccount}
tpPrompt := []PromptParam{SelectAccount}
tpNoneWithMultiplePrompts := []PromptParam{None, SelectAccount}
tpWithMultiplePrompts := []PromptParam{Consent, SelectAccount}
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
rootWrapper := db.TestWrapper(t)
@ -88,6 +90,26 @@ func Test_StartAuth(t *testing.T) {
WithPrompts(tpPrompt...),
)
testAuthMethodWithMultiplePrompts := TestAuthMethod(
t, conn, databaseWrapper, org.PublicId, ActivePublicState,
"test-rp5", "fido",
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
WithApiUrl(TestConvertToUrls(t, testController.URL)[0]),
WithSigningAlgs(Alg(tpAlg)),
WithCertificates(tpCert...),
WithPrompts(tpWithMultiplePrompts...),
)
testAuthMethodNoneWithMultiplePrompts := TestAuthMethod(
t, conn, databaseWrapper, org.PublicId, ActivePublicState,
"test-rp6", "fido",
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
WithApiUrl(TestConvertToUrls(t, testController.URL)[0]),
WithSigningAlgs(Alg(tpAlg)),
WithCertificates(tpCert...),
WithPrompts(tpNoneWithMultiplePrompts...),
)
stdSetup := func(am *AuthMethod, repoFn OidcRepoFactory, apiSrv *httptest.Server) (a *AuthMethod, allowedRedirect string) {
// update the allowed redirects for the TestProvider
tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, apiSrv.URL)
@ -172,6 +194,22 @@ func Test_StartAuth(t *testing.T) {
authMethod: testAuthMethodWithPrompt,
setup: stdSetup,
},
{
name: "simple-with-multiple-prompts",
repoFn: repoFn,
apiSrv: testController,
authMethod: testAuthMethodWithMultiplePrompts,
setup: stdSetup,
},
{
name: "simple-with-none-and-multiple-prompts",
repoFn: repoFn,
apiSrv: testController,
authMethod: testAuthMethodNoneWithMultiplePrompts,
setup: stdSetup,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "prompts ([none select_account]) includes \"none\" with other values",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

@ -1034,10 +1034,16 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateAuthMethodRequest
}
}
if len(attrs.GetPrompts()) > 0 {
for _, p := range attrs.GetPrompts() {
if !oidc.SupportedPrompt(oidc.PromptParam(p)) {
badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p)
break
prompts := strutil.RemoveDuplicatesStable(attrs.GetPrompts(), false)
if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 {
badFields[promptsField] = fmt.Sprintf(`prompts (%s) includes "none" with other values`, prompts)
} else {
for _, p := range attrs.GetPrompts() {
if !oidc.SupportedPrompt(oidc.PromptParam(p)) {
badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p)
break
}
}
}
}
@ -1170,10 +1176,16 @@ func validateUpdateRequest(ctx context.Context, req *pbs.UpdateAuthMethodRequest
}
}
if len(attrs.GetPrompts()) > 0 {
for _, p := range attrs.GetPrompts() {
if !oidc.SupportedPrompt(oidc.PromptParam(p)) {
badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p)
break
prompts := strutil.RemoveDuplicatesStable(attrs.GetPrompts(), false)
if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 {
badFields[promptsField] = fmt.Sprintf(`prompts (%s) includes "none" with other values`, prompts)
} else {
for _, p := range attrs.GetPrompts() {
if !oidc.SupportedPrompt(oidc.PromptParam(p)) {
badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p)
break
}
}
}
}

@ -1454,6 +1454,25 @@ func TestCreate(t *testing.T) {
},
},
},
{
name: "Create OIDC AuthMethod with none and other prompts",
req: &pbs.CreateAuthMethodRequest{Item: &pb.AuthMethod{
ScopeId: o.GetPublicId(),
Type: oidc.Subtype.String(),
Attrs: &pb.AuthMethod_OidcAuthMethodsAttributes{
OidcAuthMethodsAttributes: &pb.OidcAuthMethodAttributes{
Issuer: wrapperspb.String("https://example.discovery.url:4821/.well-known/openid-configuration/"),
ClientId: wrapperspb.String("exampleclientid"),
ClientSecret: wrapperspb.String("secret"),
ApiUrlPrefix: wrapperspb.String("https://callback.prefix:9281/path"),
Prompts: []string{string(oidc.None), string(oidc.SelectAccount)},
},
},
}},
idPrefix: globals.OidcAuthMethodPrefix + "_",
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
errContains: "includes \\\"none\\\" with other values",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {

Loading…
Cancel
Save