internal/auth/ldap: add account pagination

pull/4202/head
Johan Brandhorst-Satzkorn 2 years ago
parent 44971b3559
commit 6281a5b73c

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/auth/ldap/store"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/types/resource"
@ -150,3 +151,13 @@ func (a *Account) oplog(ctx context.Context, opType oplog.OpType) (oplog.Metadat
}
return metadata, nil
}
type deletedAccount struct {
PublicId string `gorm:"primary_key"`
DeleteTime *timestamp.Timestamp
}
// TableName returns the tablename to override the default gorm table name
func (s *deletedAccount) TableName() string {
return "auth_ldap_account_deleted"
}

@ -11,6 +11,8 @@ import (
"net/url"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/util"
)
type options struct {
@ -48,6 +50,7 @@ type options struct {
withPublicId string
withDerefAliases DerefAliasType
withMaximumPageSize uint32
withStartPageAfterItem pagination.Item
}
// Option - how options are passed as args
@ -413,3 +416,16 @@ func (d DerefAliasType) IsValid(ctx context.Context) error {
return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is not a valid ldap dereference alias type", d))
}
}
// WithStartPageAfterItem is used to paginate over the results.
// The next page will start after the provided item.
func WithStartPageAfterItem(ctx context.Context, item pagination.Item) Option {
const op = "ldap.WithStartPageAfterItem"
return func(o *options) error {
if util.IsNil(item) {
return errors.New(ctx, errors.InvalidParameter, op, "item cannot be nil")
}
o.withStartPageAfterItem = item
return nil
}
}

