You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/auth/oidc/repository_auth_method_test.go

359 lines
13 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package oidc
import (
"context"
"testing"
"time"
"github.com/hashicorp/boundary/internal/auth/oidc/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)
func Test_upsertAccount(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
rw := db.New(conn)
r, err := NewRepository(ctx, rw, rw, kmsCache)
require.NoError(t, err)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
amActivePriv := TestAuthMethod(
t,
conn, databaseWrapper, org.PublicId, ActivePrivateState,
"alice_rp", "fido",
WithApiUrl(TestConvertToUrls(t, "https://alice-active-priv.com/callback")[0]),
WithSigningAlgs(RS256))
amWithMapping := TestAuthMethod(
t,
conn, databaseWrapper, org.PublicId, ActivePrivateState,
"alice_rp", "fido",
WithAccountClaimMap(map[string]AccountToClaim{"oid": ToSubClaim}),
WithApiUrl(TestConvertToUrls(t, "https://alice-active-priv.com/callback")[0]),
WithSigningAlgs(RS256))
tests := []struct {
name string
am *AuthMethod
idClaims map[string]any
atClaims map[string]any
wantAcct *Account
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "success-defaults",
am: amActivePriv,
idClaims: map[string]any{"iss": "https://alice-active-priv.com", "sub": "success-defaults"},
atClaims: map[string]any{},
wantAcct: &Account{Account: &store.Account{
AuthMethodId: amActivePriv.PublicId,
Issuer: "https://alice-active-priv.com",
Subject: "success-defaults",
TokenClaims: `{"iss":"https://alice-active-priv.com","sub":"success-defaults"}`,
UserinfoClaims: "{}",
}},
},
{
name: "success-atTk-full-name-and-email",
am: amActivePriv,
idClaims: map[string]any{"iss": "https://alice-active-priv.com", "sub": "success-atTk-full-name-and-email"},
atClaims: map[string]any{"name": "alice eve-smith", "email": "alice@alice.com"},
wantAcct: &Account{Account: &store.Account{
AuthMethodId: amActivePriv.PublicId,
Issuer: "https://alice-active-priv.com",
Subject: "success-atTk-full-name-and-email",
Email: "alice@alice.com",
FullName: "alice eve-smith",
TokenClaims: `{"iss":"https://alice-active-priv.com","sub":"success-atTk-full-name-and-email"}`,
UserinfoClaims: `{"email":"alice@alice.com","name":"alice eve-smith"}`,
}},
},
{
name: "success-idTk-full-name-and-email",
am: amActivePriv,
idClaims: map[string]any{"iss": "https://alice-active-priv.com", "sub": "success-idTk-full-name-and-email", "name": "alice eve-smith", "email": "alice@alice.com"},
atClaims: map[string]any{},
wantAcct: &Account{Account: &store.Account{
AuthMethodId: amActivePriv.PublicId,
Issuer: "https://alice-active-priv.com",
Subject: "success-idTk-full-name-and-email",
Email: "alice@alice.com",
FullName: "alice eve-smith",
TokenClaims: `{"email":"alice@alice.com","iss":"https://alice-active-priv.com","name":"alice eve-smith","sub":"success-idTk-full-name-and-email"}`,
UserinfoClaims: `{}`,
}},
},
{
name: "success-map-idTk",
am: amWithMapping,
idClaims: map[string]any{"iss": "https://alice-active-priv.com", "sub": "success-defaults", "oid": "success-map"},
atClaims: map[string]any{},
wantAcct: &Account{Account: &store.Account{
AuthMethodId: amWithMapping.PublicId,
Issuer: "https://alice-active-priv.com",
Subject: "success-map",
TokenClaims: `{"iss":"https://alice-active-priv.com","oid":"success-map","sub":"success-defaults"}`,
UserinfoClaims: `{}`,
}},
},
{
name: "non-existent-auth-method-scope-id",
am: func() *AuthMethod { cp := amActivePriv.Clone(); cp.ScopeId = "non-existent-scope-id"; return cp }(),
idClaims: map[string]any{"iss": "https://alice.com", "sub": "non-existent-auth-method-scope-id"},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "unable to get oplog wrapper",
},
{
name: "non-parsable-issuer",
am: amActivePriv,
idClaims: map[string]any{"iss": ":::::alice.com", "sub": "non-parsable-issuer"},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "unable to parse issuer",
},
{
name: "empty-sub",
am: amActivePriv,
idClaims: map[string]any{"iss": "https://alice.com", "sub": ""},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "missing subject",
},
{
name: "empty-issuer",
am: amActivePriv,
idClaims: map[string]any{"iss": "", "sub": "bad-issuer"},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "missing issuer",
},
{
name: "missing-id-token-issuer",
am: amActivePriv,
idClaims: map[string]any{},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "issuer is not present in ID Token",
},
{
name: "missing-id-token-sub",
am: amActivePriv,
idClaims: map[string]any{"iss": "https://alice.com"},
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "to account subject and it is not present in ID Token",
},
{
name: "missing-id-token-claims",
am: amActivePriv,
atClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "missing ID Token claims",
},
{
name: "missing-access-token-claims",
am: amActivePriv,
idClaims: map[string]any{},
wantErrMatch: errors.T(errors.Unknown),
wantErrContains: "missing Access Token claims",
},
{
name: "missing-auth-method",
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing auth method",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
gotAcct, err := r.upsertAccount(ctx, tt.am, tt.idClaims, tt.atClaims)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "got err %s but wanted: %+v", tt.wantErrMatch.Code, err)
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
tt.wantAcct.PublicId = gotAcct.PublicId
tt.wantAcct.Version = gotAcct.Version
tt.wantAcct.CreateTime = gotAcct.CreateTime
tt.wantAcct.UpdateTime = gotAcct.UpdateTime
require.NoError(err)
ok := proto.Equal(tt.wantAcct.Account, gotAcct.Account)
assert.True(ok)
if !ok {
assert.Equalf(tt.wantAcct, gotAcct, "using assert.Equal to just show the diff between want and got")
}
})
}
}
func Test_upsertOplog(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
rw := db.New(conn)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
oplogWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeOplog)
require.NoError(t, err)
testAuthMethod := TestAuthMethod(
t, conn, databaseWrapper, org.PublicId, ActivePrivateState,
"alice-rp", "fido",
WithSigningAlgs(RS256),
WithIssuer(TestConvertToUrls(t, "https://www.alice.com")[0]),
WithApiUrl(TestConvertToUrls(t, "https://www.alice.com/callback")[0]),
)
tests := []struct {
name string
writer db.Writer
wrapper wrapping.Wrapper
op oplog.OpType
scopeId string
acct *Account
fieldMasks []string
nullFields []string
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "create-success",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_CREATE,
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "create-success"),
},
{
name: "update-success-withFieldMask-only",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_UPDATE,
fieldMasks: []string{"FullName"},
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "update-success-withFieldMask-only"),
},
{
name: "update-success-withNullFields-only",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_UPDATE,
nullFields: []string{"FullName"},
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "update-success-withNullFields-only"),
},
{
name: "update-missing-fieldMask-and-nullFields",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_UPDATE,
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "update-missing-fieldMask-and-nullFields"),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "update operations must specify field masks and/or null masks",
},
{
name: "missing-writer",
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_CREATE,
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "missing-writer"),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing db writer",
},
{
name: "missing-oplog-wrapper",
writer: rw,
op: oplog.OpType_OP_TYPE_CREATE,
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "missing-oplog-wrapper"),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing oplog wrapper",
},
{
name: "bad-op-type",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_DELETE,
scopeId: org.PublicId,
acct: TestAccount(t, conn, testAuthMethod, "bad-op-type"),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "not a supported operation",
},
{
name: "missing-scope-id",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_UPDATE,
acct: TestAccount(t, conn, testAuthMethod, "missing-scope-id"),
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing scope id",
},
{
name: "missing-account",
writer: rw,
wrapper: oplogWrapper,
op: oplog.OpType_OP_TYPE_UPDATE,
scopeId: org.PublicId,
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing account",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// start with no oplog entries
_, err = rw.Exec(ctx, "delete from oplog_entry", nil)
require.NoError(err)
err := upsertOplog(ctx, tt.writer, tt.wrapper, tt.op, tt.scopeId, tt.acct, tt.fieldMasks, tt.nullFields)
if tt.wantErrMatch != nil {
require.Error(err)
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
if tt.acct != nil {
// check for create
err := db.TestVerifyOplog(t, rw, tt.acct.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second))
require.Errorf(err, "should not have found oplog entry for %s", tt.acct.AuthMethodId)
// check for update
err = db.TestVerifyOplog(t, rw, tt.acct.PublicId, db.WithOperation(oplog.OpType_OP_TYPE_UPDATE), db.WithCreateNotBefore(10*time.Second))
require.Errorf(err, "should not have found oplog entry for %s", tt.acct.AuthMethodId)
}
return
}
require.NoError(err)
err = db.TestVerifyOplog(t, rw, tt.acct.PublicId, db.WithOperation(tt.op), db.WithCreateNotBefore(10*time.Second))
require.NoErrorf(err, "unexpected error verifying oplog entry: %s", err)
})
}
}