mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
877 lines
32 KiB
877 lines
32 KiB
// Copyright IBM Corp. 2020, 2025
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package oidc
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"os"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/globals"
|
|
"github.com/hashicorp/boundary/internal/auth/oidc/store"
|
|
authStore "github.com/hashicorp/boundary/internal/auth/store"
|
|
"github.com/hashicorp/boundary/internal/authtoken"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/event"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
iamStore "github.com/hashicorp/boundary/internal/iam/store"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/hashicorp/boundary/internal/oplog"
|
|
"github.com/hashicorp/cap/oidc"
|
|
"github.com/hashicorp/eventlogger/formatter_filters/cloudevents"
|
|
"github.com/hashicorp/go-hclog"
|
|
"github.com/hashicorp/go-secure-stdlib/strutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_Callback(t *testing.T) {
|
|
// DO NOT run these tests under t.Parallel(), there be dragons because of dependencies on the
|
|
// Database and TestProvider state
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
|
|
testCtx := context.Background()
|
|
opt := event.TestWithObservationSink(t)
|
|
c := event.TestEventerConfig(t, "Test_StartAuth_to_Callback", opt)
|
|
testLock := &sync.Mutex{}
|
|
testLogger := hclog.New(&hclog.LoggerOptions{
|
|
Mutex: testLock,
|
|
Name: "test",
|
|
})
|
|
c.EventerConfig.TelemetryEnabled = true
|
|
require.NoError(t, event.InitSysEventer(testLogger, testLock, "use-Test_Callback", event.WithEventerConfig(&c.EventerConfig)))
|
|
// some standard factories for unit tests which
|
|
// are used in the Callback(...) call
|
|
iamRepoFn := func() (*iam.Repository, error) {
|
|
return iam.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
repoFn := func() (*Repository, error) {
|
|
return NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
atRepoFn := func() (*authtoken.Repository, error) {
|
|
return authtoken.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
atRepo, err := atRepoFn()
|
|
require.NoError(t, err)
|
|
|
|
iamRepo := iam.TestRepo(t, conn, rootWrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
// a very simple test mock controller, that simply responds with a 200 OK to every
|
|
// request.
|
|
testController := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
w.WriteHeader(200)
|
|
}))
|
|
defer testController.Close()
|
|
|
|
// test provider for the tests (see the oidc package docs for more info)
|
|
// it will provide discovery, JWKs, a token endpoint, etc for these tests.
|
|
tp := oidc.StartTestProvider(t)
|
|
tpCert, err := ParseCertificates(ctx, tp.CACert())
|
|
require.NoError(t, err)
|
|
_, _, tpAlg, _ := tp.SigningKeys()
|
|
|
|
// a reusable test authmethod for the unit tests
|
|
testAuthMethod := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, ActivePublicState,
|
|
"alice-rp", "fido",
|
|
WithCertificates(tpCert...),
|
|
WithSigningAlgs(Alg(tpAlg)),
|
|
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
|
|
WithApiUrl(TestConvertToUrls(t, testController.URL)[0]))
|
|
|
|
// set this as the primary so users will be created on first login
|
|
iam.TestSetPrimaryAuthMethod(t, iamRepo, org, testAuthMethod.PublicId)
|
|
|
|
// a reusable oidc.Provider for the tests
|
|
testProvider, err := convertToProvider(ctx, testAuthMethod)
|
|
require.NoError(t, err)
|
|
testConfigHash, err := testProvider.ConfigHash()
|
|
require.NoError(t, err)
|
|
|
|
// a reusable token request id for the tests.
|
|
testTokenRequestId, err := authtoken.NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
|
|
// usuable nonce for the unit tests
|
|
testNonce := "nonce"
|
|
|
|
// define the audiences the test provider will accept for the unit tests.
|
|
tp.SetCustomAudience("foo", "alice-rp")
|
|
|
|
// define a second test auth method, which is in an InactiveState for the unit tests.
|
|
org2, _ := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
|
|
databaseWrapper2, err := kmsCache.GetWrapper(ctx, org2.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
testAuthMethod2 := TestAuthMethod(t, conn, databaseWrapper2, org2.PublicId, InactiveState,
|
|
"alice-rp", "fido",
|
|
WithAudClaims("foo"),
|
|
WithMaxAge(-1), // oidc library has a 1 min leeway
|
|
WithCertificates(tpCert...),
|
|
WithSigningAlgs(Alg(tpAlg)),
|
|
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
|
|
WithApiUrl(TestConvertToUrls(t, testController.URL)[0]))
|
|
// define a second test provider based on the inactive test auth method
|
|
testProvider2, err := convertToProvider(ctx, testAuthMethod2)
|
|
require.NoError(t, err)
|
|
testConfigHash2, err := testProvider2.ConfigHash()
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
setup func() // provide a simple way to do some prework before the test.
|
|
oidcRepoFn OidcRepoFactory // returns a new oidc repo
|
|
iamRepoFn IamRepoFactory // returns a new iam repo
|
|
atRepoFn AuthTokenRepoFactory // returns a new auth token repo
|
|
am *AuthMethod // the authmethod for the test
|
|
state string // state parameter for test provider and Callback(...)
|
|
code string // code parameter for test provider and Callback(...)
|
|
wantSubject string // sub claim from id token
|
|
wantInfoName string // name claim from userinfo
|
|
wantInfoEmail string // email claim from userinfo
|
|
wantFinalRedirect string // final redirect from Callback(...)
|
|
wantErrMatch *errors.Template // error template to match
|
|
wantErrContains string // error string should contain
|
|
}{
|
|
{
|
|
name: "simple", // must remain the first test
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantFinalRedirect: "https://testcontroler.com/hi-alice",
|
|
wantSubject: "simple@example.com",
|
|
wantInfoName: "alice doe-eve",
|
|
wantInfoEmail: "alice@example.com",
|
|
},
|
|
{
|
|
name: "dup", // must follow "simple" unit test
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantFinalRedirect: "https://testcontroler.com/hi-alice",
|
|
wantSubject: "dup@example.com",
|
|
wantInfoName: "alice doe-eve",
|
|
wantInfoEmail: "alice@example.com",
|
|
},
|
|
{
|
|
name: "inactive-valid",
|
|
setup: func() {
|
|
acct := TestAccount(t, conn, testAuthMethod2, "inactive-valid@example.com")
|
|
_ = iam.TestUser(t, iamRepo, org2.PublicId, iam.WithAccountIds(acct.PublicId))
|
|
},
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod2,
|
|
state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash2, testNonce),
|
|
code: "simple",
|
|
wantFinalRedirect: "https://testcontroler.com/hi-alice",
|
|
wantSubject: "inactive-valid@example.com",
|
|
wantInfoName: "alice doe-eve",
|
|
wantInfoEmail: "alice@example.com",
|
|
},
|
|
{
|
|
name: "missing-oidc-repo-fn",
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing oidc repository",
|
|
},
|
|
{
|
|
name: "missing-iam-repo-fn",
|
|
oidcRepoFn: repoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing iam repository",
|
|
},
|
|
{
|
|
name: "missing-at-repo-fn",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing auth token repository",
|
|
},
|
|
{
|
|
name: "missing-state",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing state",
|
|
},
|
|
{
|
|
name: "missing-code",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing code",
|
|
},
|
|
{
|
|
name: "missing-auth-method",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: func() *AuthMethod { return nil }(),
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing auth method",
|
|
},
|
|
{
|
|
name: "mismatch-auth-method",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "auth method id does not match request wrapper auth method id",
|
|
},
|
|
{
|
|
name: "bad-state-encoding",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: "unable to decode message",
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.Unknown),
|
|
wantErrContains: "unable to decode message",
|
|
},
|
|
{
|
|
name: "inactive-with-config-change",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod2,
|
|
state: testState(t, testAuthMethod2, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", 1, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.AuthMethodInactive),
|
|
wantErrContains: "configuration changed during in-flight authentication attempt",
|
|
},
|
|
{
|
|
name: "expired-attempt",
|
|
oidcRepoFn: repoFn,
|
|
iamRepoFn: iamRepoFn,
|
|
atRepoFn: atRepoFn,
|
|
am: testAuthMethod,
|
|
state: testState(t, testAuthMethod, kmsCache, testTokenRequestId, -20*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce),
|
|
code: "simple",
|
|
wantErrMatch: errors.T(errors.AuthAttemptExpired),
|
|
wantErrContains: "request state has expired",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
// start with no tokens in the db
|
|
_, err := rw.Exec(ctx, "delete from auth_token", nil)
|
|
require.NoError(err)
|
|
// start with no users in the db
|
|
excludeUsers := []any{globals.AnonymousUserId, globals.AnyAuthenticatedUserId, globals.RecoveryUserId}
|
|
_, err = rw.Exec(ctx, "delete from iam_user where public_id not in(?, ?, ?)", excludeUsers)
|
|
require.NoError(err)
|
|
// start with no oplog entries
|
|
_, err = rw.Exec(ctx, "delete from oplog_entry", nil)
|
|
require.NoError(err)
|
|
|
|
if tt.setup != nil {
|
|
tt.setup()
|
|
}
|
|
|
|
// the test provider is stateful, so we need to configure
|
|
// it for this unit test.
|
|
tp.SetExpectedAuthNonce(testNonce)
|
|
if tt.am != nil {
|
|
tp.SetClientCreds(tt.am.ClientId, tt.am.ClientSecret)
|
|
tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, tt.am.ApiUrl)
|
|
tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect})
|
|
}
|
|
if tt.code != "" {
|
|
tp.SetExpectedAuthCode(tt.code)
|
|
}
|
|
if tt.state != "" {
|
|
tp.SetExpectedState(tt.state)
|
|
}
|
|
tp.SetExpectedSubject(tt.wantSubject)
|
|
|
|
info := map[string]any{}
|
|
if tt.wantSubject != "" {
|
|
info["sub"] = tt.wantSubject
|
|
}
|
|
if tt.wantInfoEmail != "" {
|
|
info["email"] = tt.wantInfoEmail
|
|
}
|
|
if tt.wantInfoName != "" {
|
|
info["name"] = tt.wantInfoName
|
|
}
|
|
if len(info) > 0 {
|
|
tp.SetUserInfoReply(info)
|
|
}
|
|
gotRedirect, err := Callback(ctx,
|
|
tt.oidcRepoFn,
|
|
tt.iamRepoFn,
|
|
tt.atRepoFn,
|
|
tt.am,
|
|
tt.state,
|
|
tt.code,
|
|
)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Empty(gotRedirect)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "want err code: %q got: %q", tt.wantErrMatch.Code, err)
|
|
assert.Contains(err.Error(), tt.wantErrContains)
|
|
|
|
// make sure there are no tokens in the db..
|
|
var tokens []*authtoken.AuthToken
|
|
err := rw.SearchWhere(ctx, &tokens, "1=?", []any{1})
|
|
require.NoError(err)
|
|
assert.Equal(0, len(tokens))
|
|
|
|
// make sure there weren't any oplog entries written.
|
|
var entries []*oplog.Entry
|
|
err = rw.SearchWhere(ctx, &entries, "1=?", []any{1})
|
|
require.NoError(err)
|
|
amId := ""
|
|
if tt.am != nil {
|
|
amId = tt.am.PublicId
|
|
}
|
|
require.Equalf(0, len(entries), "should not have found oplog entry for %s", amId)
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tt.wantFinalRedirect, gotRedirect)
|
|
|
|
sinkFileName := c.ObservationEvents.Name()
|
|
defer func() { _ = os.WriteFile(sinkFileName, nil, 0o666) }()
|
|
b, err := os.ReadFile(sinkFileName)
|
|
require.NoError(err)
|
|
got := &cloudevents.Event{}
|
|
err = json.Unmarshal(b, got)
|
|
require.NoErrorf(err, "json: %s", string(b))
|
|
details, ok := got.Data.(map[string]any)["details"]
|
|
require.True(ok)
|
|
for _, key := range details.([]any) {
|
|
assert.Contains(key.(map[string]any)["payload"], "user_id")
|
|
assert.Contains(key.(map[string]any)["payload"], "auth_token_start")
|
|
assert.Contains(key.(map[string]any)["payload"], "auth_token_end")
|
|
}
|
|
|
|
// make sure a pending token was created.
|
|
var tokens []*authtoken.AuthToken
|
|
err = rw.SearchWhere(ctx, &tokens, "1=?", []any{1})
|
|
require.NoError(err)
|
|
require.Equal(1, len(tokens))
|
|
tk, err := atRepo.LookupAuthToken(ctx, tokens[0].PublicId)
|
|
require.NoError(err)
|
|
assert.Equal(tk.Status, string(authtoken.PendingStatus))
|
|
|
|
// make sure the account was updated properly
|
|
var acct Account
|
|
err = rw.LookupWhere(ctx, &acct, "auth_method_id = ? and subject = ?", []any{tt.am.PublicId, tt.wantSubject})
|
|
require.NoError(err)
|
|
assert.Equal(tt.wantInfoEmail, acct.Email)
|
|
assert.Equal(tt.wantInfoName, acct.FullName)
|
|
assert.Equal(tk.AuthAccountId, acct.PublicId)
|
|
|
|
// make sure the token is properly assoc with the
|
|
// logged in user
|
|
var users []*iam.User
|
|
err = rw.SearchWhere(ctx, &users, "public_id not in(?, ?, ?)", excludeUsers)
|
|
require.NoError(err)
|
|
require.Equal(1, len(users))
|
|
assert.Equal(tk.IamUserId, users[0].PublicId)
|
|
|
|
// check the oplog entries.
|
|
var entries []*oplog.Entry
|
|
err = rw.SearchWhere(ctx, &entries, "1=?", []any{1})
|
|
require.NoError(err)
|
|
oplogWrapper, err := kmsCache.GetWrapper(ctx, tt.am.ScopeId, kms.KeyPurposeOplog)
|
|
require.NoError(err)
|
|
types, err := oplog.NewTypeCatalog(testCtx,
|
|
oplog.Type{Interface: new(store.Account), Name: "auth_oidc_account"},
|
|
oplog.Type{Interface: new(iamStore.User), Name: "iam_user"},
|
|
oplog.Type{Interface: new(authStore.Account), Name: "auth_account"},
|
|
)
|
|
require.NoError(err)
|
|
|
|
cnt := 0
|
|
foundAcct, foundOidcAcct := false, false
|
|
for _, e := range entries {
|
|
cnt += 1
|
|
e.Wrapper = oplogWrapper
|
|
err := e.DecryptData(ctx)
|
|
require.NoError(err)
|
|
msgs, err := e.UnmarshalData(ctx, types)
|
|
require.NoError(err)
|
|
for _, m := range msgs {
|
|
switch m.TypeName {
|
|
case "auth_oidc_account":
|
|
foundOidcAcct = true
|
|
t.Log("test: ", tt.name, "found oidc: ", m)
|
|
case "iam_user":
|
|
t.Log("test: ", tt.name, "found iam user: ", m)
|
|
case "auth_account":
|
|
foundAcct = true
|
|
t.Log("test: ", tt.name, "found acct", m)
|
|
}
|
|
}
|
|
}
|
|
assert.Truef(foundAcct, "expected to find auth account oplog entry")
|
|
assert.Truef(foundOidcAcct, "expected to find auth oidc account oplog entry")
|
|
})
|
|
}
|
|
t.Run("replay-attack-with-dup-state", func(t *testing.T) {
|
|
// a test to ensure that replays of duplicate states
|
|
// are rejected and produce an appropriate error.
|
|
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
// start with no tokens in the db
|
|
_, err := rw.Exec(ctx, "delete from auth_token", nil)
|
|
require.NoError(err)
|
|
// start with no users in the db
|
|
excludeUsers := []any{globals.AnonymousUserId, globals.AnyAuthenticatedUserId, globals.RecoveryUserId}
|
|
_, err = rw.Exec(ctx, "delete from iam_user where public_id not in(?, ?, ?)", excludeUsers)
|
|
require.NoError(err)
|
|
// start with no oplog entries
|
|
_, err = rw.Exec(ctx, "delete from oplog_entry", nil)
|
|
require.NoError(err)
|
|
|
|
// prime the test provider's state for the test
|
|
tp.SetClientCreds(testAuthMethod.ClientId, testAuthMethod.ClientSecret)
|
|
tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, testController.URL)
|
|
tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect})
|
|
state := testState(t, testAuthMethod, kmsCache, testTokenRequestId, 20*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce)
|
|
tp.SetExpectedAuthCode("simple")
|
|
tp.SetExpectedState(state)
|
|
|
|
wantSubject := "replay-attack-with-dup-state@example.com"
|
|
tp.SetExpectedSubject(wantSubject)
|
|
|
|
tp.SetUserInfoReply(map[string]any{"sub": wantSubject})
|
|
tp.SetExpectedAuthNonce(testNonce)
|
|
config := event.EventerConfig{
|
|
ObservationsEnabled: true,
|
|
}
|
|
testLock := &sync.Mutex{}
|
|
testLogger := hclog.New(&hclog.LoggerOptions{
|
|
Mutex: testLock,
|
|
Name: "test",
|
|
})
|
|
e, err := event.NewEventer(testLogger, testLock, "replay-attack-with-dup-state", config)
|
|
require.NoError(err)
|
|
ctx, err := event.NewEventerContext(ctx, e)
|
|
require.NoError(err)
|
|
// the first request should succeed.
|
|
gotRedirect, err := Callback(ctx,
|
|
repoFn,
|
|
iamRepoFn,
|
|
atRepoFn,
|
|
testAuthMethod,
|
|
state,
|
|
"simple",
|
|
)
|
|
require.NoError(err)
|
|
require.NotNil(gotRedirect)
|
|
|
|
// the replay should raise an error.
|
|
gotRedirect2, err := Callback(ctx,
|
|
repoFn,
|
|
iamRepoFn,
|
|
atRepoFn,
|
|
testAuthMethod,
|
|
state,
|
|
"simple",
|
|
)
|
|
require.Error(err)
|
|
require.Empty(gotRedirect2)
|
|
assert.Truef(errors.Match(errors.T(errors.Forbidden), err), "want err code: %q got: %q", errors.InvalidParameter, err)
|
|
assert.Contains(err.Error(), "not a unique request")
|
|
})
|
|
}
|
|
|
|
// Test_StartAuth_to_Callback will test if we can successfully
|
|
// test StartAuth(...) through Callback(...)
|
|
func Test_StartAuth_to_Callback(t *testing.T) {
|
|
t.Run("startAuth-to-Callback", func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
ctx := context.Background()
|
|
c := event.TestEventerConfig(t, "Test_StartAuth_to_Callback")
|
|
testLock := &sync.Mutex{}
|
|
testLogger := hclog.New(&hclog.LoggerOptions{
|
|
Mutex: testLock,
|
|
Name: "test",
|
|
})
|
|
require.NoError(event.InitSysEventer(testLogger, testLock, "use-Test_StartAuth_to_Callback", event.WithEventerConfig(&c.EventerConfig)))
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
// start with no tokens in the db
|
|
_, err := rw.Exec(ctx, "delete from auth_token", nil)
|
|
require.NoError(err)
|
|
// start with no users in the db
|
|
excludeUsers := []any{globals.AnonymousUserId, globals.AnyAuthenticatedUserId, globals.RecoveryUserId}
|
|
_, err = rw.Exec(ctx, "delete from iam_user where public_id not in(?, ?, ?)", excludeUsers)
|
|
require.NoError(err)
|
|
// start with no oplog entries
|
|
_, err = rw.Exec(ctx, "delete from oplog_entry", nil)
|
|
require.NoError(err)
|
|
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
|
|
// func pointers for the test controller.
|
|
iamRepoFn := func() (*iam.Repository, error) {
|
|
return iam.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
repoFn := func() (*Repository, error) {
|
|
return NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
atRepoFn := func() (*authtoken.Repository, error) {
|
|
return authtoken.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
atRepo, err := atRepoFn()
|
|
require.NoError(err)
|
|
|
|
controller := startTestControllerSrv(t, repoFn, iamRepoFn, atRepoFn)
|
|
|
|
iamRepo := iam.TestRepo(t, conn, rootWrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(err)
|
|
|
|
// the testing OIDC provider (it's a functional fake, see the package docs)
|
|
tp := oidc.StartTestProvider(t)
|
|
tpCert, err := ParseCertificates(ctx, tp.CACert())
|
|
require.NoError(err)
|
|
_, _, tpAlg, _ := tp.SigningKeys()
|
|
|
|
endToEndAuthMethod := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, ActivePublicState,
|
|
"end-to-end-rp", "fido",
|
|
WithCertificates(tpCert...),
|
|
WithSigningAlgs(Alg(tpAlg)),
|
|
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
|
|
WithApiUrl(TestConvertToUrls(t, controller.Addr())[0]))
|
|
|
|
// need the updated org version, so we can set the primary auth method id
|
|
org, _ = iamRepo.LookupScope(ctx, org.PublicId)
|
|
require.NoError(err)
|
|
iam.TestSetPrimaryAuthMethod(t, iamRepo, org, endToEndAuthMethod.PublicId)
|
|
|
|
// the test controller is stateful and needs to know what auth method id
|
|
// it's suppose to operate on
|
|
controller.SetAuthMethod(endToEndAuthMethod)
|
|
|
|
authUrl, _, err := StartAuth(ctx, repoFn, endToEndAuthMethod.PublicId)
|
|
require.NoError(err)
|
|
|
|
authParams, err := url.ParseQuery(authUrl.RawQuery)
|
|
require.NoError(err)
|
|
require.Equal(1, len(authParams["nonce"]))
|
|
require.Equal(1, len(authParams["state"]))
|
|
|
|
// the TestProvider is stateful and needs to be configured for the upcoming requests.
|
|
tp.SetExpectedState(authParams["state"][0])
|
|
tp.SetExpectedAuthNonce(authParams["nonce"][0])
|
|
tp.SetExpectedAuthCode("simple")
|
|
tp.SetClientCreds(endToEndAuthMethod.ClientId, endToEndAuthMethod.ClientSecret)
|
|
tpAllowedRedirect := controller.CallbackUrl()
|
|
tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect})
|
|
|
|
client := tp.HTTPClient()
|
|
// this http request will:
|
|
// * 1) go to the TestProvider, where auth/authz will be faked.
|
|
// * 2) TestProvider will send 302 back to the client for callback
|
|
// * 3) 302 redirected to testControllerSrv's callback handler
|
|
// * 4) callback handler will exchange the token with the TestProvider
|
|
// * 5) callback will send 302 back to client with final redirect
|
|
// * 6) 302 redirected to testControllerSrv's final redirect handler.
|
|
// * 7) final redirect handler sends "Congratulations" msg back to client.
|
|
//
|
|
// If this succeeds, then the service functions of StartAuth(...) and
|
|
// Callback(...) are successfully working together.
|
|
resp, err := client.Get(authUrl.String())
|
|
require.NoError(err)
|
|
defer resp.Body.Close()
|
|
contents, err := io.ReadAll(resp.Body)
|
|
require.NoError(err)
|
|
require.Containsf(string(contents), "Congratulations", "expected \"Congratulations\" on successful oidc authentication and got: %s", string(contents))
|
|
|
|
// check to make sure there's a pending token, after the successful callback
|
|
var tokens []*authtoken.AuthToken
|
|
err = rw.SearchWhere(ctx, &tokens, "1=?", []any{1})
|
|
require.NoError(err)
|
|
require.Equal(1, len(tokens))
|
|
tk, err := atRepo.LookupAuthToken(ctx, tokens[0].PublicId)
|
|
require.NoError(err)
|
|
assert.Equal(tk.Status, string(authtoken.PendingStatus))
|
|
})
|
|
}
|
|
|
|
func Test_ManagedGroupFiltering(t *testing.T) {
|
|
// DO NOT run these tests under t.Parallel()
|
|
|
|
// A note about this test: other tests handle checking managed group
|
|
// membership, creation, etc. This test is only scoped to checking that
|
|
// given reasonable data in the jwt/userinfo, the result of a callback call
|
|
// results in association with the proper managed groups.
|
|
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
opt := event.TestWithObservationSink(t)
|
|
c := event.TestEventerConfig(t, "Test_StartAuth_to_Callback", opt)
|
|
testLock := &sync.Mutex{}
|
|
testLogger := hclog.New(&hclog.LoggerOptions{
|
|
Mutex: testLock,
|
|
Name: "test",
|
|
})
|
|
c.EventerConfig.TelemetryEnabled = true
|
|
require.NoError(t, event.InitSysEventer(testLogger, testLock, "use-Test_ManagedGroupFiltering", event.WithEventerConfig(&c.EventerConfig)))
|
|
// some standard factories for unit tests which
|
|
// are used in the Callback(...) call
|
|
iamRepoFn := func() (*iam.Repository, error) {
|
|
return iam.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
repoFn := func() (*Repository, error) {
|
|
// Set a low limit to test that the managed group listing overrides the limit
|
|
return NewRepository(ctx, rw, rw, kmsCache, WithLimit(1))
|
|
}
|
|
atRepoFn := func() (*authtoken.Repository, error) {
|
|
return authtoken.NewRepository(ctx, rw, rw, kmsCache)
|
|
}
|
|
|
|
iamRepo := iam.TestRepo(t, conn, rootWrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
// a very simple test mock controller, that simply responds with a 200 OK to
|
|
// every request.
|
|
testController := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
|
|
w.WriteHeader(200)
|
|
}))
|
|
defer testController.Close()
|
|
|
|
// test provider for the tests (see the oidc package docs for more info)
|
|
// it will provide discovery, JWKs, a token endpoint, etc for these tests.
|
|
tp := oidc.StartTestProvider(t)
|
|
tpCert, err := ParseCertificates(ctx, tp.CACert())
|
|
require.NoError(t, err)
|
|
_, _, tpAlg, _ := tp.SigningKeys()
|
|
|
|
// a reusable test authmethod for the unit tests
|
|
testAuthMethod := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, ActivePublicState,
|
|
"alice-rp", "fido",
|
|
WithCertificates(tpCert...),
|
|
WithSigningAlgs(Alg(tpAlg)),
|
|
WithIssuer(TestConvertToUrls(t, tp.Addr())[0]),
|
|
WithApiUrl(TestConvertToUrls(t, testController.URL)[0]))
|
|
|
|
// set this as the primary so users will be created on first login
|
|
iam.TestSetPrimaryAuthMethod(t, iamRepo, org, testAuthMethod.PublicId)
|
|
|
|
// Create an account so it's a known account id
|
|
sub := "alice@example.com"
|
|
account := TestAccount(t, conn, testAuthMethod, sub)
|
|
|
|
// Create two managed groups. We'll use these in tests to set filters and
|
|
// then check that the user belongs to the ones we expect.
|
|
mgs := []*ManagedGroup{
|
|
TestManagedGroup(t, conn, testAuthMethod, TestFakeManagedGroupFilter),
|
|
TestManagedGroup(t, conn, testAuthMethod, TestFakeManagedGroupFilter),
|
|
}
|
|
|
|
// A reusable oidc.Provider for the tests
|
|
testProvider, err := convertToProvider(ctx, testAuthMethod)
|
|
require.NoError(t, err)
|
|
testConfigHash, err := testProvider.ConfigHash()
|
|
require.NoError(t, err)
|
|
|
|
// Set up the provider a bit
|
|
testNonce := "nonce"
|
|
tp.SetExpectedAuthNonce(testNonce)
|
|
code := "simple"
|
|
tp.SetExpectedAuthCode(code)
|
|
tp.SetExpectedSubject(sub)
|
|
tp.SetCustomAudience("foo", "alice-rp")
|
|
info := map[string]any{
|
|
"roles": []string{"user", "operator"},
|
|
"sub": "alice@example.com",
|
|
"email": "alice-alias@example.com",
|
|
"name": "alice doe joe foe",
|
|
"co:lon": "colon",
|
|
}
|
|
tp.SetUserInfoReply(info)
|
|
|
|
repo, err := repoFn()
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
filters []string // Should always be length 2, and specify the filters to use for the test
|
|
matchingMgs []*ManagedGroup // The public IDs of the managed groups we expect the user to be in
|
|
}{
|
|
{
|
|
name: "no match",
|
|
filters: []string{
|
|
TestFakeManagedGroupFilter,
|
|
TestFakeManagedGroupFilter,
|
|
},
|
|
},
|
|
{
|
|
name: "token match",
|
|
filters: []string{
|
|
`"/token/nonce" == "nonce"`,
|
|
TestFakeManagedGroupFilter,
|
|
},
|
|
matchingMgs: mgs[0:1],
|
|
},
|
|
{
|
|
name: "token double match",
|
|
filters: []string{
|
|
`"/token/nonce" == "nonce"`,
|
|
`"/token/name" == "Alice Doe Smith"`,
|
|
},
|
|
matchingMgs: mgs[0:2],
|
|
},
|
|
{
|
|
name: "userinfo double match",
|
|
filters: []string{
|
|
`"user" in "/userinfo/roles"`,
|
|
`"/userinfo/email" == "alice-alias@example.com"`,
|
|
},
|
|
matchingMgs: mgs[0:2],
|
|
},
|
|
{
|
|
name: "userinfo match only",
|
|
filters: []string{
|
|
`"/token/nonce" == "not-nonce"`,
|
|
`"/userinfo/email" == "alice-alias@example.com"`,
|
|
},
|
|
matchingMgs: mgs[1:2],
|
|
},
|
|
{
|
|
name: "colon test",
|
|
filters: []string{
|
|
TestFakeManagedGroupFilter,
|
|
`"/userinfo/co:lon" == "colon"`,
|
|
},
|
|
matchingMgs: mgs[1:2],
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
// A unique token ID for each test
|
|
testTokenRequestId, err := authtoken.NewAuthTokenId(ctx)
|
|
require.NoError(err)
|
|
|
|
// the test provider is stateful, so we need to configure
|
|
// it for this unit test.
|
|
tp.SetClientCreds(testAuthMethod.ClientId, testAuthMethod.ClientSecret)
|
|
tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, testAuthMethod.ApiUrl)
|
|
tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect})
|
|
|
|
state := testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce)
|
|
tp.SetExpectedState(state)
|
|
|
|
// Set the filters on the MGs for this test. First we need to get the current versions.
|
|
currMgs, ttime, err := repo.ListManagedGroups(ctx, testAuthMethod.PublicId, WithLimit(-1))
|
|
require.NoError(err)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
require.Len(currMgs, 2)
|
|
currVersionMap := map[string]uint32{
|
|
currMgs[0].PublicId: currMgs[0].Version,
|
|
currMgs[1].PublicId: currMgs[1].Version,
|
|
}
|
|
for i, filter := range tt.filters {
|
|
mgs[i].Filter = filter
|
|
_, numUpdated, err := repo.UpdateManagedGroup(ctx, org.PublicId, mgs[i], currVersionMap[mgs[i].PublicId], []string{"Filter"})
|
|
require.Equal(numUpdated, 1)
|
|
require.NoError(err)
|
|
}
|
|
// Run the callback
|
|
_, err = Callback(ctx,
|
|
repoFn,
|
|
iamRepoFn,
|
|
atRepoFn,
|
|
testAuthMethod,
|
|
state,
|
|
code,
|
|
)
|
|
require.NoError(err)
|
|
sinkFileName := c.ObservationEvents.Name()
|
|
defer func() { _ = os.WriteFile(sinkFileName, nil, 0o666) }()
|
|
b, err := os.ReadFile(sinkFileName)
|
|
require.NoError(err)
|
|
got := &cloudevents.Event{}
|
|
err = json.Unmarshal(b, got)
|
|
require.NoErrorf(err, "json: %s", string(b))
|
|
details, ok := got.Data.(map[string]any)["details"]
|
|
require.True(ok)
|
|
for _, key := range details.([]any) {
|
|
assert.Contains(key.(map[string]any)["payload"], "user_id")
|
|
assert.Contains(key.(map[string]any)["payload"], "auth_token_start")
|
|
assert.Contains(key.(map[string]any)["payload"], "auth_token_end")
|
|
}
|
|
// Ensure that we get the expected groups
|
|
memberships, err := repo.ListManagedGroupMembershipsByMember(ctx, account.PublicId, WithLimit(-1))
|
|
require.NoError(err)
|
|
assert.Equal(len(tt.matchingMgs), len(memberships))
|
|
var matchingIds []string
|
|
for _, mg := range tt.matchingMgs {
|
|
matchingIds = append(matchingIds, mg.PublicId)
|
|
}
|
|
for _, mg := range memberships {
|
|
assert.True(strutil.StrListContains(matchingIds, mg.ManagedGroupId))
|
|
}
|
|
})
|
|
}
|
|
}
|