@ -9,12 +9,34 @@ import (
"crypto/rand"
"crypto/x509"
"testing"
"time"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type fakeItem struct {
pagination.Item
publicId string
createTime time.Time
updateTime time.Time
}
func (p *fakeItem) GetPublicId() string {
return p.publicId
}
func (p *fakeItem) GetCreateTime() *timestamp.Timestamp {
return timestamp.New(p.createTime)
}
func (p *fakeItem) GetUpdateTime() *timestamp.Timestamp {
return timestamp.New(p.updateTime)
}
func Test_getOpts(t *testing.T) {
t.Parallel()
testCtx := context.Background()
@ -350,4 +372,18 @@ func Test_getOpts(t *testing.T) {
assert.ErrorContains(err, `"Invalid" is not a valid ldap dereference alias type`)
assert.Truef(errors.Match(errors.T(errors.InvalidParameter), err), "want err code: %q got: %q", errors.InvalidParameter, err)
})
t.Run("WithStartPageAfterItem", func(t *testing.T) {
t.Run("nil item", func(t *testing.T) {
_, err := getOpts(WithStartPageAfterItem(context.Background(), nil))
require.Error(t, err)
})
assert := assert.New(t)
updateTime := time.Now()
createTime := time.Now()
opts, err := getOpts(WithStartPageAfterItem(context.Background(), &fakeItem{nil, "s_1", createTime, updateTime}))
require.NoError(t, err)
assert.Equal(opts.withStartPageAfterItem.GetPublicId(), "s_1")
assert.Equal(opts.withStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
assert.Equal(opts.withStartPageAfterItem.GetCreateTime(), timestamp.New(createTime))
})
}

@ -0,0 +1,10 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
const (
estimateCountAccounts = `
select sum(reltuples::bigint) as estimate from pg_class where oid in ('auth_ldap_account'::regclass)
`
)

@ -5,10 +5,13 @@ package ldap
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
@ -96,27 +99,103 @@ func (r *Repository) LookupAccount(ctx context.Context, withPublicId string, _ .
return a, nil
}
// ListAccounts in an auth method and supports WithLimit option.
func (r *Repository) ListAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, error) {
const op = "ldap.(Repository).ListAccounts"
// listAccounts returns a slice of accounts in the auth method.
// Supported options:
// - WithLimit which overrides the limit set in the Repository object
// - WithStartPageAfterItem which sets where to start listing from
func (r *Repository) listAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, time.Time, error) {
const op = "ldap.(Repository).listAccounts"
if withAuthMethodId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id")
return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id")
}
opts, err := getOpts(opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
return nil, time.Time{}, errors.Wrap(ctx, err, op)
}
limit := r.defaultLimit
if opts.withLimit != 0 {
// non-zero signals an override of the default limit for the repo.
limit = opts.withLimit
}
var accts []*Account
err = r.reader.SearchWhere(ctx, &accts, "auth_method_id = ?", []any{withAuthMethodId}, db.WithLimit(limit))
var args []any
whereClause := "auth_method_id = @auth_method_id"
args = append(args, sql.Named("auth_method_id", withAuthMethodId))
if opts.withStartPageAfterItem != nil {
whereClause = fmt.Sprintf("(create_time, public_id) < (@last_item_create_time, @last_item_id) and %s", whereClause)
args = append(args,
sql.Named("last_item_create_time", opts.withStartPageAfterItem.GetCreateTime()),
sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()),
)
}
dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("create_time desc, public_id desc")}
return r.queryAccounts(ctx, whereClause, args, dbOpts...)
}
// listAccountsRefresh returns a slice of accounts in the auth method.
// Supported options:
// - WithLimit which overrides the limit set in the Repository object
// - WithStartPageAfterItem which sets where to start listing from
func (r *Repository) listAccountsRefresh(ctx context.Context, withAuthMethodId string, updatedAfter time.Time, opt ...Option) ([]*Account, time.Time, error) {
const op = "ldap.(Repository).listAccountsRefresh"
switch {
case withAuthMethodId == "":
return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id")
case updatedAfter.IsZero():
return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing updated after time")
}
opts, err := getOpts(opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op)
return nil, time.Time{}, errors.Wrap(ctx, err, op)
}
return accts, nil
limit := r.defaultLimit
if opts.withLimit != 0 {
// non-zero signals an override of the default limit for the repo.
limit = opts.withLimit
}
var args []any
whereClause := "update_time > @updated_after_time and auth_method_id = @auth_method_id"
args = append(args,
sql.Named("updated_after_time", timestamp.New(updatedAfter)),
sql.Named("auth_method_id", withAuthMethodId),
)
if opts.withStartPageAfterItem != nil {
whereClause = fmt.Sprintf("(update_time, public_id) < (@last_item_update_time, @last_item_id) and %s", whereClause)
args = append(args,
sql.Named("last_item_update_time", opts.withStartPageAfterItem.GetUpdateTime()),
sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()),
)
}
dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("update_time desc, public_id desc")}
return r.queryAccounts(ctx, whereClause, args, dbOpts...)
}
func (r *Repository) queryAccounts(ctx context.Context, whereClause string, args []any, opt ...db.Option) ([]*Account, time.Time, error) {
const op = "ldap.(Repository).queryAccounts"
var accts []*Account
var transactionTimestamp time.Time
if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(rd db.Reader, w db.Writer) error {
var inAccts []*Account
if err := rd.SearchWhere(ctx, &inAccts, whereClause, args, opt...); err != nil {
return errors.Wrap(ctx, err, op)
}
accts = inAccts
var err error
transactionTimestamp, err = rd.Now(ctx)
return err
}); err != nil {
return nil, time.Time{}, errors.Wrap(ctx, err, op)
}
return accts, transactionTimestamp, nil
}
// DeleteAccount deletes the account for the provided id from the repository returning a count of the
@ -252,3 +331,45 @@ func (r *Repository) UpdateAccount(ctx context.Context, scopeId string, a *Accou
return returnedAccount, rowsUpdated, nil
}
// listDeletedAccountIds lists the public IDs of any accounts deleted since the timestamp provided,
// and the timestamp of the transaction within which the accounts were listed.
func (r *Repository) listDeletedAccountIds(ctx context.Context, since time.Time) ([]string, time.Time, error) {
const op = "ldap.(Repository).listDeletedAccountIds"
var deleteAccounts []*deletedAccount
var transactionTimestamp time.Time
if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, _ db.Writer) error {
if err := r.SearchWhere(ctx, &deleteAccounts, "delete_time >= ?", []any{since}); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query deleted accounts"))
}
var err error
transactionTimestamp, err = r.Now(ctx)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to get transaction timestamp"))
}
return nil
}); err != nil {
return nil, time.Time{}, err
}
var accountIds []string
for _, a := range deleteAccounts {
accountIds = append(accountIds, a.PublicId)
}
return accountIds, transactionTimestamp, nil
}
// estimatedAccountCount returns an estimate of the total number of accounts.
func (r *Repository) estimatedAccountCount(ctx context.Context) (int, error) {
const op = "ldap.(Repository).estimatedAccountCount"
rows, err := r.reader.Query(ctx, estimateCountAccounts, nil)
if err != nil {
return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts"))
}
var count int
for rows.Next() {
if err := r.reader.ScanRows(ctx, rows, &count); err != nil {
return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts"))
}
}
return count, nil
}

