diff --git a/internal/auth/oidc/provider.go b/internal/auth/oidc/provider.go index 629060ef9e..344e52c67e 100644 --- a/internal/auth/oidc/provider.go +++ b/internal/auth/oidc/provider.go @@ -114,7 +114,7 @@ func convertToProvider(ctx context.Context, am *AuthMethod) (*oidc.Provider, err am.ClientId, oidc.ClientSecret(am.ClientSecret), algs, - []string{fmt.Sprintf(CallbackEndpoint, am.GetApiUrl(), am.PublicId)}, + []string{fmt.Sprintf(CallbackEndpoint, am.GetApiUrl())}, oidc.WithAudiences(am.AudClaims...), oidc.WithProviderCA(strings.Join(am.Certificates, "\n")), ) diff --git a/internal/auth/oidc/provider_test.go b/internal/auth/oidc/provider_test.go index 0a8432295f..7a0ddee20a 100644 --- a/internal/auth/oidc/provider_test.go +++ b/internal/auth/oidc/provider_test.go @@ -25,7 +25,7 @@ func Test_ProviderCaching(t *testing.T) { require.NoError(t, err) id := authMethodId secret := authMethodId - p1 := testProvider(t, id, secret, fmt.Sprintf(CallbackEndpoint, allowedRedirect, authMethodId), tp) // provider needs the complete callback URL + p1 := testProvider(t, id, secret, fmt.Sprintf(CallbackEndpoint, allowedRedirect), tp) // provider needs the complete callback URL testAm, err := NewAuthMethod("fake-org", id, ClientSecret(secret), WithIssuer(issuer), WithApiUrl(TestConvertToUrls(t, allowedRedirect)[0])) @@ -94,7 +94,7 @@ func Test_convertToProvider(t *testing.T) { require.NoError(t, err) id := authMethodId secret := authMethodId - p := testProvider(t, id, secret, fmt.Sprintf(CallbackEndpoint, allowedRedirect, authMethodId), tp) // provider callback needs the complete URL + p := testProvider(t, id, secret, fmt.Sprintf(CallbackEndpoint, allowedRedirect), tp) // provider callback needs the complete URL testAm, err := NewAuthMethod("fake-org", id, ClientSecret(secret), WithIssuer(issuer), WithApiUrl(TestConvertToUrls(t, allowedRedirect)[0])) require.NoError(t, err) diff --git a/internal/auth/oidc/service.go b/internal/auth/oidc/service.go index c170022040..37ee1b2a9b 100644 --- a/internal/auth/oidc/service.go +++ b/internal/auth/oidc/service.go @@ -32,7 +32,7 @@ const ( // CallbackEndpoint is the endpoint for the oidc callback which will be // included in the auth URL returned when an authen attempted is kicked off. - CallbackEndpoint = "%s/v1/auth-methods/%s:authenticate:callback" + CallbackEndpoint = "%s/v1/auth-methods/oidc:authenticate:callback" ) type ( @@ -139,8 +139,8 @@ func decryptMessage(ctx context.Context, wrappingWrapper wrapping.Wrapper, wrapp return decryptedMsg, nil } -// unwrapMessage does just that, it unwraps the encoded request.Wrapper proto message -func unwrapMessage(ctx context.Context, encodedWrappedMsg string) (*request.Wrapper, error) { +// UnwrapMessage does just that, it unwraps the encoded request.Wrapper proto message +func UnwrapMessage(ctx context.Context, encodedWrappedMsg string) (*request.Wrapper, error) { const op = "" decoded, err := base58.FastBase58Decoding(encodedWrappedMsg) if err != nil { diff --git a/internal/auth/oidc/service_callback.go b/internal/auth/oidc/service_callback.go index 79f92c84e9..074c3bf3cf 100644 --- a/internal/auth/oidc/service_callback.go +++ b/internal/auth/oidc/service_callback.go @@ -89,7 +89,7 @@ func Callback( if err != nil { return "", errors.Wrap(err, op) } - stateWrapper, err := unwrapMessage(ctx, state) + stateWrapper, err := UnwrapMessage(ctx, state) if err != nil { return "", errors.Wrap(err, op) } @@ -151,7 +151,7 @@ func Callback( if strings.TrimSpace(am.ApiUrl) == "" { return "", errors.New(errors.InvalidParameter, op, "empty api URL") } - oidcRequest, err := oidc.NewRequest(AttemptExpiration, fmt.Sprintf(CallbackEndpoint, am.ApiUrl, am.PublicId), opts...) + oidcRequest, err := oidc.NewRequest(AttemptExpiration, fmt.Sprintf(CallbackEndpoint, am.ApiUrl), opts...) if err != nil { return "", errors.New(errors.Unknown, op, "unable to create oidc request for token exchange", errors.WithWrap(err)) } diff --git a/internal/auth/oidc/service_callback_test.go b/internal/auth/oidc/service_callback_test.go index 687098a30b..4b44af4fac 100644 --- a/internal/auth/oidc/service_callback_test.go +++ b/internal/auth/oidc/service_callback_test.go @@ -302,7 +302,7 @@ func Test_Callback(t *testing.T) { tp.SetExpectedAuthNonce(testNonce) if tt.am != nil { tp.SetClientCreds(tt.am.ClientId, tt.am.ClientSecret) - tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, tt.am.ApiUrl, tt.am.PublicId) + tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, tt.am.ApiUrl) tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect}) } if tt.code != "" { @@ -444,7 +444,7 @@ func Test_Callback(t *testing.T) { // prime the test provider's state for the test tp.SetClientCreds(testAuthMethod.ClientId, testAuthMethod.ClientSecret) - tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, testController.URL, testAuthMethod.PublicId) + tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, testController.URL) tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect}) state := testState(t, testAuthMethod, kmsCache, testTokenRequestId, 20*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce) tp.SetExpectedAuthCode("simple") diff --git a/internal/auth/oidc/service_start_auth.go b/internal/auth/oidc/service_start_auth.go index a76ff97732..7bcf9a04db 100644 --- a/internal/auth/oidc/service_start_auth.go +++ b/internal/auth/oidc/service_start_auth.go @@ -58,7 +58,7 @@ func StartAuth(ctx context.Context, oidcRepoFn OidcRepoFactory, authMethodId str if err != nil { return nil, "", errors.Wrap(err, op) } - callbackRedirect := fmt.Sprintf(CallbackEndpoint, am.GetApiUrl(), authMethodId) + callbackRedirect := fmt.Sprintf(CallbackEndpoint, am.GetApiUrl()) opts := getOpts(opt...) finalRedirect := fmt.Sprintf(FinalRedirectEndpoint, am.GetApiUrl()) diff --git a/internal/auth/oidc/service_start_auth_test.go b/internal/auth/oidc/service_start_auth_test.go index eccb302d98..d2ba56b404 100644 --- a/internal/auth/oidc/service_start_auth_test.go +++ b/internal/auth/oidc/service_start_auth_test.go @@ -76,7 +76,7 @@ func Test_StartAuth(t *testing.T) { stdSetup := func(am *AuthMethod, repoFn OidcRepoFactory, apiSrv *httptest.Server) (a *AuthMethod, allowedRedirect string) { // update the allowed redirects for the TestProvider - tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, apiSrv.URL, am.PublicId) + tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, apiSrv.URL) tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect}) r, err := repoFn() require.NoError(t, err) @@ -206,7 +206,7 @@ func Test_StartAuth(t *testing.T) { // verify the state in the authUrl can be decrypted and it's correct state := authParams["state"][0] - wrappedStReq, err := unwrapMessage(ctx, state) + wrappedStReq, err := UnwrapMessage(ctx, state) require.NoError(err) repo, err := tt.repoFn() require.NoError(err) @@ -240,7 +240,7 @@ func Test_StartAuth(t *testing.T) { assert.Equal(configHash, reqState.ProviderConfigHash) // verify the token_id in the tokenUrl can be decrypted and it's correct - wrappedTkReq, err := unwrapMessage(ctx, tokenId) + wrappedTkReq, err := UnwrapMessage(ctx, tokenId) require.NoError(err) wrappingWrapper, err = requestWrappingWrapper(ctx, repo.kms, wrappedTkReq.ScopeId, wrappedTkReq.AuthMethodId) require.NoError(err) diff --git a/internal/auth/oidc/service_test.go b/internal/auth/oidc/service_test.go index 243eb6a0bc..0013da1a70 100644 --- a/internal/auth/oidc/service_test.go +++ b/internal/auth/oidc/service_test.go @@ -118,7 +118,7 @@ func Test_encryptMessage_decryptMessage(t *testing.T) { require.NoError(err) assert.NotEmpty(encrypted) - wrappedMsg, err := unwrapMessage(ctx, encrypted) + wrappedMsg, err := UnwrapMessage(ctx, encrypted) assert.Equalf(tt.authMethod.PublicId, wrappedMsg.AuthMethodId, "expected auth method %s and got: %s", tt.authMethod.PublicId, wrappedMsg.AuthMethodId) assert.Equalf(tt.authMethod.ScopeId, wrappedMsg.ScopeId, "expected scope id %s and got: %s", tt.authMethod.ScopeId, wrappedMsg.ScopeId) diff --git a/internal/auth/oidc/service_token_request.go b/internal/auth/oidc/service_token_request.go index 6d7ca18b90..7d246b35aa 100644 --- a/internal/auth/oidc/service_token_request.go +++ b/internal/auth/oidc/service_token_request.go @@ -37,7 +37,7 @@ func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFacto return nil, errors.New(errors.InvalidParameter, op, "missing token request id") } - reqTkWrapper, err := unwrapMessage(ctx, tokenRequestId) + reqTkWrapper, err := UnwrapMessage(ctx, tokenRequestId) if err != nil { return nil, errors.Wrap(err, op) } diff --git a/internal/auth/oidc/testing.go b/internal/auth/oidc/testing.go index 84e11c9989..dc9a45f05a 100644 --- a/internal/auth/oidc/testing.go +++ b/internal/auth/oidc/testing.go @@ -413,7 +413,7 @@ func (s *testControllerSrv) CallbackUrl() string { s.t.Helper() require := require.New(s.t) require.NotNil(s.authMethod, "auth method was missing") - return fmt.Sprintf(CallbackEndpoint, s.Addr(), s.authMethod.GetPublicId()) + return fmt.Sprintf(CallbackEndpoint, s.Addr()) } // ServeHTTP satisfies the http.Handler interface @@ -421,7 +421,7 @@ func (s *testControllerSrv) ServeHTTP(w http.ResponseWriter, req *http.Request) s.t.Helper() require := require.New(s.t) switch req.URL.Path { - case fmt.Sprintf("/v1/auth-methods/%s:authenticate:callback", s.authMethod.GetPublicId()): + case "/v1/auth-methods/oidc:authenticate:callback": err := req.ParseForm() require.NoErrorf(err, "%s: internal error: %w", "callback", err) state := req.FormValue("state") diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 5f9c653338..d5365b9b15 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -684,7 +684,7 @@ func (c *Command) startDevOidcAuthMethod() error { PubKey: ed25519.PublicKey(c.oidcSetup.pubKey), Alg: capoidc.EdDSA, }, - AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/%s:authenticate:callback", c.oidcSetup.callbackUrl.String(), c.DevOidcAuthMethodId)}, + AllowedRedirectURIs: []string{fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", c.oidcSetup.callbackUrl.String())}, ClientID: &c.oidcSetup.clientId, ClientSecret: &clientSecret, })) diff --git a/internal/servers/controller/handler.go b/internal/servers/controller/handler.go index 1c3ce76c44..653d80b820 100644 --- a/internal/servers/controller/handler.go +++ b/internal/servers/controller/handler.go @@ -15,6 +15,7 @@ import ( "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/auth/oidc" "github.com/hashicorp/boundary/internal/gen/controller/api/services" "github.com/hashicorp/boundary/internal/servers/controller/handlers/accounts" "github.com/hashicorp/boundary/internal/servers/controller/handlers/authmethods" @@ -361,6 +362,28 @@ func wrapHandlerWithCallbackInterceptor(h http.Handler, c *Controller) http.Hand for k := range req.Form { values[k] = req.Form.Get(k) } + + if strings.HasSuffix(req.URL.Path, "oidc:authenticate") { + if s, ok := values["state"].(string); ok { + stateWrapper, err := oidc.UnwrapMessage(context.Background(), s) + if err != nil { + c.logger.Trace("callback error marshaling state", "method", req.Method, "url", req.URL.RequestURI(), "error", err) + w.WriteHeader(http.StatusInternalServerError) + return + } + if stateWrapper.AuthMethodId == "" { + c.logger.Trace("callback error: missing auth method id", "method", req.Method, "url", req.URL.RequestURI()) + w.WriteHeader(http.StatusInternalServerError) + return + } + stripped := strings.TrimSuffix(req.URL.Path, "oidc:authenticate") + req.URL.Path = fmt.Sprintf("%s%s:authenticate", stripped, stateWrapper.AuthMethodId) + } else { + c.logger.Trace("callback error: missing state parameter", "method", req.Method, "url", req.URL.RequestURI()) + w.WriteHeader(http.StatusInternalServerError) + return + } + } attrs.Attributes = values } diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service.go b/internal/servers/controller/handlers/authmethods/authmethod_service.go index 62043a209b..87a2dbf933 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service.go @@ -687,7 +687,7 @@ func toAuthMethodProto(in auth.AuthMethod, opt ...handlers.Option) (*pb.AuthMeth } if len(i.GetApiUrl()) > 0 { attrs.ApiUrlPrefix = wrapperspb.String(i.GetApiUrl()) - attrs.CallbackUrl = fmt.Sprintf("%s/v1/auth-methods/%s:authenticate:callback", i.GetApiUrl(), i.GetPublicId()) + attrs.CallbackUrl = fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", i.GetApiUrl()) } switch i.GetMaxAge() { case 0: diff --git a/internal/servers/controller/handlers/authmethods/authmethod_service_test.go b/internal/servers/controller/handlers/authmethods/authmethod_service_test.go index 1057d3a351..e48f1d1c93 100644 --- a/internal/servers/controller/handlers/authmethods/authmethod_service_test.go +++ b/internal/servers/controller/handlers/authmethods/authmethod_service_test.go @@ -126,7 +126,7 @@ func TestGet(t *testing.T) { "client_secret_hmac": structpb.NewStringValue(""), "state": structpb.NewStringValue(string(oidc.InactiveState)), "api_url_prefix": structpb.NewStringValue("https://api.com"), - "callback_url": structpb.NewStringValue(fmt.Sprintf(oidc.CallbackEndpoint, "https://api.com", oidcam.GetPublicId())), + "callback_url": structpb.NewStringValue(fmt.Sprintf(oidc.CallbackEndpoint, "https://api.com")), }}, Version: 1, Scope: &scopepb.ScopeInfo{ @@ -246,7 +246,7 @@ func TestList(t *testing.T) { "client_secret_hmac": structpb.NewStringValue(""), "state": structpb.NewStringValue(string(oidc.ActivePublicState)), "api_url_prefix": structpb.NewStringValue("https://api.com"), - "callback_url": structpb.NewStringValue(fmt.Sprintf(oidc.CallbackEndpoint, "https://api.com", oidcam.GetPublicId())), + "callback_url": structpb.NewStringValue(fmt.Sprintf(oidc.CallbackEndpoint, "https://api.com")), "signing_algorithms": func() *structpb.Value { lv, _ := structpb.NewList([]interface{}{string(oidc.EdDSA)}) return structpb.NewListValue(lv) @@ -591,7 +591,7 @@ func TestCreate(t *testing.T) { "client_secret_hmac": structpb.NewStringValue(""), "state": structpb.NewStringValue(string(oidc.InactiveState)), "api_url_prefix": structpb.NewStringValue("https://callback.prefix:9281/path"), - "callback_url": structpb.NewStringValue(fmt.Sprintf("https://callback.prefix:9281/path/v1/auth-methods/%s_[0-9A-z]*:authenticate:callback", oidc.AuthMethodPrefix)), + "callback_url": structpb.NewStringValue("https://callback.prefix:9281/path/v1/auth-methods/oidc:authenticate:callback"), "allowed_audiences": func() *structpb.Value { lv, _ := structpb.NewList([]interface{}{"foo", "bar"}) return structpb.NewListValue(lv) diff --git a/internal/servers/controller/handlers/authmethods/oidc_test.go b/internal/servers/controller/handlers/authmethods/oidc_test.go index 4bb888d928..e212de7fcd 100644 --- a/internal/servers/controller/handlers/authmethods/oidc_test.go +++ b/internal/servers/controller/handlers/authmethods/oidc_test.go @@ -118,7 +118,7 @@ func getSetup(t *testing.T) setup { oidc.WithCertificates(ret.testProviderCaCert...), ) - ret.testProviderAllowedRedirect = fmt.Sprintf(oidc.CallbackEndpoint, ret.testController.URL, ret.authMethod.PublicId) + ret.testProviderAllowedRedirect = fmt.Sprintf(oidc.CallbackEndpoint, ret.testController.URL) ret.testProvider.SetAllowedRedirectURIs([]string{ret.testProviderAllowedRedirect}) r, err := ret.oidcRepoFn() @@ -288,7 +288,7 @@ func TestUpdate_OIDC(t *testing.T) { "client_secret_hmac": structpb.NewStringValue(""), "state": structpb.NewStringValue(string(oidc.ActivePrivateState)), "api_url_prefix": structpb.NewStringValue("http://example.com"), - "callback_url": structpb.NewStringValue(fmt.Sprintf("http://example.com/v1/auth-methods/%s_[0-9A-z]*:authenticate:callback", oidc.AuthMethodPrefix)), + "callback_url": structpb.NewStringValue("http://example.com/v1/auth-methods/oidc:authenticate:callback"), "idp_ca_certs": func() *structpb.Value { lv, _ := structpb.NewList([]interface{}{tp.CACert()}) return structpb.NewListValue(lv) @@ -817,7 +817,7 @@ func TestUpdate_OIDC(t *testing.T) { Fields: func() map[string]*structpb.Value { f := defaultReadAttributeFields() f["api_url_prefix"] = structpb.NewStringValue("https://callback.prefix:9281/path") - f["callback_url"] = structpb.NewStringValue(fmt.Sprintf("https://callback.prefix:9281/path/v1/auth-methods/%s_[0-9A-z]*:authenticate:callback", oidc.AuthMethodPrefix)) + f["callback_url"] = structpb.NewStringValue("https://callback.prefix:9281/path/v1/auth-methods/oidc:authenticate:callback") return f }(), }, @@ -1102,7 +1102,7 @@ func TestUpdate_OIDCDryRun(t *testing.T) { AuthorizedCollectionActions: authorizedCollectionActions, Attributes: &structpb.Struct{Fields: map[string]*structpb.Value{ "api_url_prefix": structpb.NewStringValue(am.GetApiUrl()), - "callback_url": structpb.NewStringValue(fmt.Sprintf("%s/v1/auth-methods/%s:authenticate:callback", am.GetApiUrl(), am.GetPublicId())), + "callback_url": structpb.NewStringValue(fmt.Sprintf("%s/v1/auth-methods/oidc:authenticate:callback", am.GetApiUrl())), "client_id": structpb.NewStringValue(am.GetClientId()), "client_secret_hmac": structpb.NewStringValue(am.GetClientSecretHmac()), "issuer": structpb.NewStringValue(am.GetIssuer()), @@ -1284,7 +1284,7 @@ func TestChangeState_OIDC(t *testing.T) { "client_id": structpb.NewStringValue(tpClientId), "client_secret_hmac": structpb.NewStringValue(""), "state": structpb.NewStringValue(string(oidc.InactiveState)), - "callback_url": structpb.NewStringValue("https://example.callback:58/v1/auth-methods/amoidc_[0-9A-z]*:authenticate:callback"), + "callback_url": structpb.NewStringValue("https://example.callback:58/v1/auth-methods/oidc:authenticate:callback"), "api_url_prefix": structpb.NewStringValue("https://example.callback:58"), "signing_algorithms": signingAlg, "idp_ca_certs": certs,