From d0f36f6a7432e870b4149308095f6f264972f5ae Mon Sep 17 00:00:00 2001 From: Jim Date: Mon, 12 Apr 2021 17:17:28 -0400 Subject: [PATCH] =?UTF-8?q?ensure=20that=20inbound=20auth=20method=20id=20?= =?UTF-8?q?matches=20request=20wrapper=20auth=20metho=E2=80=A6=20(#1104)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/auth/oidc/service_callback.go | 7 ++ internal/auth/oidc/service_callback_test.go | 60 +++++++----- internal/auth/oidc/service_token_request.go | 12 ++- .../auth/oidc/service_token_request_test.go | 96 +++++++++++++------ .../controller/handlers/authmethods/oidc.go | 2 +- .../handlers/authmethods/oidc_test.go | 4 +- 6 files changed, 123 insertions(+), 58 deletions(-) diff --git a/internal/auth/oidc/service_callback.go b/internal/auth/oidc/service_callback.go index 235568f7d3..841bbfcc12 100644 --- a/internal/auth/oidc/service_callback.go +++ b/internal/auth/oidc/service_callback.go @@ -100,6 +100,13 @@ func Callback( if err != nil { return "", errors.Wrap(err, op) } + + // it appears the authentication request was started with one auth method + // and the callback came to a different auth method. + if stateWrapper.AuthMethodId != authMethodId { + return "", errors.New(errors.InvalidParameter, op, fmt.Sprintf("%s auth method id does not match request wrapper auth method id: %s", authMethodId, stateWrapper)) + } + stateBytes, err := decryptMessage(ctx, requestWrapper, stateWrapper) if err != nil { return "", errors.Wrap(err, op) diff --git a/internal/auth/oidc/service_callback_test.go b/internal/auth/oidc/service_callback_test.go index 29925a0cae..fc46fbe8e0 100644 --- a/internal/auth/oidc/service_callback_test.go +++ b/internal/auth/oidc/service_callback_test.go @@ -114,20 +114,21 @@ func Test_Callback(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - setup func() // provide a simple way to do some prework before the test. - oidcRepoFn OidcRepoFactory // returns a new oidc repo - iamRepoFn IamRepoFactory // returns a new iam repo - atRepoFn AuthTokenRepoFactory // returns a new auth token repo - am *AuthMethod // the authmethod for the test - state string // state parameter for test provider and Callback(...) - code string // code parameter for test provider and Callback(...) - wantSubject string // sub claim from id token - wantInfoName string // name claim from userinfo - wantInfoEmail string // email claim from userinfo - wantFinalRedirect string // final redirect from Callback(...) - wantErrMatch *errors.Template // error template to match - wantErrContains string // error string should contain + name string + setup func() // provide a simple way to do some prework before the test. + oidcRepoFn OidcRepoFactory // returns a new oidc repo + iamRepoFn IamRepoFactory // returns a new iam repo + atRepoFn AuthTokenRepoFactory // returns a new auth token repo + am *AuthMethod // the authmethod for the test + authMethodIdOverride *string // an override of the authmethod.PublcId as a callback parameters + state string // state parameter for test provider and Callback(...) + code string // code parameter for test provider and Callback(...) + wantSubject string // sub claim from id token + wantInfoName string // name claim from userinfo + wantInfoEmail string // email claim from userinfo + wantFinalRedirect string // final redirect from Callback(...) + wantErrMatch *errors.Template // error template to match + wantErrContains string // error string should contain }{ { name: "simple", // must remain the first test @@ -244,6 +245,18 @@ func Test_Callback(t *testing.T) { wantErrMatch: errors.T(errors.RecordNotFound), wantErrContains: "auth method not-valid not found", }, + { + name: "mismatch-auth-method-id", // must remain the first test + oidcRepoFn: repoFn, + iamRepoFn: iamRepoFn, + atRepoFn: atRepoFn, + am: testAuthMethod, + authMethodIdOverride: &testAuthMethod2.PublicId, + state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce), + code: "simple", + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "auth method id does not match request wrapper auth method id", + }, { name: "bad-state-encoding", oidcRepoFn: repoFn, @@ -255,17 +268,6 @@ func Test_Callback(t *testing.T) { wantErrMatch: errors.T(errors.Unknown), wantErrContains: "unable to decode message", }, - { - name: "unable-to-decrypt", - oidcRepoFn: repoFn, - iamRepoFn: iamRepoFn, - atRepoFn: atRepoFn, - am: testAuthMethod, - state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce), - code: "simple", - wantErrMatch: errors.T(errors.Decrypt), - wantErrContains: "unable to decrypt message", - }, { name: "inactive-with-config-change", oidcRepoFn: repoFn, @@ -337,12 +339,18 @@ func Test_Callback(t *testing.T) { if len(info) > 0 { tp.SetUserInfoReply(info) } + var authMethodId string + if tt.authMethodIdOverride != nil { + authMethodId = *tt.authMethodIdOverride + } else { + authMethodId = tt.am.PublicId + } gotRedirect, err := Callback(ctx, tt.oidcRepoFn, tt.iamRepoFn, tt.atRepoFn, - tt.am.PublicId, + authMethodId, tt.state, tt.code, ) diff --git a/internal/auth/oidc/service_token_request.go b/internal/auth/oidc/service_token_request.go index b23b025267..6d7ca18b90 100644 --- a/internal/auth/oidc/service_token_request.go +++ b/internal/auth/oidc/service_token_request.go @@ -2,6 +2,7 @@ package oidc import ( "context" + "fmt" "time" "github.com/hashicorp/boundary/internal/auth/oidc/request" @@ -21,7 +22,7 @@ import ( // * Use the authtoken.(Repository).IssueAuthToken to issue the request id's // token and mark it as issued in the repo. If the token is already issue, an // error is returned. -func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFactory, tokenRequestId string) (*authtoken.AuthToken, error) { +func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFactory, authMethodId, tokenRequestId string) (*authtoken.AuthToken, error) { const op = "oidc.TokenRequest" if kms == nil { return nil, errors.New(errors.InvalidParameter, op, "missing kms") @@ -29,6 +30,12 @@ func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFacto if atRepoFn == nil { return nil, errors.New(errors.InvalidParameter, op, "missing auth token repo function") } + if authMethodId == "" { + return nil, errors.New(errors.InvalidParameter, op, "missing auth method id") + } + if tokenRequestId == "" { + return nil, errors.New(errors.InvalidParameter, op, "missing token request id") + } reqTkWrapper, err := unwrapMessage(ctx, tokenRequestId) if err != nil { @@ -40,6 +47,9 @@ func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFacto if reqTkWrapper.AuthMethodId == "" { return nil, errors.New(errors.InvalidParameter, op, "request token id wrapper missing auth method id") } + if reqTkWrapper.AuthMethodId != authMethodId { + return nil, errors.New(errors.InvalidParameter, op, fmt.Sprintf("%s auth method id does not match request wrapper auth method id: %s", authMethodId, reqTkWrapper.AuthMethodId)) + } // tokenRequestId is a proto request.Wrapper, which contains a cipher text field, // so we need the derived wrapper that was used to encrypt it. diff --git a/internal/auth/oidc/service_token_request_test.go b/internal/auth/oidc/service_token_request_test.go index cf39c03936..29e00d4103 100644 --- a/internal/auth/oidc/service_token_request_test.go +++ b/internal/auth/oidc/service_token_request_test.go @@ -58,6 +58,7 @@ func Test_TokenRequest(t *testing.T) { name string kms *kms.Kms atRepoFn AuthTokenRepoFactory + authMethodId string tokenRequest string wantNil bool wantErrMatch *errors.Template @@ -79,14 +80,16 @@ func Test_TokenRequest(t *testing.T) { name: "bad-wrapper", kms: kmsCache, atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: "bad-wrapper", wantErrMatch: errors.T(errors.Unknown), wantErrContains: "unable to decode message", }, { - name: "missing-wrapper-scope-id", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "missing-wrapper-scope-id", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { w := request.Wrapper{ AuthMethodId: testAuthMethod.PublicId, @@ -99,9 +102,24 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "missing scope id", }, { - name: "missing-wrapper-auth-method-id", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "missing-auth-method-id", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: "", + tokenRequest: func() string { + tokenPublicId, err := authtoken.NewAuthTokenId() + require.NoError(t, err) + TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId) + return TestTokenRequestId(t, testAuthMethod, kmsCache, 200*time.Second, tokenPublicId) + }(), + wantErrMatch: errors.T(errors.Unknown), + wantErrContains: "missing auth method id", + }, + { + name: "missing-wrapper-auth-method-id", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { w := request.Wrapper{ ScopeId: testAuthMethod.ScopeId, @@ -114,9 +132,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "missing auth method id", }, { - name: "dek-not-found", - kms: kms.TestKms(t, conn, db.TestWrapper(t)), - atRepoFn: atRepoFn, + name: "dek-not-found", + kms: kms.TestKms(t, conn, db.TestWrapper(t)), + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) @@ -127,9 +146,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "unable to get oidc wrapper", }, { - name: "expired", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "expired", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) @@ -145,6 +165,7 @@ func Test_TokenRequest(t *testing.T) { atRepoFn: func() (*authtoken.Repository, error) { return nil, errors.New(errors.Unknown, "test op", "atRepoFn-error") }, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) @@ -155,9 +176,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "atRepoFn-error", }, { - name: "error-unmarshal", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "error-unmarshal", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { blobInfo, err := testRequestWrapper.Encrypt(ctx, []byte("not-valid-request-token"), []byte(fmt.Sprintf("%s%s", testAuthMethod.PublicId, testAuthMethod.ScopeId))) require.NoError(t, err) @@ -177,9 +199,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "unable to unmarshal request token", }, { - name: "error-missing-exp", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "error-missing-exp", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) @@ -206,9 +229,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "missing request token id expiration", }, { - name: "error-missing-request-id", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "error-missing-request-id", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { exp, err := ptypes.TimestampProto(time.Now().Add(AttemptExpiration).Truncate(time.Second)) require.NoError(t, err) @@ -235,9 +259,10 @@ func Test_TokenRequest(t *testing.T) { wantErrContains: "missing token request id", }, { - name: "error-issuing-token-forbidden-code", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "error-issuing-token-forbidden-code", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { exp, err := ptypes.TimestampProto(time.Now().Add(AttemptExpiration).Truncate(time.Second)) require.NoError(t, err) @@ -264,9 +289,24 @@ func Test_TokenRequest(t *testing.T) { wantNil: true, }, { - name: "success", - kms: kmsCache, - atRepoFn: atRepoFn, + name: "mismatched-auth-method-id", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: "not-a-match", + tokenRequest: func() string { + tokenPublicId, err := authtoken.NewAuthTokenId() + require.NoError(t, err) + TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId) + return TestTokenRequestId(t, testAuthMethod, kmsCache, 200*time.Second, tokenPublicId) + }(), + wantErrMatch: errors.T(errors.InvalidParameter), + wantErrContains: "auth method id does not match request wrapper auth method id", + }, + { + name: "success", + kms: kmsCache, + atRepoFn: atRepoFn, + authMethodId: testAuthMethod.PublicId, tokenRequest: func() string { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) @@ -278,7 +318,7 @@ func Test_TokenRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - gotTk, err := TokenRequest(ctx, tt.kms, tt.atRepoFn, tt.tokenRequest) + gotTk, err := TokenRequest(ctx, tt.kms, tt.atRepoFn, tt.authMethodId, tt.tokenRequest) if tt.wantErrMatch != nil { require.Error(err) assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got: %+v", tt.wantErrMatch.Code, err) diff --git a/internal/servers/controller/handlers/authmethods/oidc.go b/internal/servers/controller/handlers/authmethods/oidc.go index 3338c449ec..5b3964245d 100644 --- a/internal/servers/controller/handlers/authmethods/oidc.go +++ b/internal/servers/controller/handlers/authmethods/oidc.go @@ -258,7 +258,7 @@ func (s Service) authenticateOidcToken(ctx context.Context, req *pbs.Authenticat return nil, errors.New(errors.InvalidParameter, op, "Empty token id request attributes.") } - token, err := oidc.TokenRequest(ctx, s.kms, s.atRepoFn, attrs.TokenId) + token, err := oidc.TokenRequest(ctx, s.kms, s.atRepoFn, req.GetAuthMethodId(), attrs.TokenId) if err != nil { // TODO: Log something so we don't lose the error's context and entire msg... switch { diff --git a/internal/servers/controller/handlers/authmethods/oidc_test.go b/internal/servers/controller/handlers/authmethods/oidc_test.go index f341857c09..03b765e4a7 100644 --- a/internal/servers/controller/handlers/authmethods/oidc_test.go +++ b/internal/servers/controller/handlers/authmethods/oidc_test.go @@ -1443,7 +1443,7 @@ func TestAuthenticate_OIDC_Token(t *testing.T) { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) oidc.TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId) - return oidc.TestTokenRequestId(t, testAuthMethod, s.kmsCache, 200*time.Second, tokenPublicId) + return oidc.TestTokenRequestId(t, s.authMethod, s.kmsCache, 200*time.Second, tokenPublicId) }(), }) require.NoError(t, err) @@ -1462,7 +1462,7 @@ func TestAuthenticate_OIDC_Token(t *testing.T) { tokenPublicId, err := authtoken.NewAuthTokenId() require.NoError(t, err) oidc.TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId) - return oidc.TestTokenRequestId(t, testAuthMethod, s.kmsCache, -20*time.Second, tokenPublicId) + return oidc.TestTokenRequestId(t, s.authMethod, s.kmsCache, -20*time.Second, tokenPublicId) }(), }) require.NoError(t, err)