@ -6,24 +6,27 @@ package ldap
import (
"context"
"fmt"
"sort"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth/ldap/store"
"github.com/hashicorp/boundary/internal/db"
dbassert "github.com/hashicorp/boundary/internal/db/assert"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestRepository_CreateAccount(t *testing.T) {
@ -640,7 +643,13 @@ func TestRepository_DeleteAccount(t *testing.T) {
}
}
func TestRepository_ListAccounts(t *testing.T) {
func TestRepository_listAccounts(t *testing.T) {
oldReadTimeout := globals.RefreshReadLookbackDuration
globals.RefreshReadLookbackDuration = 0
t.Cleanup(func() {
globals.RefreshReadLookbackDuration = oldReadTimeout
})
testConn, _ := db.TestSetup(t, "postgres")
testRw := db.New(testConn)
testWrapper := db.TestWrapper(t)
@ -669,7 +678,17 @@ func TestRepository_ListAccounts(t *testing.T) {
TestAccount(t, testConn, authMethod2, "create-success2"),
TestAccount(t, testConn, authMethod2, "create-success3"),
}
_ = accounts2
slices.Reverse(accounts1)
slices.Reverse(accounts2)
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(
Account{},
store.Account{},
timestamp.Timestamp{},
timestamppb.Timestamp{},
),
}
tests := []struct {
name string
@ -717,7 +736,7 @@ func TestRepository_ListAccounts(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := tc.repo.ListAccounts(testCtx, tc.publicId, tc.opts...)
got, ttime, err := tc.repo.listAccounts(testCtx, tc.publicId, tc.opts...)
if tc.wantErrMatch != nil {
assert.Truef(errors.Match(tc.wantErrMatch, err), "Unexpected error %s", err)
if tc.wantErrContains != "" {
@ -726,13 +745,262 @@ func TestRepository_ListAccounts(t *testing.T) {
return
}
require.NoError(err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
sort.Slice(got, func(i, j int) bool {
return strings.Compare(got[i].LoginName, got[j].LoginName) < 0
})
assert.EqualValues(tc.want, got)
require.Empty(cmp.Diff(got, tc.want, cmpOpts...))
})
}
t.Run("validation", func(t *testing.T) {
t.Parallel()
t.Run("missing auth method id", func(t *testing.T) {
t.Parallel()
_, _, err := testRepo.listAccounts(testCtx, "", WithLimit(testCtx, 1))
require.ErrorContains(t, err, "missing auth method id")
})
})
t.Run("success-without-after-item", func(t *testing.T) {
t.Parallel()
resp, ttime, err := testRepo.listAccounts(testCtx, authMethod1.PublicId, WithLimit(testCtx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1, cmpOpts...))
})
t.Run("success-with-after-item", func(t *testing.T) {
t.Parallel()
resp, ttime, err := testRepo.listAccounts(testCtx, authMethod1.PublicId, WithStartPageAfterItem(testCtx, accounts1[0]), WithLimit(testCtx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1[1:], cmpOpts...))
})
}
func TestRepository_listAccountsRefresh(t *testing.T) {
oldReadTimeout := globals.RefreshReadLookbackDuration
globals.RefreshReadLookbackDuration = 0
t.Cleanup(func() {
globals.RefreshReadLookbackDuration = oldReadTimeout
})
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
testWrapper := db.TestWrapper(t)
ctx := context.Background()
testKms := kms.TestKms(t, conn, testWrapper)
iamRepo := iam.TestRepo(t, conn, testWrapper)
org, _ := iam.TestScopes(t, iamRepo)
databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
repo, err := NewRepository(ctx, rw, rw, testKms)
assert.NoError(t, err)
fiveDaysAgo := time.Now().AddDate(0, 0, -5)
authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"})
authMethod2 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap2"})
accounts1 := []*Account{
TestAccount(t, conn, authMethod1, "create-success"),
TestAccount(t, conn, authMethod1, "create-success2"),
TestAccount(t, conn, authMethod1, "create-success3"),
}
accounts2 := []*Account{
TestAccount(t, conn, authMethod2, "create-success"),
TestAccount(t, conn, authMethod2, "create-success2"),
TestAccount(t, conn, authMethod2, "create-success3"),
}
slices.Reverse(accounts1)
_ = accounts2
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(
Account{},
store.Account{},
timestamp.Timestamp{},
timestamppb.Timestamp{},
),
cmpopts.SortSlices(func(i, j string) bool { return i < j }),
}
t.Run("validation", func(t *testing.T) {
t.Parallel()
t.Run("missing updated after", func(t *testing.T) {
t.Parallel()
_, _, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, time.Time{}, WithLimit(ctx, 1))
require.ErrorContains(t, err, "missing updated after time")
})
t.Run("missing auth method id", func(t *testing.T) {
t.Parallel()
_, _, err := repo.listAccountsRefresh(ctx, "", fiveDaysAgo, WithLimit(ctx, 1))
require.ErrorContains(t, err, "missing auth method id")
})
})
t.Run("success-without-after-item", func(t *testing.T) {
t.Parallel()
resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, fiveDaysAgo, WithLimit(ctx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1, cmpOpts...))
})
t.Run("success-with-after-item", func(t *testing.T) {
t.Parallel()
resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, fiveDaysAgo, WithStartPageAfterItem(ctx, accounts1[0]), WithLimit(ctx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1[1:], cmpOpts...))
})
t.Run("success-without-after-item-recent-updated-after", func(t *testing.T) {
t.Parallel()
resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, accounts1[len(accounts1)-1].GetUpdateTime().AsTime(), WithLimit(ctx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1[:len(accounts1)-1], cmpOpts...))
})
t.Run("success-with-after-item-recent-updated-after", func(t *testing.T) {
t.Parallel()
resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, accounts1[len(accounts1)-1].GetUpdateTime().AsTime(), WithStartPageAfterItem(ctx, accounts1[0]), WithLimit(ctx, 10))
require.NoError(t, err)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
require.Empty(t, cmp.Diff(resp, accounts1[1:len(accounts1)-1], cmpOpts...))
})
}
func TestRepository_estimatedCount(t *testing.T) {
oldReadTimeout := globals.RefreshReadLookbackDuration
globals.RefreshReadLookbackDuration = 0
t.Cleanup(func() {
globals.RefreshReadLookbackDuration = oldReadTimeout
})
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
testWrapper := db.TestWrapper(t)
ctx := context.Background()
testKms := kms.TestKms(t, conn, testWrapper)
iamRepo := iam.TestRepo(t, conn, testWrapper)
org, _ := iam.TestScopes(t, iamRepo)
databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
repo, err := NewRepository(ctx, rw, rw, testKms)
assert.NoError(t, err)
sqlDb, err := conn.SqlDB(ctx)
require.NoError(t, err)
// Check total entries at start, expect 0
numItems, err := repo.estimatedAccountCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
// create account and check count, expect 1
authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"})
acct := TestAccount(t, conn, authMethod1, "create-success")
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedAccountCount(ctx)
require.NoError(t, err)
assert.Equal(t, 1, numItems)
// Delete acct and check count, expect 0 again
_, err = repo.DeleteAccount(ctx, acct.GetPublicId())
require.NoError(t, err)
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedAccountCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
}
func TestRepository_listDeletedIds(t *testing.T) {
oldReadTimeout := globals.RefreshReadLookbackDuration
globals.RefreshReadLookbackDuration = 0
t.Cleanup(func() {
globals.RefreshReadLookbackDuration = oldReadTimeout
})
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
testWrapper := db.TestWrapper(t)
ctx := context.Background()
testKms := kms.TestKms(t, conn, testWrapper)
iamRepo := iam.TestRepo(t, conn, testWrapper)
org, _ := iam.TestScopes(t, iamRepo)
databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
repo, err := NewRepository(ctx, rw, rw, testKms)
assert.NoError(t, err)
sqlDb, err := conn.SqlDB(ctx)
require.NoError(t, err)
// Check total entries at start, expect 0
numItems, err := repo.estimatedAccountCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
// create account and check deleted ids, should be empty
authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"})
acct := TestAccount(t, conn, authMethod1, "create-success")
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
deletedIds, ttime, err := repo.listDeletedAccountIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
require.Empty(t, deletedIds)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Delete acct and check count, expect 1 entry
_, err = repo.DeleteAccount(ctx, acct.GetPublicId())
require.NoError(t, err)
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
deletedIds, ttime, err = repo.listDeletedAccountIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
assert.Empty(
t,
cmp.Diff(
[]string{acct.GetPublicId()},
deletedIds,
cmpopts.SortSlices(func(i, j string) bool { return i < j }),
),
)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Try again with the time set to now, expect no entries
deletedIds, ttime, err = repo.listDeletedAccountIds(ctx, time.Now())
require.NoError(t, err)
require.Empty(t, deletedIds)
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
}
func TestRepository_ListAccounts_Limits(t *testing.T) {
@ -805,9 +1073,12 @@ func TestRepository_ListAccounts_Limits(t *testing.T) {
repo, err := NewRepository(testCtx, testRw, testRw, testKms, tc.repoOpts...)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.ListAccounts(context.Background(), am.GetPublicId(), tc.listOpts...)
got, ttime, err := repo.listAccounts(context.Background(), am.GetPublicId(), tc.listOpts...)
require.NoError(err)
assert.Len(got, tc.wantLen)
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
})
}
}

