@ -58,6 +58,7 @@ func Test_TokenRequest(t *testing.T) {
name string
kms * kms . Kms
atRepoFn AuthTokenRepoFactory
authMethodId string
tokenRequest string
wantNil bool
wantErrMatch * errors . Template
@ -79,14 +80,16 @@ func Test_TokenRequest(t *testing.T) {
name : "bad-wrapper" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : "bad-wrapper" ,
wantErrMatch : errors . T ( errors . Unknown ) ,
wantErrContains : "unable to decode message" ,
} ,
{
name : "missing-wrapper-scope-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "missing-wrapper-scope-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
w := request . Wrapper {
AuthMethodId : testAuthMethod . PublicId ,
@ -99,9 +102,24 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "missing scope id" ,
} ,
{
name : "missing-wrapper-auth-method-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "missing-auth-method-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : "" ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
TestPendingToken ( t , testAtRepo , testUser , testAcct , tokenPublicId )
return TestTokenRequestId ( t , testAuthMethod , kmsCache , 200 * time . Second , tokenPublicId )
} ( ) ,
wantErrMatch : errors . T ( errors . Unknown ) ,
wantErrContains : "missing auth method id" ,
} ,
{
name : "missing-wrapper-auth-method-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
w := request . Wrapper {
ScopeId : testAuthMethod . ScopeId ,
@ -114,9 +132,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "missing auth method id" ,
} ,
{
name : "dek-not-found" ,
kms : kms . TestKms ( t , conn , db . TestWrapper ( t ) ) ,
atRepoFn : atRepoFn ,
name : "dek-not-found" ,
kms : kms . TestKms ( t , conn , db . TestWrapper ( t ) ) ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
@ -127,9 +146,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "unable to get oidc wrapper" ,
} ,
{
name : "expired" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "expired" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
@ -145,6 +165,7 @@ func Test_TokenRequest(t *testing.T) {
atRepoFn : func ( ) ( * authtoken . Repository , error ) {
return nil , errors . New ( errors . Unknown , "test op" , "atRepoFn-error" )
} ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
@ -155,9 +176,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "atRepoFn-error" ,
} ,
{
name : "error-unmarshal" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "error-unmarshal" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
blobInfo , err := testRequestWrapper . Encrypt ( ctx , [ ] byte ( "not-valid-request-token" ) , [ ] byte ( fmt . Sprintf ( "%s%s" , testAuthMethod . PublicId , testAuthMethod . ScopeId ) ) )
require . NoError ( t , err )
@ -177,9 +199,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "unable to unmarshal request token" ,
} ,
{
name : "error-missing-exp" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "error-missing-exp" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
@ -206,9 +229,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "missing request token id expiration" ,
} ,
{
name : "error-missing-request-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "error-missing-request-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
exp , err := ptypes . TimestampProto ( time . Now ( ) . Add ( AttemptExpiration ) . Truncate ( time . Second ) )
require . NoError ( t , err )
@ -235,9 +259,10 @@ func Test_TokenRequest(t *testing.T) {
wantErrContains : "missing token request id" ,
} ,
{
name : "error-issuing-token-forbidden-code" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "error-issuing-token-forbidden-code" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
exp , err := ptypes . TimestampProto ( time . Now ( ) . Add ( AttemptExpiration ) . Truncate ( time . Second ) )
require . NoError ( t , err )
@ -264,9 +289,24 @@ func Test_TokenRequest(t *testing.T) {
wantNil : true ,
} ,
{
name : "success" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
name : "mismatched-auth-method-id" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : "not-a-match" ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
TestPendingToken ( t , testAtRepo , testUser , testAcct , tokenPublicId )
return TestTokenRequestId ( t , testAuthMethod , kmsCache , 200 * time . Second , tokenPublicId )
} ( ) ,
wantErrMatch : errors . T ( errors . InvalidParameter ) ,
wantErrContains : "auth method id does not match request wrapper auth method id" ,
} ,
{
name : "success" ,
kms : kmsCache ,
atRepoFn : atRepoFn ,
authMethodId : testAuthMethod . PublicId ,
tokenRequest : func ( ) string {
tokenPublicId , err := authtoken . NewAuthTokenId ( )
require . NoError ( t , err )
@ -278,7 +318,7 @@ func Test_TokenRequest(t *testing.T) {
for _ , tt := range tests {
t . Run ( tt . name , func ( t * testing . T ) {
assert , require := assert . New ( t ) , require . New ( t )
gotTk , err := TokenRequest ( ctx , tt . kms , tt . atRepoFn , tt . tokenRequest)
gotTk , err := TokenRequest ( ctx , tt . kms , tt . atRepoFn , tt . authMethodId, tt . tokenRequest)
if tt . wantErrMatch != nil {
require . Error ( err )
assert . Truef ( errors . Match ( tt . wantErrMatch , err ) , "wanted %q and got: %+v" , tt . wantErrMatch . Code , err )