mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
299 lines
8.9 KiB
299 lines
8.9 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package oidc_test
|
|
|
|
import (
|
|
"context"
|
|
"math/rand"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/hashicorp/boundary/internal/auth/oidc"
|
|
"github.com/hashicorp/boundary/internal/auth/oidc/store"
|
|
"github.com/hashicorp/boundary/internal/daemon/controller"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func Test_ManagedGroupMemberships(t *testing.T) {
|
|
// This tests both managed group membership functions (set/list) as list is
|
|
// always called as a return from set and we are validating the values that
|
|
// come back against what we expect.
|
|
|
|
// This test can be run in parallel; the subtests *cannot*.
|
|
t.Parallel()
|
|
|
|
// Note: using a test controller here for ease of setup as we need a working
|
|
// dev OIDC auth method and associated accounts. This test is not making API
|
|
// calls! It's accessing the repo directly via the test controller's
|
|
// exposure of the underlying DB primitives.
|
|
tc := controller.NewTestController(t, nil)
|
|
defer tc.Shutdown()
|
|
|
|
conn := tc.DbConn()
|
|
rw := db.New(conn)
|
|
|
|
kmsCache := tc.Kms()
|
|
iamRepo := tc.IamRepo()
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
|
|
ctx := context.Background()
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
authMethod := oidc.TestAuthMethod(
|
|
t, conn, databaseWrapper, org.GetPublicId(), oidc.ActivePrivateState,
|
|
"alice-rp", "fido",
|
|
oidc.WithSigningAlgs(oidc.RS256),
|
|
oidc.WithIssuer(oidc.TestConvertToUrls(t, "https://www.alice.com")[0]),
|
|
oidc.WithApiUrl(oidc.TestConvertToUrls(t, "https://www.alice.com/callback")[0]),
|
|
)
|
|
|
|
repo, err := oidc.NewRepository(ctx, rw, rw, kmsCache)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, repo)
|
|
|
|
mgs := make([]*oidc.ManagedGroup, 0, 10)
|
|
|
|
for i := 0; i < 100; i++ {
|
|
mg := oidc.AllocManagedGroup()
|
|
mg.AuthMethodId = authMethod.PublicId
|
|
mg.Filter = oidc.TestFakeManagedGroupFilter
|
|
got, err := repo.CreateManagedGroup(context.Background(), org.GetPublicId(), mg)
|
|
require.NoError(t, err)
|
|
mgs = append(mgs, got)
|
|
}
|
|
|
|
// Fetch valid OIDC accounts. One will be "static" where we will simply
|
|
// ensure modifying the groups for the other doesn't affect it; the other
|
|
// will be used for testing.
|
|
var accts []*oidc.Account
|
|
err = rw.SearchWhere(ctx, &accts, "", nil, db.WithLimit(2))
|
|
require.NoError(t, err)
|
|
require.Len(t, accts, 2)
|
|
staticAccountId := accts[0].PublicId
|
|
staticMembershipCount := 20
|
|
accountId := accts[1].PublicId
|
|
|
|
account := oidc.AllocAccount()
|
|
account.PublicId = accountId
|
|
|
|
tests := []struct {
|
|
name string
|
|
// If true, we will auto populate necessary values into the function
|
|
validPrereqs bool
|
|
|
|
// Else these can be used for testing
|
|
authMethod *oidc.AuthMethod
|
|
authMethodId string
|
|
account *oidc.Account
|
|
accountId string
|
|
authMethodScopeId string
|
|
wantPreseededMgsCount int
|
|
wantMgsCount int
|
|
specificMgs []*oidc.ManagedGroup
|
|
|
|
wantErr errors.Code
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "nil auth method",
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing auth method",
|
|
},
|
|
{
|
|
name: "missing auth method store",
|
|
authMethod: &oidc.AuthMethod{},
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing auth method store",
|
|
},
|
|
{
|
|
name: "missing auth method id",
|
|
authMethod: &oidc.AuthMethod{AuthMethod: &store.AuthMethod{}},
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing auth method id",
|
|
},
|
|
{
|
|
name: "missing auth method scope id",
|
|
authMethod: &oidc.AuthMethod{AuthMethod: &store.AuthMethod{PublicId: authMethod.PublicId}},
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing auth method scope id",
|
|
},
|
|
{
|
|
name: "missing account",
|
|
authMethod: authMethod,
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing account",
|
|
},
|
|
{
|
|
name: "missing account store",
|
|
authMethod: authMethod,
|
|
account: &oidc.Account{},
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing account store",
|
|
},
|
|
{
|
|
name: "missing account id",
|
|
authMethod: authMethod,
|
|
account: &oidc.Account{Account: &store.Account{}},
|
|
wantErr: errors.InvalidParameter,
|
|
wantErrContains: "missing account id",
|
|
},
|
|
{
|
|
name: "valid fixed, static",
|
|
validPrereqs: true,
|
|
accountId: staticAccountId,
|
|
specificMgs: mgs[0:staticMembershipCount],
|
|
},
|
|
{
|
|
name: "valid fixed",
|
|
validPrereqs: true,
|
|
specificMgs: mgs[0:20],
|
|
},
|
|
{
|
|
name: "valid fixed, same values",
|
|
validPrereqs: true,
|
|
specificMgs: mgs[0:20],
|
|
},
|
|
{
|
|
name: "valid fixed, new values",
|
|
validPrereqs: true,
|
|
specificMgs: mgs[20:40],
|
|
},
|
|
{
|
|
name: "valid none",
|
|
validPrereqs: true,
|
|
wantMgsCount: 0,
|
|
},
|
|
{
|
|
name: "valid none, second test, testing gracefully aborting",
|
|
validPrereqs: true,
|
|
wantMgsCount: 0,
|
|
},
|
|
{
|
|
name: "valid fixed, prep for random",
|
|
validPrereqs: true,
|
|
specificMgs: mgs[20:50],
|
|
},
|
|
{
|
|
name: "valid random",
|
|
validPrereqs: true,
|
|
wantMgsCount: 30,
|
|
},
|
|
{
|
|
name: "valid random, second test",
|
|
validPrereqs: true,
|
|
wantMgsCount: 20,
|
|
},
|
|
{
|
|
name: "valid with duplicates",
|
|
validPrereqs: true,
|
|
specificMgs: append(mgs[0:20], mgs[0:20]...),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
// We are intentionally carrying things over between tests to be
|
|
// more realistic but that means we need correct versions, so update
|
|
// them first.
|
|
currMgs, ttime, err := repo.ListManagedGroups(ctx, authMethod.PublicId)
|
|
require.NoError(err)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
require.Len(currMgs, 100)
|
|
currVersionMap := make(map[string]uint32, len(currMgs))
|
|
for _, currMg := range currMgs {
|
|
currVersionMap[currMg.PublicId] = currMg.Version
|
|
}
|
|
for _, mg := range mgs {
|
|
mg.Version = currVersionMap[mg.PublicId]
|
|
}
|
|
|
|
var mgsToTest []*oidc.ManagedGroup
|
|
var finalMgs map[string]*oidc.ManagedGroup
|
|
// If we know the inputs are sane, create the test data
|
|
if tt.validPrereqs {
|
|
tt.authMethod = authMethod
|
|
tt.account = oidc.AllocAccount()
|
|
tt.account.PublicId = accountId
|
|
if tt.accountId != "" {
|
|
// This is for the test where we initially populate the
|
|
// static account
|
|
tt.account.PublicId = tt.accountId
|
|
}
|
|
mgsToTest = tt.specificMgs
|
|
if mgsToTest == nil {
|
|
// Select at random
|
|
mgsToTest = make([]*oidc.ManagedGroup, tt.wantMgsCount)
|
|
for i := 0; i < tt.wantMgsCount; i++ {
|
|
mg := mgs[rand.Int()%len(mgs)]
|
|
mgsToTest[i] = mg
|
|
}
|
|
}
|
|
finalMgs = make(map[string]*oidc.ManagedGroup)
|
|
for _, v := range mgsToTest {
|
|
finalMgs[v.PublicId] = v
|
|
}
|
|
}
|
|
|
|
memberships, _, err := repo.SetManagedGroupMemberships(ctx, tt.authMethod, tt.account, mgsToTest)
|
|
if tt.wantErr != 0 {
|
|
assert.Truef(errors.Match(errors.T(tt.wantErr), err), "Unexpected error %s", err)
|
|
if tt.wantErrContains != "" {
|
|
assert.True(strings.Contains(err.Error(), tt.wantErrContains))
|
|
}
|
|
return
|
|
}
|
|
|
|
require.NoError(err)
|
|
assert.Len(memberships, len(finalMgs))
|
|
|
|
// Ensure the same set was found; all memberships found should have
|
|
// been in the finalMgs map, and when they are all removed there
|
|
// should be nothing left.
|
|
for _, mship := range memberships {
|
|
// Randomly check a few to ensure the MembershipsByGroup function works
|
|
members, err := repo.ListManagedGroupMembershipsByGroup(ctx, mship.ManagedGroupId)
|
|
require.NoError(err)
|
|
require.NotEmpty(members)
|
|
var found bool
|
|
for _, v := range members {
|
|
if v.MemberId == tt.account.GetPublicId() {
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
assert.True(found)
|
|
assert.Contains(finalMgs, mship.ManagedGroupId)
|
|
delete(finalMgs, mship.ManagedGroupId)
|
|
}
|
|
assert.Len(finalMgs, 0)
|
|
|
|
// Now check that the static account still has the same memberships
|
|
memberships, err = repo.ListManagedGroupMembershipsByMember(ctx, staticAccountId)
|
|
require.NoError(err)
|
|
assert.Len(memberships, staticMembershipCount)
|
|
finalMgs = make(map[string]*oidc.ManagedGroup, staticMembershipCount)
|
|
for _, mg := range mgs[0:staticMembershipCount] {
|
|
finalMgs[mg.PublicId] = mg
|
|
}
|
|
for _, mship := range memberships {
|
|
assert.Contains(finalMgs, mship.ManagedGroupId)
|
|
delete(finalMgs, mship.ManagedGroupId)
|
|
}
|
|
assert.Len(finalMgs, 0)
|
|
})
|
|
}
|
|
}
|