@ -0,0 +1,62 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"time"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/pagination"
)
// ListAccounts lists up to page size ldap accounts, filtering out entries that
// do not pass the filter item function. It will automatically request
// more accounts from the database, at page size chunks, to fill the page.
// It returns a new list token used to continue pagination or refresh items.
// Accounts are ordered by create time descending (most recently created first).
func ListAccounts(
ctx context.Context,
grantsHash []byte,
pageSize int,
filterItemFn pagination.ListFilterFunc[auth.Account],
repo *Repository,
authMethodId string,
) (*pagination.ListResponse[auth.Account], error) {
const op = "ldap.ListAccounts"
switch {
case len(grantsHash) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash")
case pageSize < 1:
return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1")
case filterItemFn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback")
case repo == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo")
case authMethodId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID")
}
listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) {
opts := []Option{
WithLimit(ctx, limit),
}
if lastPageItem != nil {
opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem))
}
ldapAccts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...)
if err != nil {
return nil, time.Time{}, err
}
var accounts []auth.Account
for _, acct := range ldapAccts {
accounts = append(accounts, acct)
}
return accounts, listTime, nil
}
return pagination.List(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount)
}

@ -0,0 +1,623 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap_test
import (
"context"
"slices"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/auth/ldap"
"github.com/hashicorp/boundary/internal/auth/ldap/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/listtoken"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestService_ListAccounts(t *testing.T) {
// Set database read timeout to avoid duplicates in response
oldReadTimeout := globals.RefreshReadLookbackDuration
globals.RefreshReadLookbackDuration = 0
t.Cleanup(func() {
globals.RefreshReadLookbackDuration = oldReadTimeout
})
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
sqlDb, err := conn.SqlDB(ctx)
require.NoError(t, err)
wrapper := db.TestWrapper(t)
fiveDaysAgo := time.Now().AddDate(0, 0, -5)
rw := db.New(conn)
testKms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, iamRepo)
databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
repo, err := ldap.NewRepository(context.Background(), rw, rw, testKms)
require.NoError(t, err)
authMethod := ldap.TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"})
ldapAccts := []*ldap.Account{
ldap.TestAccount(t, conn, authMethod, "create-success"),
ldap.TestAccount(t, conn, authMethod, "create-success2"),
ldap.TestAccount(t, conn, authMethod, "create-success3"),
ldap.TestAccount(t, conn, authMethod, "create-success4"),
ldap.TestAccount(t, conn, authMethod, "create-success5"),
}
var accounts []auth.Account
for _, acct := range ldapAccts {
accounts = append(accounts, acct)
}
// since we sort by create time descending, we need to reverse the slice
slices.Reverse(accounts)
// Run analyze to update host estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(
ldap.Account{},
store.Account{},
timestamp.Timestamp{},
timestamppb.Timestamp{},
),
}
t.Run("ListAccounts validation", func(t *testing.T) {
t.Parallel()
t.Run("missing grants hash", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err := ldap.ListAccounts(ctx, nil, 1, filterFunc, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing grants hash")
})
t.Run("zero page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err := ldap.ListAccounts(ctx, []byte("some hash"), 0, filterFunc, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("negative page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err := ldap.ListAccounts(ctx, []byte("some hash"), -1, filterFunc, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("nil filter func", func(t *testing.T) {
t.Parallel()
_, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing filter item callback")
})
t.Run("nil repo", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, nil, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing repo")
})
t.Run("missing auth method ID", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, "")
require.ErrorContains(t, err, "missing auth method ID")
})
})
t.Run("ListAccountsPage validation", func(t *testing.T) {
t.Parallel()
t.Run("missing grants hash", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing grants hash")
})
t.Run("zero page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("negative page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("nil filter func", func(t *testing.T) {
t.Parallel()
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing filter item callback")
})
t.Run("nil token", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "token did not have a pagination token component")
})
t.Run("nil repo", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing repo")
})
t.Run("missing auth method ID", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "")
require.ErrorContains(t, err, "missing auth method ID")
})
t.Run("wrong token resource type", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "token did not have an account resource type")
})
})
t.Run("ListAccountsRefresh validation", func(t *testing.T) {
t.Parallel()
t.Run("missing grants hash", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing grants hash")
})
t.Run("zero page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("negative page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("nil filter func", func(t *testing.T) {
t.Parallel()
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing filter item callback")
})
t.Run("nil token", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("nil repo", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing repo")
})
t.Run("missing auth method ID", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "")
require.ErrorContains(t, err, "missing auth method ID")
})
t.Run("wrong token resource type", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "token did not have an account resource type")
})
})
t.Run("ListAccountsRefreshPage validation", func(t *testing.T) {
t.Parallel()
t.Run("missing grants hash", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing grants hash")
})
t.Run("zero page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("negative page size", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "page size must be at least 1")
})
t.Run("nil filter func", func(t *testing.T) {
t.Parallel()
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing filter item callback")
})
t.Run("nil token", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing token")
})
t.Run("wrong token type", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "token did not have a refresh token component")
})
t.Run("nil repo", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId())
require.ErrorContains(t, err, "missing repo")
})
t.Run("missing credential store id", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "")
require.ErrorContains(t, err, "missing auth method ID")
})
t.Run("wrong token resource type", func(t *testing.T) {
t.Parallel()
filterFunc := func(_ context.Context, a auth.Account) (bool, error) {
return true, nil
}
tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo)
require.NoError(t, err)
_, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId())
require.ErrorContains(t, err, "token did not have an account resource type")
})
})
t.Run("simple pagination", func(t *testing.T) {
filterFunc := func(context.Context, auth.Account) (bool, error) {
return true, nil
}
resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.NotNil(t, resp.ListToken)
require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp.CompleteListing)
require.Equal(t, resp.EstimatedItemCount, 5)
require.Empty(t, resp.DeletedIds)
require.Len(t, resp.Items, 1)
require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...))
resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp2.CompleteListing)
require.Equal(t, resp2.EstimatedItemCount, 5)
require.Empty(t, resp2.DeletedIds)
require.Len(t, resp2.Items, 1)
require.Empty(t, cmp.Diff(resp2.Items[0], accounts[1], cmpOpts...))
resp3, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp3.CompleteListing)
require.Equal(t, resp3.EstimatedItemCount, 5)
require.Empty(t, resp3.DeletedIds)
require.Len(t, resp3.Items, 1)
require.Empty(t, cmp.Diff(resp3.Items[0], accounts[2], cmpOpts...))
resp4, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp4.CompleteListing)
require.Equal(t, resp4.EstimatedItemCount, 5)
require.Empty(t, resp4.DeletedIds)
require.Len(t, resp4.Items, 1)
require.Empty(t, cmp.Diff(resp4.Items[0], accounts[3], cmpOpts...))
resp5, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp5.CompleteListing)
require.Equal(t, resp5.EstimatedItemCount, 5)
require.Empty(t, resp5.DeletedIds)
require.Len(t, resp5.Items, 1)
require.Empty(t, cmp.Diff(resp5.Items[0], accounts[4], cmpOpts...))
// Finished initial pagination phase, request refresh
// Expect no results.
resp6, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp5.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp6.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp6.CompleteListing)
require.Equal(t, resp6.EstimatedItemCount, 5)
require.Empty(t, resp6.DeletedIds)
require.Empty(t, resp6.Items)
// Create some new accounts
account1 := ldap.TestAccount(t, conn, authMethod, "new-success")
account2 := ldap.TestAccount(t, conn, authMethod, "new-success2")
t.Cleanup(func() {
repo.DeleteAccount(ctx, account1.GetPublicId())
require.NoError(t, err)
repo.DeleteAccount(ctx, account2.GetPublicId())
require.NoError(t, err)
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
})
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
// Refresh again, should get account2
resp7, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp6.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp7.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp7.CompleteListing)
require.Equal(t, resp7.EstimatedItemCount, 7)
require.Empty(t, resp7.DeletedIds)
require.Len(t, resp7.Items, 1)
require.Empty(t, cmp.Diff(resp7.Items[0], account2, cmpOpts...))
// Refresh again, should get account1
resp8, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp7.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp8.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp8.CompleteListing)
require.Equal(t, resp8.EstimatedItemCount, 7)
require.Empty(t, resp8.DeletedIds)
require.Len(t, resp8.Items, 1)
require.Empty(t, cmp.Diff(resp8.Items[0], account1, cmpOpts...))
// Refresh again, should get no results
resp9, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp8.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp9.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp9.CompleteListing)
require.Equal(t, resp9.EstimatedItemCount, 7)
require.Empty(t, resp9.DeletedIds)
require.Empty(t, resp9.Items)
})
t.Run("simple pagination with aggressive filtering", func(t *testing.T) {
filterFunc := func(ctx context.Context, a auth.Account) (bool, error) {
return a.GetPublicId() == accounts[1].GetPublicId() ||
a.GetPublicId() == accounts[len(accounts)-1].GetPublicId(), nil
}
resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.NotNil(t, resp.ListToken)
require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp.CompleteListing)
require.Equal(t, resp.EstimatedItemCount, 5)
require.Empty(t, resp.DeletedIds)
require.Len(t, resp.Items, 1)
require.Empty(t, cmp.Diff(resp.Items[0], accounts[1], cmpOpts...))
resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.NotNil(t, resp2.ListToken)
require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp2.CompleteListing)
require.Equal(t, resp2.EstimatedItemCount, 5)
require.Empty(t, resp2.DeletedIds)
require.Len(t, resp2.Items, 1)
require.Empty(t, cmp.Diff(resp2.Items[0], accounts[len(accounts)-1], cmpOpts...))
// request a refresh, nothing should be returned
resp3, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp3.CompleteListing)
require.Equal(t, resp3.EstimatedItemCount, 5)
require.Empty(t, resp3.DeletedIds)
require.Empty(t, resp3.Items)
// Create some new accounts
account1 := ldap.TestAccount(t, conn, authMethod, "new-success")
account2 := ldap.TestAccount(t, conn, authMethod, "new-success2")
account3 := ldap.TestAccount(t, conn, authMethod, "new-success3")
t.Cleanup(func() {
repo.DeleteAccount(ctx, account1.GetPublicId())
require.NoError(t, err)
repo.DeleteAccount(ctx, account2.GetPublicId())
require.NoError(t, err)
repo.DeleteAccount(ctx, account3.GetPublicId())
require.NoError(t, err)
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
})
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
filterFunc = func(_ context.Context, a auth.Account) (bool, error) {
return a.GetPublicId() == account3.GetPublicId() ||
a.GetPublicId() == account1.GetPublicId(), nil
}
// Refresh again, should get account3
resp4, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp4.CompleteListing)
require.Equal(t, resp4.EstimatedItemCount, 8)
require.Empty(t, resp4.DeletedIds)
require.Len(t, resp4.Items, 1)
require.Empty(t, cmp.Diff(resp4.Items[0], account3, cmpOpts...))
// Refresh again, should get account1
resp5, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp5.CompleteListing)
require.Equal(t, resp5.EstimatedItemCount, 8)
require.Empty(t, resp5.DeletedIds)
require.Len(t, resp5.Items, 1)
require.Empty(t, cmp.Diff(resp5.Items[0], account1, cmpOpts...))
})
t.Run("simple pagination with deletion", func(t *testing.T) {
filterFunc := func(context.Context, auth.Account) (bool, error) {
return true, nil
}
deletedAccountId := accounts[0].GetPublicId()
repo.DeleteAccount(ctx, deletedAccountId)
require.NoError(t, err)
accounts = accounts[1:]
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.NotNil(t, resp.ListToken)
require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash"))
require.False(t, resp.CompleteListing)
require.Equal(t, resp.EstimatedItemCount, 4)
require.Empty(t, resp.DeletedIds)
require.Len(t, resp.Items, 1)
require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...))
// request remaining results
resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 3, filterFunc, resp.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp2.CompleteListing)
require.Equal(t, resp2.EstimatedItemCount, 4)
require.Empty(t, resp2.DeletedIds)
require.Len(t, resp2.Items, 3)
require.Empty(t, cmp.Diff(resp2.Items, accounts[1:], cmpOpts...))
deletedAccountId = accounts[0].GetPublicId()
repo.DeleteAccount(ctx, deletedAccountId)
require.NoError(t, err)
accounts = accounts[1:]
// Run analyze to update count estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
// request a refresh, nothing should be returned except the deleted id
resp3, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId())
require.NoError(t, err)
require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash"))
require.True(t, resp3.CompleteListing)
require.Equal(t, resp3.EstimatedItemCount, 3)
require.Contains(t, resp3.DeletedIds, deletedAccountId)
require.Empty(t, resp3.Items)
})
}

