ensure that inbound auth method id matches request wrapper auth metho… (#1104)

pull/1106/head
Jim 5 years ago committed by GitHub
parent 90b30bad25
commit d0f36f6a74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -100,6 +100,13 @@ func Callback(
if err != nil {
return "", errors.Wrap(err, op)
}
// it appears the authentication request was started with one auth method
// and the callback came to a different auth method.
if stateWrapper.AuthMethodId != authMethodId {
return "", errors.New(errors.InvalidParameter, op, fmt.Sprintf("%s auth method id does not match request wrapper auth method id: %s", authMethodId, stateWrapper))
}
stateBytes, err := decryptMessage(ctx, requestWrapper, stateWrapper)
if err != nil {
return "", errors.Wrap(err, op)

@ -114,20 +114,21 @@ func Test_Callback(t *testing.T) {
require.NoError(t, err)
tests := []struct {
name string
setup func() // provide a simple way to do some prework before the test.
oidcRepoFn OidcRepoFactory // returns a new oidc repo
iamRepoFn IamRepoFactory // returns a new iam repo
atRepoFn AuthTokenRepoFactory // returns a new auth token repo
am *AuthMethod // the authmethod for the test
state string // state parameter for test provider and Callback(...)
code string // code parameter for test provider and Callback(...)
wantSubject string // sub claim from id token
wantInfoName string // name claim from userinfo
wantInfoEmail string // email claim from userinfo
wantFinalRedirect string // final redirect from Callback(...)
wantErrMatch *errors.Template // error template to match
wantErrContains string // error string should contain
name string
setup func() // provide a simple way to do some prework before the test.
oidcRepoFn OidcRepoFactory // returns a new oidc repo
iamRepoFn IamRepoFactory // returns a new iam repo
atRepoFn AuthTokenRepoFactory // returns a new auth token repo
am *AuthMethod // the authmethod for the test
authMethodIdOverride *string // an override of the authmethod.PublcId as a callback parameters
state string // state parameter for test provider and Callback(...)
code string // code parameter for test provider and Callback(...)
wantSubject string // sub claim from id token
wantInfoName string // name claim from userinfo
wantInfoEmail string // email claim from userinfo
wantFinalRedirect string // final redirect from Callback(...)
wantErrMatch *errors.Template // error template to match
wantErrContains string // error string should contain
}{
{
name: "simple", // must remain the first test
@ -244,6 +245,18 @@ func Test_Callback(t *testing.T) {
wantErrMatch: errors.T(errors.RecordNotFound),
wantErrContains: "auth method not-valid not found",
},
{
name: "mismatch-auth-method-id", // must remain the first test
oidcRepoFn: repoFn,
iamRepoFn: iamRepoFn,
atRepoFn: atRepoFn,
am: testAuthMethod,
authMethodIdOverride: &testAuthMethod2.PublicId,
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
code: "simple",
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "auth method id does not match request wrapper auth method id",
},
{
name: "bad-state-encoding",
oidcRepoFn: repoFn,
@ -255,17 +268,6 @@ func Test_Callback(t *testing.T) {
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "unable to decode message",
},
{
name: "unable-to-decrypt",
oidcRepoFn: repoFn,
iamRepoFn: iamRepoFn,
atRepoFn: atRepoFn,
am: testAuthMethod,
state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
code: "simple",
wantErrMatch: errors.T(errors.Decrypt),
wantErrContains: "unable to decrypt message",
},
{
name: "inactive-with-config-change",
oidcRepoFn: repoFn,
@ -337,12 +339,18 @@ func Test_Callback(t *testing.T) {
if len(info) > 0 {
tp.SetUserInfoReply(info)
}
var authMethodId string
if tt.authMethodIdOverride != nil {
authMethodId = *tt.authMethodIdOverride
} else {
authMethodId = tt.am.PublicId
}
gotRedirect, err := Callback(ctx,
tt.oidcRepoFn,
tt.iamRepoFn,
tt.atRepoFn,
tt.am.PublicId,
authMethodId,
tt.state,
tt.code,
)

@ -2,6 +2,7 @@ package oidc
import (
"context"
"fmt"
"time"
"github.com/hashicorp/boundary/internal/auth/oidc/request"
@ -21,7 +22,7 @@ import (
// * Use the authtoken.(Repository).IssueAuthToken to issue the request id's
// token and mark it as issued in the repo. If the token is already issue, an
// error is returned.
func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFactory, tokenRequestId string) (*authtoken.AuthToken, error) {
func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFactory, authMethodId, tokenRequestId string) (*authtoken.AuthToken, error) {
const op = "oidc.TokenRequest"
if kms == nil {
return nil, errors.New(errors.InvalidParameter, op, "missing kms")
@ -29,6 +30,12 @@ func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFacto
if atRepoFn == nil {
return nil, errors.New(errors.InvalidParameter, op, "missing auth token repo function")
}
if authMethodId == "" {
return nil, errors.New(errors.InvalidParameter, op, "missing auth method id")
}
if tokenRequestId == "" {
return nil, errors.New(errors.InvalidParameter, op, "missing token request id")
}
reqTkWrapper, err := unwrapMessage(ctx, tokenRequestId)
if err != nil {
@ -40,6 +47,9 @@ func TokenRequest(ctx context.Context, kms *kms.Kms, atRepoFn AuthTokenRepoFacto
if reqTkWrapper.AuthMethodId == "" {
return nil, errors.New(errors.InvalidParameter, op, "request token id wrapper missing auth method id")
}
if reqTkWrapper.AuthMethodId != authMethodId {
return nil, errors.New(errors.InvalidParameter, op, fmt.Sprintf("%s auth method id does not match request wrapper auth method id: %s", authMethodId, reqTkWrapper.AuthMethodId))
}
// tokenRequestId is a proto request.Wrapper, which contains a cipher text field,
// so we need the derived wrapper that was used to encrypt it.

@ -58,6 +58,7 @@ func Test_TokenRequest(t *testing.T) {
name string
kms *kms.Kms
atRepoFn AuthTokenRepoFactory
authMethodId string
tokenRequest string
wantNil bool
wantErrMatch *errors.Template
@ -79,14 +80,16 @@ func Test_TokenRequest(t *testing.T) {
name: "bad-wrapper",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: "bad-wrapper",
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "unable to decode message",
},
{
name: "missing-wrapper-scope-id",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "missing-wrapper-scope-id",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
w := request.Wrapper{
AuthMethodId: testAuthMethod.PublicId,
@ -99,9 +102,24 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "missing scope id",
},
{
name: "missing-wrapper-auth-method-id",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "missing-auth-method-id",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: "",
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId)
return TestTokenRequestId(t, testAuthMethod, kmsCache, 200*time.Second, tokenPublicId)
}(),
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "missing auth method id",
},
{
name: "missing-wrapper-auth-method-id",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
w := request.Wrapper{
ScopeId: testAuthMethod.ScopeId,
@ -114,9 +132,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "missing auth method id",
},
{
name: "dek-not-found",
kms: kms.TestKms(t, conn, db.TestWrapper(t)),
atRepoFn: atRepoFn,
name: "dek-not-found",
kms: kms.TestKms(t, conn, db.TestWrapper(t)),
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
@ -127,9 +146,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "unable to get oidc wrapper",
},
{
name: "expired",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "expired",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
@ -145,6 +165,7 @@ func Test_TokenRequest(t *testing.T) {
atRepoFn: func() (*authtoken.Repository, error) {
return nil, errors.New(errors.Unknown, "test op", "atRepoFn-error")
},
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
@ -155,9 +176,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "atRepoFn-error",
},
{
name: "error-unmarshal",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "error-unmarshal",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
blobInfo, err := testRequestWrapper.Encrypt(ctx, []byte("not-valid-request-token"), []byte(fmt.Sprintf("%s%s", testAuthMethod.PublicId, testAuthMethod.ScopeId)))
require.NoError(t, err)
@ -177,9 +199,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "unable to unmarshal request token",
},
{
name: "error-missing-exp",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "error-missing-exp",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
@ -206,9 +229,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "missing request token id expiration",
},
{
name: "error-missing-request-id",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "error-missing-request-id",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
exp, err := ptypes.TimestampProto(time.Now().Add(AttemptExpiration).Truncate(time.Second))
require.NoError(t, err)
@ -235,9 +259,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains: "missing token request id",
},
{
name: "error-issuing-token-forbidden-code",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "error-issuing-token-forbidden-code",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
exp, err := ptypes.TimestampProto(time.Now().Add(AttemptExpiration).Truncate(time.Second))
require.NoError(t, err)
@ -264,9 +289,24 @@ func Test_TokenRequest(t *testing.T) {
wantNil: true,
},
{
name: "success",
kms: kmsCache,
atRepoFn: atRepoFn,
name: "mismatched-auth-method-id",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: "not-a-match",
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId)
return TestTokenRequestId(t, testAuthMethod, kmsCache, 200*time.Second, tokenPublicId)
}(),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "auth method id does not match request wrapper auth method id",
},
{
name: "success",
kms: kmsCache,
atRepoFn: atRepoFn,
authMethodId: testAuthMethod.PublicId,
tokenRequest: func() string {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
@ -278,7 +318,7 @@ func Test_TokenRequest(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
gotTk, err := TokenRequest(ctx, tt.kms, tt.atRepoFn, tt.tokenRequest)
gotTk, err := TokenRequest(ctx, tt.kms, tt.atRepoFn, tt.authMethodId, tt.tokenRequest)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %q and got: %+v", tt.wantErrMatch.Code, err)

@ -258,7 +258,7 @@ func (s Service) authenticateOidcToken(ctx context.Context, req *pbs.Authenticat
return nil, errors.New(errors.InvalidParameter, op, "Empty token id request attributes.")
}
token, err := oidc.TokenRequest(ctx, s.kms, s.atRepoFn, attrs.TokenId)
token, err := oidc.TokenRequest(ctx, s.kms, s.atRepoFn, req.GetAuthMethodId(), attrs.TokenId)
if err != nil {
// TODO: Log something so we don't lose the error's context and entire msg...
switch {

@ -1443,7 +1443,7 @@ func TestAuthenticate_OIDC_Token(t *testing.T) {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
oidc.TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId)
return oidc.TestTokenRequestId(t, testAuthMethod, s.kmsCache, 200*time.Second, tokenPublicId)
return oidc.TestTokenRequestId(t, s.authMethod, s.kmsCache, 200*time.Second, tokenPublicId)
}(),
})
require.NoError(t, err)
@ -1462,7 +1462,7 @@ func TestAuthenticate_OIDC_Token(t *testing.T) {
tokenPublicId, err := authtoken.NewAuthTokenId()
require.NoError(t, err)
oidc.TestPendingToken(t, testAtRepo, testUser, testAcct, tokenPublicId)
return oidc.TestTokenRequestId(t, testAuthMethod, s.kmsCache, -20*time.Second, tokenPublicId)
return oidc.TestTokenRequestId(t, s.authMethod, s.kmsCache, -20*time.Second, tokenPublicId)
}(),
})
require.NoError(t, err)

Loading…
Cancel
Save