diff --git a/go.mod b/go.mod index fa2cdde24c..d74403eec7 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/hashicorp/cap v0.0.0-20210518163718-e72205e8eaae github.com/hashicorp/dbassert v0.0.0-20200930125617-6218396928df github.com/hashicorp/errwrap v1.1.0 - github.com/hashicorp/go-bexpr v0.1.7 + github.com/hashicorp/go-bexpr v0.1.8 github.com/hashicorp/go-cleanhttp v0.5.2 github.com/hashicorp/go-hclog v0.16.1 github.com/hashicorp/go-kms-wrapping v0.6.1 diff --git a/go.sum b/go.sum index 8b3f3a771c..cbf29dc6a2 100644 --- a/go.sum +++ b/go.sum @@ -449,8 +449,8 @@ github.com/hashicorp/dbassert v0.0.0-20200930125617-6218396928df/go.mod h1:+B5eZ github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-bexpr v0.1.7 h1:z48qzCgJkvdnMO/LDy3XHNyCyxnHiFGx9uTKLv0jW2Y= -github.com/hashicorp/go-bexpr v0.1.7/go.mod h1:oxlubA2vC/gFVfX1A6JGp7ls7uCDlfJn732ehYYg+g0= +github.com/hashicorp/go-bexpr v0.1.8 h1:ETfuLF1bBAuHW/Qg6l1xCdV8WJ7lfatLtJ1N1w0IsfE= +github.com/hashicorp/go-bexpr v0.1.8/go.mod h1:oxlubA2vC/gFVfX1A6JGp7ls7uCDlfJn732ehYYg+g0= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= diff --git a/internal/auth/oidc/managed_group_test.go b/internal/auth/oidc/managed_group_test.go index 35f518c1c2..33e7b94585 100644 --- a/internal/auth/oidc/managed_group_test.go +++ b/internal/auth/oidc/managed_group_test.go @@ -26,7 +26,7 @@ func Test_ManagedGroups_RepoValidate(t *testing.T) { }) t.Run("valid", func(t *testing.T) { mg.AuthMethodId = "amoidc_1234567890" - mg.Filter = testFakeFilter + mg.Filter = testFakeManagedGroupFilter assert.NoError(mg.validate(op)) }) } diff --git a/internal/auth/oidc/service_callback.go b/internal/auth/oidc/service_callback.go index fc9b66827c..9762d58ea7 100644 --- a/internal/auth/oidc/service_callback.go +++ b/internal/auth/oidc/service_callback.go @@ -11,6 +11,8 @@ import ( "github.com/hashicorp/boundary/internal/authtoken" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/cap/oidc" + "github.com/hashicorp/go-bexpr" + "github.com/mitchellh/pointerstructure" ) // Callback is an oidc domain service function for processing a successful OIDC @@ -185,6 +187,40 @@ func Callback( return "", errors.Wrap(err, op) } + // Get the set of all managed groups so we can filter + mgs, err := r.ListManagedGroups(ctx, am.GetPublicId()) + if err != nil { + return "", errors.Wrap(err, op) + } + if len(mgs) > 0 { + matchedMgs := make([]*ManagedGroup, 0, len(mgs)) + evalData := map[string]interface{}{ + "token": idTkClaims, + "userinfo": userInfoClaims, + } + // Iterate through and check claims against filters + for _, mg := range mgs { + eval, err := bexpr.CreateEvaluator(mg.Filter) + if err != nil { + // We check all filters on ingress so this should never happen, + // but we validate anyways + return "", errors.Wrap(err, op) + } + match, err := eval.Evaluate(evalData) + if err != nil && !errors.Is(err, pointerstructure.ErrNotFound) { + return "", errors.Wrap(err, op) + } + if match { + matchedMgs = append(matchedMgs, mg) + } + } + // We always pass it in, even if none match, because in that case we + // need to remove any mappings that exist + if _, _, err := r.SetManagedGroupMemberships(ctx, am, acct, matchedMgs); err != nil { + return "", errors.Wrap(err, op) + } + } + // before searching for the iam.User associated with the account, // we need to see if this particular auth method is allowed to // autovivify users for the scope. diff --git a/internal/auth/oidc/service_callback_test.go b/internal/auth/oidc/service_callback_test.go index 4b44af4fac..3fca2f9257 100644 --- a/internal/auth/oidc/service_callback_test.go +++ b/internal/auth/oidc/service_callback_test.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" iamStore "github.com/hashicorp/boundary/internal/iam/store" + "github.com/hashicorp/boundary/sdk/strutil" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" @@ -595,3 +596,198 @@ func Test_StartAuth_to_Callback(t *testing.T) { assert.Equal(tk.Status, string(authtoken.PendingStatus)) }) } + +func Test_ManagedGroupFiltering(t *testing.T) { + // DO NOT run these tests under t.Parallel() + + // A note about this test: other tests handle checking managed group + // membership, creation, etc. This test is only scoped to checking that + // given reasonable data in the jwt/userinfo, the result of a callback call + // results in association with the proper managed groups. + + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + rootWrapper := db.TestWrapper(t) + kmsCache := kms.TestKms(t, conn, rootWrapper) + + // some standard factories for unit tests which + // are used in the Callback(...) call + iamRepoFn := func() (*iam.Repository, error) { + return iam.NewRepository(rw, rw, kmsCache) + } + repoFn := func() (*Repository, error) { + return NewRepository(rw, rw, kmsCache) + } + atRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kmsCache) + } + + iamRepo := iam.TestRepo(t, conn, rootWrapper) + org, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + + // a very simple test mock controller, that simply responds with a 200 OK to + // every request. + testController := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(200) + })) + defer testController.Close() + + // test provider for the tests (see the oidc package docs for more info) + // it will provide discovery, JWKs, a token endpoint, etc for these tests. + tp := oidc.StartTestProvider(t) + tpCert, err := ParseCertificates(tp.CACert()) + require.NoError(t, err) + _, _, tpAlg, _ := tp.SigningKeys() + + // a reusable test authmethod for the unit tests + testAuthMethod := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, ActivePublicState, + "alice-rp", "fido", + WithCertificates(tpCert...), + WithSigningAlgs(Alg(tpAlg)), + WithIssuer(TestConvertToUrls(t, tp.Addr())[0]), + WithApiUrl(TestConvertToUrls(t, testController.URL)[0])) + + // set this as the primary so users will be created on first login + iam.TestSetPrimaryAuthMethod(t, iamRepo, org, testAuthMethod.PublicId) + + // Create an account so it's a known account id + sub := "alice@example.com" + account := TestAccount(t, conn, testAuthMethod, sub) + + // Create two managed groups. We'll use these in tests to set filters and + // then check that the user belongs to the ones we expect. + mgs := []*ManagedGroup{ + TestManagedGroup(t, conn, testAuthMethod, testFakeManagedGroupFilter), + TestManagedGroup(t, conn, testAuthMethod, testFakeManagedGroupFilter), + } + + // A reusable oidc.Provider for the tests + testProvider, err := convertToProvider(ctx, testAuthMethod) + require.NoError(t, err) + testConfigHash, err := testProvider.ConfigHash() + require.NoError(t, err) + + // Set up the provider a bit + testNonce := "nonce" + tp.SetExpectedAuthNonce(testNonce) + code := "simple" + tp.SetExpectedAuthCode(code) + tp.SetExpectedSubject(sub) + tp.SetCustomAudience("foo", "alice-rp") + info := map[string]interface{}{ + "roles": []string{"user", "operator"}, + "sub": "alice@example.com", + "email": "alice-alias@example.com", + "name": "alice doe joe foe", + } + tp.SetUserInfoReply(info) + + repo, err := repoFn() + require.NoError(t, err) + + tests := []struct { + name string + filters []string // Should always be length 2, and specify the filters to use for the test + matchingMgs []*ManagedGroup // The public IDs of the managed groups we expect the user to be in + }{ + { + name: "no match", + filters: []string{ + testFakeManagedGroupFilter, + testFakeManagedGroupFilter, + }, + }, + { + name: "token match", + filters: []string{ + `"/token/nonce" == "nonce"`, + testFakeManagedGroupFilter, + }, + matchingMgs: mgs[0:1], + }, + { + name: "token double match", + filters: []string{ + `"/token/nonce" == "nonce"`, + `"/token/name" == "Alice Doe Smith"`, + }, + matchingMgs: mgs[0:2], + }, + { + name: "userinfo double match", + filters: []string{ + `"user" in "/userinfo/roles"`, + `"/userinfo/email" == "alice-alias@example.com"`, + }, + matchingMgs: mgs[0:2], + }, + { + name: "userinfo match only", + filters: []string{ + `"/token/nonce" == "not-nonce"`, + `"/userinfo/email" == "alice-alias@example.com"`, + }, + matchingMgs: mgs[1:2], + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + // A unique token ID for each test + testTokenRequestId, err := authtoken.NewAuthTokenId() + require.NoError(err) + + // the test provider is stateful, so we need to configure + // it for this unit test. + tp.SetClientCreds(testAuthMethod.ClientId, testAuthMethod.ClientSecret) + tpAllowedRedirect := fmt.Sprintf(CallbackEndpoint, testAuthMethod.ApiUrl) + tp.SetAllowedRedirectURIs([]string{tpAllowedRedirect}) + + state := testState(t, testAuthMethod, kmsCache, testTokenRequestId, 2000*time.Second, "https://testcontroler.com/hi-alice", testConfigHash, testNonce) + tp.SetExpectedState(state) + + // Set the filters on the MGs for this test. First we need to get the current versions. + currMgs, err := repo.ListManagedGroups(ctx, testAuthMethod.PublicId) + require.NoError(err) + require.Len(currMgs, 2) + currVersionMap := map[string]uint32{ + currMgs[0].PublicId: currMgs[0].Version, + currMgs[1].PublicId: currMgs[1].Version, + } + for i, filter := range tt.filters { + mgs[i].Filter = filter + _, numUpdated, err := repo.UpdateManagedGroup(ctx, org.PublicId, mgs[i], currVersionMap[mgs[i].PublicId], []string{"Filter"}) + require.Equal(numUpdated, 1) + require.NoError(err) + } + + // Run the callback + _, err = Callback(ctx, + repoFn, + iamRepoFn, + atRepoFn, + testAuthMethod, + state, + code, + ) + require.NoError(err) + + // Ensure that we get the expected groups + memberships, err := repo.ListManagedGroupMembershipsByMember(ctx, account.PublicId) + require.NoError(err) + assert.Equal(len(tt.matchingMgs), len(memberships)) + var matchingIds []string + for _, mg := range tt.matchingMgs { + matchingIds = append(matchingIds, mg.PublicId) + } + for _, mg := range memberships { + assert.True(strutil.StrListContains(matchingIds, mg.ManagedGroupId)) + } + }) + } +}