@ -0,0 +1,79 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"time"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/listtoken"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/types/resource"
)
// ListAccountsPage lists up to page size ldap accounts, filtering out entries that
// do not pass the filter item function. It will automatically request
// more ldap accounts from the database, at page size chunks, to fill the page.
// It will start its paging based on the information in the token.
// It returns a new list token used to continue pagination or refresh items.
// Accounts are ordered by create time descending (most recently created first).
func ListAccountsPage(
ctx context.Context,
grantsHash []byte,
pageSize int,
filterItemFn pagination.ListFilterFunc[auth.Account],
tok *listtoken.Token,
repo *Repository,
authMethodId string,
) (*pagination.ListResponse[auth.Account], error) {
const op = "ldap.ListAccountsPage"
switch {
case len(grantsHash) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash")
case pageSize < 1:
return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1")
case filterItemFn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback")
case tok == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token")
case repo == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo")
case authMethodId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID")
case tok.ResourceType != resource.Account:
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type")
}
if _, ok := tok.Subtype.(*listtoken.PaginationToken); !ok {
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a pagination token component")
}
listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) {
opts := []Option{
WithLimit(ctx, limit),
}
if lastPageItem != nil {
opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem))
} else {
lastItem, err := tok.LastItem(ctx)
if err != nil {
return nil, time.Time{}, err
}
opts = append(opts, WithStartPageAfterItem(ctx, lastItem))
}
ldapAccounts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...)
if err != nil {
return nil, time.Time{}, err
}
var accts []auth.Account
for _, acct := range ldapAccounts {
accts = append(accts, acct)
}
return accts, listTime, nil
}
return pagination.ListPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, tok)
}

