From 053421600d0f7e20ff68feac1b9cc7fdb4b80ecf Mon Sep 17 00:00:00 2001 From: Elim Tsiagbey Date: Fri, 8 Dec 2023 09:51:54 -0500 Subject: [PATCH] feat(oidc-prompts): Add Validation for Prompts (#4130) Add validation to prevent "none" from being combined with multiple prompts. - Valid `["none"]` - Invalid `["none", "select_account"]` - Valid `["consent", "select_account"]` --- internal/auth/oidc/service_start_auth.go | 8 ++++ internal/auth/oidc/service_start_auth_test.go | 40 ++++++++++++++++++- .../authmethods/authmethod_service.go | 28 +++++++++---- .../authmethods/authmethod_service_test.go | 19 +++++++++ 4 files changed, 86 insertions(+), 9 deletions(-) diff --git a/internal/auth/oidc/service_start_auth.go b/internal/auth/oidc/service_start_auth.go index 78adf46381..bc2840b86f 100644 --- a/internal/auth/oidc/service_start_auth.go +++ b/internal/auth/oidc/service_start_auth.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/go-secure-stdlib/strutil" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -55,6 +56,13 @@ func StartAuth(ctx context.Context, oidcRepoFn OidcRepoFactory, authMethodId str if am.OperationalState == string(InactiveState) { return nil, "", errors.New(ctx, errors.AuthMethodInactive, op, "not allowed to start authentication attempt") } + if len(am.Prompts) > 0 { + prompts := strutil.RemoveDuplicatesStable(am.Prompts, false) + + if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 { + return nil, "", errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf(`prompts (%s) includes "none" with other values`, am.Prompts)) + } + } // get the provider from the cache (if possible) provider, err := providerCache().get(ctx, am) diff --git a/internal/auth/oidc/service_start_auth_test.go b/internal/auth/oidc/service_start_auth_test.go index 9864dfa1ce..cdec972194 100644 --- a/internal/auth/oidc/service_start_auth_test.go +++ b/internal/auth/oidc/service_start_auth_test.go @@ -33,7 +33,9 @@ func Test_StartAuth(t *testing.T) { _, _, tpAlg, _ := tp.SigningKeys() tpCert, err := ParseCertificates(ctx, tp.CACert()) require.NoError(t, err) - tpPrompt := []PromptParam{Consent, SelectAccount} + tpPrompt := []PromptParam{SelectAccount} + tpNoneWithMultiplePrompts := []PromptParam{None, SelectAccount} + tpWithMultiplePrompts := []PromptParam{Consent, SelectAccount} conn, _ := db.TestSetup(t, "postgres") rw := db.New(conn) rootWrapper := db.TestWrapper(t) @@ -88,6 +90,26 @@ func Test_StartAuth(t *testing.T) { WithPrompts(tpPrompt...), ) + testAuthMethodWithMultiplePrompts := TestAuthMethod( + t, conn, databaseWrapper, org.PublicId, ActivePublicState, + "test-rp5", "fido", + WithIssuer(TestConvertToUrls(t, tp.Addr())[0]), + WithApiUrl(TestConvertToUrls(t, testController.URL)[0]), + WithSigningAlgs(Alg(tpAlg)), + WithCertificates(tpCert...), + WithPrompts(tpWithMultiplePrompts...), + ) + + testAuthMethodNoneWithMultiplePrompts := TestAuthMethod( + t, conn, databaseWrapper, org.PublicId, ActivePublicState, + "test-rp6", "fido", + WithIssuer(TestConvertToUrls(t, tp.Addr())[0]), + WithApiUrl(TestConvertToUrls(t, testController.URL)[0]), + WithSigningAlgs(Alg(tpAlg)), + WithCertificates(tpCert...), + WithPrompts(tpNoneWithMultiplePrompts...), + ) + stdSetup := func(am *AuthMethod, repoFn OidcRepoFactory, apiSrv *httptest.Server) (a *AuthMethod, allowedRedirect string) { // update the allowed redirects for the TestProvider tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, apiSrv.URL) @@ -172,6 +194,22 @@ func Test_StartAuth(t *testing.T) { authMethod: testAuthMethodWithPrompt, setup: stdSetup, }, + { + name: "simple-with-multiple-prompts", + repoFn: repoFn, + apiSrv: testController, + authMethod: testAuthMethodWithMultiplePrompts, + setup: stdSetup, + }, + { + name: "simple-with-none-and-multiple-prompts", + repoFn: repoFn, + apiSrv: testController, + authMethod: testAuthMethodNoneWithMultiplePrompts, + setup: stdSetup, + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "prompts ([none select_account]) includes \"none\" with other values", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/daemon/controller/handlers/authmethods/authmethod_service.go b/internal/daemon/controller/handlers/authmethods/authmethod_service.go index aae2d8ab2c..f5ac58a8a4 100644 --- a/internal/daemon/controller/handlers/authmethods/authmethod_service.go +++ b/internal/daemon/controller/handlers/authmethods/authmethod_service.go @@ -1034,10 +1034,16 @@ func validateCreateRequest(ctx context.Context, req *pbs.CreateAuthMethodRequest } } if len(attrs.GetPrompts()) > 0 { - for _, p := range attrs.GetPrompts() { - if !oidc.SupportedPrompt(oidc.PromptParam(p)) { - badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p) - break + prompts := strutil.RemoveDuplicatesStable(attrs.GetPrompts(), false) + + if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 { + badFields[promptsField] = fmt.Sprintf(`prompts (%s) includes "none" with other values`, prompts) + } else { + for _, p := range attrs.GetPrompts() { + if !oidc.SupportedPrompt(oidc.PromptParam(p)) { + badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p) + break + } } } } @@ -1170,10 +1176,16 @@ func validateUpdateRequest(ctx context.Context, req *pbs.UpdateAuthMethodRequest } } if len(attrs.GetPrompts()) > 0 { - for _, p := range attrs.GetPrompts() { - if !oidc.SupportedPrompt(oidc.PromptParam(p)) { - badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p) - break + prompts := strutil.RemoveDuplicatesStable(attrs.GetPrompts(), false) + + if strutil.StrListContains(prompts, string(oidc.None)) && len(prompts) > 1 { + badFields[promptsField] = fmt.Sprintf(`prompts (%s) includes "none" with other values`, prompts) + } else { + for _, p := range attrs.GetPrompts() { + if !oidc.SupportedPrompt(oidc.PromptParam(p)) { + badFields[promptsField] = fmt.Sprintf("Contains unsupported prompt %q", p) + break + } } } } diff --git a/internal/daemon/controller/handlers/authmethods/authmethod_service_test.go b/internal/daemon/controller/handlers/authmethods/authmethod_service_test.go index 03542ee95d..c208c1ca34 100644 --- a/internal/daemon/controller/handlers/authmethods/authmethod_service_test.go +++ b/internal/daemon/controller/handlers/authmethods/authmethod_service_test.go @@ -1454,6 +1454,25 @@ func TestCreate(t *testing.T) { }, }, }, + { + name: "Create OIDC AuthMethod with none and other prompts", + req: &pbs.CreateAuthMethodRequest{Item: &pb.AuthMethod{ + ScopeId: o.GetPublicId(), + Type: oidc.Subtype.String(), + Attrs: &pb.AuthMethod_OidcAuthMethodsAttributes{ + OidcAuthMethodsAttributes: &pb.OidcAuthMethodAttributes{ + Issuer: wrapperspb.String("https://example.discovery.url:4821/.well-known/openid-configuration/"), + ClientId: wrapperspb.String("exampleclientid"), + ClientSecret: wrapperspb.String("secret"), + ApiUrlPrefix: wrapperspb.String("https://callback.prefix:9281/path"), + Prompts: []string{string(oidc.None), string(oidc.SelectAccount)}, + }, + }, + }}, + idPrefix: globals.OidcAuthMethodPrefix + "_", + err: handlers.ApiErrorWithCode(codes.InvalidArgument), + errContains: "includes \\\"none\\\" with other values", + }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) {