@ -0,0 +1,81 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/listtoken"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/types/resource"
)
// ListAccountsRefresh lists ldap accounts according to the page size
// and list token, filtering out entries that do not
// pass the filter item fn. It returns a new list token
// based on the old one, the grants hash, and the returned
// ldap accounts.
func ListAccountsRefresh(
ctx context.Context,
grantsHash []byte,
pageSize int,
filterItemFn pagination.ListFilterFunc[auth.Account],
tok *listtoken.Token,
repo *Repository,
authMethodId string,
) (*pagination.ListResponse[auth.Account], error) {
const op = "ldap.ListAccountsRefresh"
switch {
case len(grantsHash) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash")
case pageSize < 1:
return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1")
case filterItemFn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback")
case tok == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token")
case repo == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo")
case authMethodId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID")
case tok.ResourceType != resource.Account:
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type")
}
rt, ok := tok.Subtype.(*listtoken.StartRefreshToken)
if !ok {
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a start-refresh token component")
}
listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) {
opts := []Option{
WithLimit(ctx, limit),
}
if lastPageItem != nil {
opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem))
}
// Add the database read timeout to account for any creations missed due to concurrent
// transactions in the initial pagination phase.
ldapAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PreviousPhaseUpperBound.Add(-globals.RefreshReadLookbackDuration), opts...)
if err != nil {
return nil, time.Time{}, err
}
var accounts []auth.Account
for _, account := range ldapAccounts {
accounts = append(accounts, account)
}
return accounts, listTime, nil
}
listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) {
// Add the database read timeout to account for any deletes missed due to concurrent
// transactions in the original list pagination phase.
return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration))
}
return pagination.ListRefresh(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok)
}

@ -0,0 +1,91 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package ldap
import (
"context"
"time"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/listtoken"
"github.com/hashicorp/boundary/internal/pagination"
"github.com/hashicorp/boundary/internal/types/resource"
)
// ListAccountsRefreshPage lists up to page size accounts, filtering out entries that
// do not pass the filter item function. It will automatically request
// more accounts from the database, at page size chunks, to fill the page.
// It will start its paging based on the information in the token.
// It returns a new list token used to continue pagination or refresh items.
// Accounts are ordered by update time descending (most recently updated first).
// Accounts may contain items that were already returned during the initial
// pagination phase. It also returns a list of any accounts deleted since the
// last response.
func ListAccountsRefreshPage(
ctx context.Context,
grantsHash []byte,
pageSize int,
filterItemFn pagination.ListFilterFunc[auth.Account],
tok *listtoken.Token,
repo *Repository,
authMethodId string,
) (*pagination.ListResponse[auth.Account], error) {
const op = "ldap.ListAccountsRefreshPage"
switch {
case len(grantsHash) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash")
case pageSize < 1:
return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1")
case filterItemFn == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback")
case tok == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token")
case repo == nil:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo")
case authMethodId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID")
case tok.ResourceType != resource.Account:
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type")
}
rt, ok := tok.Subtype.(*listtoken.RefreshToken)
if !ok {
return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a refresh token component")
}
listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) {
opts := []Option{
WithLimit(ctx, limit),
}
if lastPageItem != nil {
opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem))
} else {
lastItem, err := tok.LastItem(ctx)
if err != nil {
return nil, time.Time{}, err
}
opts = append(opts, WithStartPageAfterItem(ctx, lastItem))
}
// Add the database read timeout to account for any creations missed due to concurrent
// transactions in the original list pagination phase.
sAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PhaseLowerBound.Add(-globals.RefreshReadLookbackDuration), opts...)
if err != nil {
return nil, time.Time{}, err
}
var accounts []auth.Account
for _, account := range sAccounts {
accounts = append(accounts, account)
}
return accounts, listTime, nil
}
listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) {
// Add the database read timeout to account for any deletes missed due to concurrent
// transactions in the original list pagination phase.
return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration))
}
return pagination.ListRefreshPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok)
}
Loading…
Cancel
Save