diff --git a/internal/auth/ldap/account.go b/internal/auth/ldap/account.go index ac3633b675..e137c4d4b8 100644 --- a/internal/auth/ldap/account.go +++ b/internal/auth/ldap/account.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/auth" "github.com/hashicorp/boundary/internal/auth/ldap/store" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/types/resource" @@ -150,3 +151,13 @@ func (a *Account) oplog(ctx context.Context, opType oplog.OpType) (oplog.Metadat } return metadata, nil } + +type deletedAccount struct { + PublicId string `gorm:"primary_key"` + DeleteTime *timestamp.Timestamp +} + +// TableName returns the tablename to override the default gorm table name +func (s *deletedAccount) TableName() string { + return "auth_ldap_account_deleted" +} diff --git a/internal/auth/ldap/options.go b/internal/auth/ldap/options.go index fbce8a4d38..203a5be6d7 100644 --- a/internal/auth/ldap/options.go +++ b/internal/auth/ldap/options.go @@ -11,6 +11,8 @@ import ( "net/url" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/util" ) type options struct { @@ -48,6 +50,7 @@ type options struct { withPublicId string withDerefAliases DerefAliasType withMaximumPageSize uint32 + withStartPageAfterItem pagination.Item } // Option - how options are passed as args @@ -413,3 +416,16 @@ func (d DerefAliasType) IsValid(ctx context.Context) error { return errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("%q is not a valid ldap dereference alias type", d)) } } + +// WithStartPageAfterItem is used to paginate over the results. +// The next page will start after the provided item. +func WithStartPageAfterItem(ctx context.Context, item pagination.Item) Option { + const op = "ldap.WithStartPageAfterItem" + return func(o *options) error { + if util.IsNil(item) { + return errors.New(ctx, errors.InvalidParameter, op, "item cannot be nil") + } + o.withStartPageAfterItem = item + return nil + } +} diff --git a/internal/auth/ldap/options_test.go b/internal/auth/ldap/options_test.go index 08f70a5b57..d05cfc962e 100644 --- a/internal/auth/ldap/options_test.go +++ b/internal/auth/ldap/options_test.go @@ -9,12 +9,34 @@ import ( "crypto/rand" "crypto/x509" "testing" + "time" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/pagination" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +type fakeItem struct { + pagination.Item + publicId string + createTime time.Time + updateTime time.Time +} + +func (p *fakeItem) GetPublicId() string { + return p.publicId +} + +func (p *fakeItem) GetCreateTime() *timestamp.Timestamp { + return timestamp.New(p.createTime) +} + +func (p *fakeItem) GetUpdateTime() *timestamp.Timestamp { + return timestamp.New(p.updateTime) +} + func Test_getOpts(t *testing.T) { t.Parallel() testCtx := context.Background() @@ -350,4 +372,18 @@ func Test_getOpts(t *testing.T) { assert.ErrorContains(err, `"Invalid" is not a valid ldap dereference alias type`) assert.Truef(errors.Match(errors.T(errors.InvalidParameter), err), "want err code: %q got: %q", errors.InvalidParameter, err) }) + t.Run("WithStartPageAfterItem", func(t *testing.T) { + t.Run("nil item", func(t *testing.T) { + _, err := getOpts(WithStartPageAfterItem(context.Background(), nil)) + require.Error(t, err) + }) + assert := assert.New(t) + updateTime := time.Now() + createTime := time.Now() + opts, err := getOpts(WithStartPageAfterItem(context.Background(), &fakeItem{nil, "s_1", createTime, updateTime})) + require.NoError(t, err) + assert.Equal(opts.withStartPageAfterItem.GetPublicId(), "s_1") + assert.Equal(opts.withStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime)) + assert.Equal(opts.withStartPageAfterItem.GetCreateTime(), timestamp.New(createTime)) + }) } diff --git a/internal/auth/ldap/query.go b/internal/auth/ldap/query.go new file mode 100644 index 0000000000..e23de6d431 --- /dev/null +++ b/internal/auth/ldap/query.go @@ -0,0 +1,10 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap + +const ( + estimateCountAccounts = ` +select sum(reltuples::bigint) as estimate from pg_class where oid in ('auth_ldap_account'::regclass) +` +) diff --git a/internal/auth/ldap/repository_account.go b/internal/auth/ldap/repository_account.go index 5e262089d3..caa9bcab69 100644 --- a/internal/auth/ldap/repository_account.go +++ b/internal/auth/ldap/repository_account.go @@ -5,10 +5,13 @@ package ldap import ( "context" + "database/sql" "fmt" "strings" + "time" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" @@ -96,27 +99,103 @@ func (r *Repository) LookupAccount(ctx context.Context, withPublicId string, _ . return a, nil } -// ListAccounts in an auth method and supports WithLimit option. -func (r *Repository) ListAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, error) { - const op = "ldap.(Repository).ListAccounts" +// listAccounts returns a slice of accounts in the auth method. +// Supported options: +// - WithLimit which overrides the limit set in the Repository object +// - WithStartPageAfterItem which sets where to start listing from +func (r *Repository) listAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, time.Time, error) { + const op = "ldap.(Repository).listAccounts" if withAuthMethodId == "" { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") } opts, err := getOpts(opt...) if err != nil { - return nil, errors.Wrap(ctx, err, op) + return nil, time.Time{}, errors.Wrap(ctx, err, op) } + limit := r.defaultLimit if opts.withLimit != 0 { // non-zero signals an override of the default limit for the repo. limit = opts.withLimit } - var accts []*Account - err = r.reader.SearchWhere(ctx, &accts, "auth_method_id = ?", []any{withAuthMethodId}, db.WithLimit(limit)) + + var args []any + whereClause := "auth_method_id = @auth_method_id" + args = append(args, sql.Named("auth_method_id", withAuthMethodId)) + + if opts.withStartPageAfterItem != nil { + whereClause = fmt.Sprintf("(create_time, public_id) < (@last_item_create_time, @last_item_id) and %s", whereClause) + args = append(args, + sql.Named("last_item_create_time", opts.withStartPageAfterItem.GetCreateTime()), + sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), + ) + } + + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("create_time desc, public_id desc")} + return r.queryAccounts(ctx, whereClause, args, dbOpts...) +} + +// listAccountsRefresh returns a slice of accounts in the auth method. +// Supported options: +// - WithLimit which overrides the limit set in the Repository object +// - WithStartPageAfterItem which sets where to start listing from +func (r *Repository) listAccountsRefresh(ctx context.Context, withAuthMethodId string, updatedAfter time.Time, opt ...Option) ([]*Account, time.Time, error) { + const op = "ldap.(Repository).listAccountsRefresh" + switch { + case withAuthMethodId == "": + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") + case updatedAfter.IsZero(): + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing updated after time") + } + + opts, err := getOpts(opt...) if err != nil { - return nil, errors.Wrap(ctx, err, op) + return nil, time.Time{}, errors.Wrap(ctx, err, op) } - return accts, nil + + limit := r.defaultLimit + if opts.withLimit != 0 { + // non-zero signals an override of the default limit for the repo. + limit = opts.withLimit + } + + var args []any + whereClause := "update_time > @updated_after_time and auth_method_id = @auth_method_id" + args = append(args, + sql.Named("updated_after_time", timestamp.New(updatedAfter)), + sql.Named("auth_method_id", withAuthMethodId), + ) + + if opts.withStartPageAfterItem != nil { + whereClause = fmt.Sprintf("(update_time, public_id) < (@last_item_update_time, @last_item_id) and %s", whereClause) + args = append(args, + sql.Named("last_item_update_time", opts.withStartPageAfterItem.GetUpdateTime()), + sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), + ) + } + + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("update_time desc, public_id desc")} + return r.queryAccounts(ctx, whereClause, args, dbOpts...) +} + +func (r *Repository) queryAccounts(ctx context.Context, whereClause string, args []any, opt ...db.Option) ([]*Account, time.Time, error) { + const op = "ldap.(Repository).queryAccounts" + + var accts []*Account + var transactionTimestamp time.Time + if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(rd db.Reader, w db.Writer) error { + var inAccts []*Account + if err := rd.SearchWhere(ctx, &inAccts, whereClause, args, opt...); err != nil { + return errors.Wrap(ctx, err, op) + } + accts = inAccts + var err error + transactionTimestamp, err = rd.Now(ctx) + return err + }); err != nil { + return nil, time.Time{}, errors.Wrap(ctx, err, op) + } + return accts, transactionTimestamp, nil } // DeleteAccount deletes the account for the provided id from the repository returning a count of the @@ -252,3 +331,45 @@ func (r *Repository) UpdateAccount(ctx context.Context, scopeId string, a *Accou return returnedAccount, rowsUpdated, nil } + +// listDeletedAccountIds lists the public IDs of any accounts deleted since the timestamp provided, +// and the timestamp of the transaction within which the accounts were listed. +func (r *Repository) listDeletedAccountIds(ctx context.Context, since time.Time) ([]string, time.Time, error) { + const op = "ldap.(Repository).listDeletedAccountIds" + var deleteAccounts []*deletedAccount + var transactionTimestamp time.Time + if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, _ db.Writer) error { + if err := r.SearchWhere(ctx, &deleteAccounts, "delete_time >= ?", []any{since}); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query deleted accounts")) + } + var err error + transactionTimestamp, err = r.Now(ctx) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("failed to get transaction timestamp")) + } + return nil + }); err != nil { + return nil, time.Time{}, err + } + var accountIds []string + for _, a := range deleteAccounts { + accountIds = append(accountIds, a.PublicId) + } + return accountIds, transactionTimestamp, nil +} + +// estimatedAccountCount returns an estimate of the total number of accounts. +func (r *Repository) estimatedAccountCount(ctx context.Context) (int, error) { + const op = "ldap.(Repository).estimatedAccountCount" + rows, err := r.reader.Query(ctx, estimateCountAccounts, nil) + if err != nil { + return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts")) + } + var count int + for rows.Next() { + if err := r.reader.ScanRows(ctx, rows, &count); err != nil { + return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts")) + } + } + return count, nil +} diff --git a/internal/auth/ldap/repository_account_test.go b/internal/auth/ldap/repository_account_test.go index 895465d39d..35ac17bf83 100644 --- a/internal/auth/ldap/repository_account_test.go +++ b/internal/auth/ldap/repository_account_test.go @@ -6,24 +6,27 @@ package ldap import ( "context" "fmt" - "sort" "strings" "testing" "time" "github.com/DATA-DOG/go-sqlmock" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/ldap/store" "github.com/hashicorp/boundary/internal/db" dbassert "github.com/hashicorp/boundary/internal/db/assert" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" "google.golang.org/protobuf/testing/protocmp" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestRepository_CreateAccount(t *testing.T) { @@ -640,7 +643,13 @@ func TestRepository_DeleteAccount(t *testing.T) { } } -func TestRepository_ListAccounts(t *testing.T) { +func TestRepository_listAccounts(t *testing.T) { + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + testConn, _ := db.TestSetup(t, "postgres") testRw := db.New(testConn) testWrapper := db.TestWrapper(t) @@ -669,7 +678,17 @@ func TestRepository_ListAccounts(t *testing.T) { TestAccount(t, testConn, authMethod2, "create-success2"), TestAccount(t, testConn, authMethod2, "create-success3"), } - _ = accounts2 + slices.Reverse(accounts1) + slices.Reverse(accounts2) + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported( + Account{}, + store.Account{}, + timestamp.Timestamp{}, + timestamppb.Timestamp{}, + ), + } tests := []struct { name string @@ -717,7 +736,7 @@ func TestRepository_ListAccounts(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { assert, require := assert.New(t), require.New(t) - got, err := tc.repo.ListAccounts(testCtx, tc.publicId, tc.opts...) + got, ttime, err := tc.repo.listAccounts(testCtx, tc.publicId, tc.opts...) if tc.wantErrMatch != nil { assert.Truef(errors.Match(tc.wantErrMatch, err), "Unexpected error %s", err) if tc.wantErrContains != "" { @@ -726,13 +745,262 @@ func TestRepository_ListAccounts(t *testing.T) { return } require.NoError(err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(time.Now().Before(ttime.Add(10 * time.Second))) + assert.True(time.Now().After(ttime.Add(-10 * time.Second))) - sort.Slice(got, func(i, j int) bool { - return strings.Compare(got[i].LoginName, got[j].LoginName) < 0 - }) - assert.EqualValues(tc.want, got) + require.Empty(cmp.Diff(got, tc.want, cmpOpts...)) }) } + + t.Run("validation", func(t *testing.T) { + t.Parallel() + t.Run("missing auth method id", func(t *testing.T) { + t.Parallel() + _, _, err := testRepo.listAccounts(testCtx, "", WithLimit(testCtx, 1)) + require.ErrorContains(t, err, "missing auth method id") + }) + }) + + t.Run("success-without-after-item", func(t *testing.T) { + t.Parallel() + resp, ttime, err := testRepo.listAccounts(testCtx, authMethod1.PublicId, WithLimit(testCtx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1, cmpOpts...)) + }) + t.Run("success-with-after-item", func(t *testing.T) { + t.Parallel() + resp, ttime, err := testRepo.listAccounts(testCtx, authMethod1.PublicId, WithStartPageAfterItem(testCtx, accounts1[0]), WithLimit(testCtx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1[1:], cmpOpts...)) + }) +} + +func TestRepository_listAccountsRefresh(t *testing.T) { + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + testWrapper := db.TestWrapper(t) + + ctx := context.Background() + testKms := kms.TestKms(t, conn, testWrapper) + iamRepo := iam.TestRepo(t, conn, testWrapper) + org, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + + repo, err := NewRepository(ctx, rw, rw, testKms) + assert.NoError(t, err) + + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + + authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + authMethod2 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap2"}) + accounts1 := []*Account{ + TestAccount(t, conn, authMethod1, "create-success"), + TestAccount(t, conn, authMethod1, "create-success2"), + TestAccount(t, conn, authMethod1, "create-success3"), + } + accounts2 := []*Account{ + TestAccount(t, conn, authMethod2, "create-success"), + TestAccount(t, conn, authMethod2, "create-success2"), + TestAccount(t, conn, authMethod2, "create-success3"), + } + + slices.Reverse(accounts1) + _ = accounts2 + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported( + Account{}, + store.Account{}, + timestamp.Timestamp{}, + timestamppb.Timestamp{}, + ), + cmpopts.SortSlices(func(i, j string) bool { return i < j }), + } + + t.Run("validation", func(t *testing.T) { + t.Parallel() + t.Run("missing updated after", func(t *testing.T) { + t.Parallel() + _, _, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, time.Time{}, WithLimit(ctx, 1)) + require.ErrorContains(t, err, "missing updated after time") + }) + t.Run("missing auth method id", func(t *testing.T) { + t.Parallel() + _, _, err := repo.listAccountsRefresh(ctx, "", fiveDaysAgo, WithLimit(ctx, 1)) + require.ErrorContains(t, err, "missing auth method id") + }) + }) + + t.Run("success-without-after-item", func(t *testing.T) { + t.Parallel() + resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, fiveDaysAgo, WithLimit(ctx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1, cmpOpts...)) + }) + t.Run("success-with-after-item", func(t *testing.T) { + t.Parallel() + resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, fiveDaysAgo, WithStartPageAfterItem(ctx, accounts1[0]), WithLimit(ctx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1[1:], cmpOpts...)) + }) + t.Run("success-without-after-item-recent-updated-after", func(t *testing.T) { + t.Parallel() + resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, accounts1[len(accounts1)-1].GetUpdateTime().AsTime(), WithLimit(ctx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1[:len(accounts1)-1], cmpOpts...)) + }) + t.Run("success-with-after-item-recent-updated-after", func(t *testing.T) { + t.Parallel() + resp, ttime, err := repo.listAccountsRefresh(ctx, authMethod1.PublicId, accounts1[len(accounts1)-1].GetUpdateTime().AsTime(), WithStartPageAfterItem(ctx, accounts1[0]), WithLimit(ctx, 10)) + require.NoError(t, err) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + require.Empty(t, cmp.Diff(resp, accounts1[1:len(accounts1)-1], cmpOpts...)) + }) +} + +func TestRepository_estimatedCount(t *testing.T) { + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + testWrapper := db.TestWrapper(t) + + ctx := context.Background() + testKms := kms.TestKms(t, conn, testWrapper) + iamRepo := iam.TestRepo(t, conn, testWrapper) + org, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + + repo, err := NewRepository(ctx, rw, rw, testKms) + assert.NoError(t, err) + + sqlDb, err := conn.SqlDB(ctx) + require.NoError(t, err) + + // Check total entries at start, expect 0 + numItems, err := repo.estimatedAccountCount(ctx) + require.NoError(t, err) + assert.Equal(t, 0, numItems) + + // create account and check count, expect 1 + authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + acct := TestAccount(t, conn, authMethod1, "create-success") + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + numItems, err = repo.estimatedAccountCount(ctx) + require.NoError(t, err) + assert.Equal(t, 1, numItems) + + // Delete acct and check count, expect 0 again + _, err = repo.DeleteAccount(ctx, acct.GetPublicId()) + require.NoError(t, err) + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + numItems, err = repo.estimatedAccountCount(ctx) + require.NoError(t, err) + assert.Equal(t, 0, numItems) +} + +func TestRepository_listDeletedIds(t *testing.T) { + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + conn, _ := db.TestSetup(t, "postgres") + rw := db.New(conn) + testWrapper := db.TestWrapper(t) + + ctx := context.Background() + testKms := kms.TestKms(t, conn, testWrapper) + iamRepo := iam.TestRepo(t, conn, testWrapper) + org, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + + repo, err := NewRepository(ctx, rw, rw, testKms) + assert.NoError(t, err) + + sqlDb, err := conn.SqlDB(ctx) + require.NoError(t, err) + + // Check total entries at start, expect 0 + numItems, err := repo.estimatedAccountCount(ctx) + require.NoError(t, err) + assert.Equal(t, 0, numItems) + + // create account and check deleted ids, should be empty + authMethod1 := TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + acct := TestAccount(t, conn, authMethod1, "create-success") + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + deletedIds, ttime, err := repo.listDeletedAccountIds(ctx, time.Now().AddDate(-1, 0, 0)) + require.NoError(t, err) + require.Empty(t, deletedIds) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + + // Delete acct and check count, expect 1 entry + _, err = repo.DeleteAccount(ctx, acct.GetPublicId()) + require.NoError(t, err) + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + deletedIds, ttime, err = repo.listDeletedAccountIds(ctx, time.Now().AddDate(-1, 0, 0)) + require.NoError(t, err) + assert.Empty( + t, + cmp.Diff( + []string{acct.GetPublicId()}, + deletedIds, + cmpopts.SortSlices(func(i, j string) bool { return i < j }), + ), + ) + // Transaction timestamp should be within ~10 seconds of now + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) + + // Try again with the time set to now, expect no entries + deletedIds, ttime, err = repo.listDeletedAccountIds(ctx, time.Now()) + require.NoError(t, err) + require.Empty(t, deletedIds) + assert.True(t, time.Now().Before(ttime.Add(10*time.Second))) + assert.True(t, time.Now().After(ttime.Add(-10*time.Second))) } func TestRepository_ListAccounts_Limits(t *testing.T) { @@ -805,9 +1073,12 @@ func TestRepository_ListAccounts_Limits(t *testing.T) { repo, err := NewRepository(testCtx, testRw, testRw, testKms, tc.repoOpts...) assert.NoError(err) require.NotNil(repo) - got, err := repo.ListAccounts(context.Background(), am.GetPublicId(), tc.listOpts...) + got, ttime, err := repo.listAccounts(context.Background(), am.GetPublicId(), tc.listOpts...) require.NoError(err) assert.Len(got, tc.wantLen) + // Transaction timestamp should be within ~10 seconds of now + assert.True(time.Now().Before(ttime.Add(10 * time.Second))) + assert.True(time.Now().After(ttime.Add(-10 * time.Second))) }) } } diff --git a/internal/auth/ldap/service_list_accounts.go b/internal/auth/ldap/service_list_accounts.go new file mode 100644 index 0000000000..f41cae1656 --- /dev/null +++ b/internal/auth/ldap/service_list_accounts.go @@ -0,0 +1,62 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/pagination" +) + +// ListAccounts lists up to page size ldap accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more accounts from the database, at page size chunks, to fill the page. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by create time descending (most recently created first). +func ListAccounts( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "ldap.ListAccounts" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(ctx, limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem)) + } + ldapAccts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, acct := range ldapAccts { + accounts = append(accounts, acct) + } + return accounts, listTime, nil + } + + return pagination.List(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount) +} diff --git a/internal/auth/ldap/service_list_accounts_ext_test.go b/internal/auth/ldap/service_list_accounts_ext_test.go new file mode 100644 index 0000000000..0d79332e8b --- /dev/null +++ b/internal/auth/ldap/service_list_accounts_ext_test.go @@ -0,0 +1,623 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap_test + +import ( + "context" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/auth/ldap" + "github.com/hashicorp/boundary/internal/auth/ldap/store" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/types/resource" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestService_ListAccounts(t *testing.T) { + // Set database read timeout to avoid duplicates in response + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + sqlDb, err := conn.SqlDB(ctx) + require.NoError(t, err) + wrapper := db.TestWrapper(t) + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + rw := db.New(conn) + + testKms := kms.TestKms(t, conn, wrapper) + iamRepo := iam.TestRepo(t, conn, wrapper) + org, _ := iam.TestScopes(t, iamRepo) + + databaseWrapper, err := testKms.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase) + require.NoError(t, err) + + repo, err := ldap.NewRepository(context.Background(), rw, rw, testKms) + require.NoError(t, err) + + authMethod := ldap.TestAuthMethod(t, conn, databaseWrapper, org.PublicId, []string{"ldaps://ldap1"}) + ldapAccts := []*ldap.Account{ + ldap.TestAccount(t, conn, authMethod, "create-success"), + ldap.TestAccount(t, conn, authMethod, "create-success2"), + ldap.TestAccount(t, conn, authMethod, "create-success3"), + ldap.TestAccount(t, conn, authMethod, "create-success4"), + ldap.TestAccount(t, conn, authMethod, "create-success5"), + } + + var accounts []auth.Account + for _, acct := range ldapAccts { + accounts = append(accounts, acct) + } + // since we sort by create time descending, we need to reverse the slice + slices.Reverse(accounts) + + // Run analyze to update host estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported( + ldap.Account{}, + store.Account{}, + timestamp.Timestamp{}, + timestamppb.Timestamp{}, + ), + } + + t.Run("ListAccounts validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := ldap.ListAccounts(ctx, nil, 1, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := ldap.ListAccounts(ctx, []byte("some hash"), 0, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := ldap.ListAccounts(ctx, []byte("some hash"), -1, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + _, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + }) + t.Run("ListAccountsPage validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("wrong token type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have a pagination token component") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + t.Run("ListAccountsRefresh validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + t.Run("ListAccountsRefreshPage validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("wrong token type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have a refresh token component") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing credential store id", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + + t.Run("simple pagination", func(t *testing.T) { + filterFunc := func(context.Context, auth.Account) (bool, error) { + return true, nil + } + resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 5) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...)) + + resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 5) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 1) + require.Empty(t, cmp.Diff(resp2.Items[0], accounts[1], cmpOpts...)) + + resp3, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 5) + require.Empty(t, resp3.DeletedIds) + require.Len(t, resp3.Items, 1) + require.Empty(t, cmp.Diff(resp3.Items[0], accounts[2], cmpOpts...)) + + resp4, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp4.CompleteListing) + require.Equal(t, resp4.EstimatedItemCount, 5) + require.Empty(t, resp4.DeletedIds) + require.Len(t, resp4.Items, 1) + require.Empty(t, cmp.Diff(resp4.Items[0], accounts[3], cmpOpts...)) + + resp5, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp5.CompleteListing) + require.Equal(t, resp5.EstimatedItemCount, 5) + require.Empty(t, resp5.DeletedIds) + require.Len(t, resp5.Items, 1) + require.Empty(t, cmp.Diff(resp5.Items[0], accounts[4], cmpOpts...)) + + // Finished initial pagination phase, request refresh + // Expect no results. + resp6, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp5.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp6.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp6.CompleteListing) + require.Equal(t, resp6.EstimatedItemCount, 5) + require.Empty(t, resp6.DeletedIds) + require.Empty(t, resp6.Items) + + // Create some new accounts + account1 := ldap.TestAccount(t, conn, authMethod, "new-success") + account2 := ldap.TestAccount(t, conn, authMethod, "new-success2") + t.Cleanup(func() { + repo.DeleteAccount(ctx, account1.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, account2.GetPublicId()) + require.NoError(t, err) + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + }) + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // Refresh again, should get account2 + resp7, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp6.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp7.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp7.CompleteListing) + require.Equal(t, resp7.EstimatedItemCount, 7) + require.Empty(t, resp7.DeletedIds) + require.Len(t, resp7.Items, 1) + require.Empty(t, cmp.Diff(resp7.Items[0], account2, cmpOpts...)) + + // Refresh again, should get account1 + resp8, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp7.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp8.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp8.CompleteListing) + require.Equal(t, resp8.EstimatedItemCount, 7) + require.Empty(t, resp8.DeletedIds) + require.Len(t, resp8.Items, 1) + require.Empty(t, cmp.Diff(resp8.Items[0], account1, cmpOpts...)) + + // Refresh again, should get no results + resp9, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp8.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp9.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp9.CompleteListing) + require.Equal(t, resp9.EstimatedItemCount, 7) + require.Empty(t, resp9.DeletedIds) + require.Empty(t, resp9.Items) + }) + + t.Run("simple pagination with aggressive filtering", func(t *testing.T) { + filterFunc := func(ctx context.Context, a auth.Account) (bool, error) { + return a.GetPublicId() == accounts[1].GetPublicId() || + a.GetPublicId() == accounts[len(accounts)-1].GetPublicId(), nil + } + resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 5) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[1], cmpOpts...)) + + resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp2.ListToken) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 5) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 1) + require.Empty(t, cmp.Diff(resp2.Items[0], accounts[len(accounts)-1], cmpOpts...)) + + // request a refresh, nothing should be returned + resp3, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 5) + require.Empty(t, resp3.DeletedIds) + require.Empty(t, resp3.Items) + + // Create some new accounts + account1 := ldap.TestAccount(t, conn, authMethod, "new-success") + account2 := ldap.TestAccount(t, conn, authMethod, "new-success2") + account3 := ldap.TestAccount(t, conn, authMethod, "new-success3") + t.Cleanup(func() { + repo.DeleteAccount(ctx, account1.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, account2.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, account3.GetPublicId()) + require.NoError(t, err) + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + }) + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + filterFunc = func(_ context.Context, a auth.Account) (bool, error) { + return a.GetPublicId() == account3.GetPublicId() || + a.GetPublicId() == account1.GetPublicId(), nil + } + // Refresh again, should get account3 + resp4, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp4.CompleteListing) + require.Equal(t, resp4.EstimatedItemCount, 8) + require.Empty(t, resp4.DeletedIds) + require.Len(t, resp4.Items, 1) + require.Empty(t, cmp.Diff(resp4.Items[0], account3, cmpOpts...)) + + // Refresh again, should get account1 + resp5, err := ldap.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp5.CompleteListing) + require.Equal(t, resp5.EstimatedItemCount, 8) + require.Empty(t, resp5.DeletedIds) + require.Len(t, resp5.Items, 1) + require.Empty(t, cmp.Diff(resp5.Items[0], account1, cmpOpts...)) + }) + + t.Run("simple pagination with deletion", func(t *testing.T) { + filterFunc := func(context.Context, auth.Account) (bool, error) { + return true, nil + } + deletedAccountId := accounts[0].GetPublicId() + repo.DeleteAccount(ctx, deletedAccountId) + require.NoError(t, err) + accounts = accounts[1:] + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := ldap.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 4) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...)) + + // request remaining results + resp2, err := ldap.ListAccountsPage(ctx, []byte("some hash"), 3, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 4) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 3) + require.Empty(t, cmp.Diff(resp2.Items, accounts[1:], cmpOpts...)) + + deletedAccountId = accounts[0].GetPublicId() + repo.DeleteAccount(ctx, deletedAccountId) + require.NoError(t, err) + accounts = accounts[1:] + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // request a refresh, nothing should be returned except the deleted id + resp3, err := ldap.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 3) + require.Contains(t, resp3.DeletedIds, deletedAccountId) + require.Empty(t, resp3.Items) + }) +} diff --git a/internal/auth/ldap/service_list_accounts_page.go b/internal/auth/ldap/service_list_accounts_page.go new file mode 100644 index 0000000000..e16db5d20a --- /dev/null +++ b/internal/auth/ldap/service_list_accounts_page.go @@ -0,0 +1,79 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsPage lists up to page size ldap accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more ldap accounts from the database, at page size chunks, to fill the page. +// It will start its paging based on the information in the token. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by create time descending (most recently created first). +func ListAccountsPage( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "ldap.ListAccountsPage" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + if _, ok := tok.Subtype.(*listtoken.PaginationToken); !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a pagination token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(ctx, limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem)) + } else { + lastItem, err := tok.LastItem(ctx) + if err != nil { + return nil, time.Time{}, err + } + opts = append(opts, WithStartPageAfterItem(ctx, lastItem)) + } + ldapAccounts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...) + if err != nil { + return nil, time.Time{}, err + } + var accts []auth.Account + for _, acct := range ldapAccounts { + accts = append(accts, acct) + } + return accts, listTime, nil + } + + return pagination.ListPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, tok) +} diff --git a/internal/auth/ldap/service_list_accounts_refresh.go b/internal/auth/ldap/service_list_accounts_refresh.go new file mode 100644 index 0000000000..9fed00734a --- /dev/null +++ b/internal/auth/ldap/service_list_accounts_refresh.go @@ -0,0 +1,81 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsRefresh lists ldap accounts according to the page size +// and list token, filtering out entries that do not +// pass the filter item fn. It returns a new list token +// based on the old one, the grants hash, and the returned +// ldap accounts. +func ListAccountsRefresh( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "ldap.ListAccountsRefresh" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + rt, ok := tok.Subtype.(*listtoken.StartRefreshToken) + if !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a start-refresh token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(ctx, limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem)) + } + // Add the database read timeout to account for any creations missed due to concurrent + // transactions in the initial pagination phase. + ldapAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PreviousPhaseUpperBound.Add(-globals.RefreshReadLookbackDuration), opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, account := range ldapAccounts { + accounts = append(accounts, account) + } + return accounts, listTime, nil + } + listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) { + // Add the database read timeout to account for any deletes missed due to concurrent + // transactions in the original list pagination phase. + return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration)) + } + + return pagination.ListRefresh(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok) +} diff --git a/internal/auth/ldap/service_list_accounts_refresh_page.go b/internal/auth/ldap/service_list_accounts_refresh_page.go new file mode 100644 index 0000000000..194a6f7baa --- /dev/null +++ b/internal/auth/ldap/service_list_accounts_refresh_page.go @@ -0,0 +1,91 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package ldap + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsRefreshPage lists up to page size accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more accounts from the database, at page size chunks, to fill the page. +// It will start its paging based on the information in the token. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by update time descending (most recently updated first). +// Accounts may contain items that were already returned during the initial +// pagination phase. It also returns a list of any accounts deleted since the +// last response. +func ListAccountsRefreshPage( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "ldap.ListAccountsRefreshPage" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + rt, ok := tok.Subtype.(*listtoken.RefreshToken) + if !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a refresh token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(ctx, limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(ctx, lastPageItem)) + } else { + lastItem, err := tok.LastItem(ctx) + if err != nil { + return nil, time.Time{}, err + } + opts = append(opts, WithStartPageAfterItem(ctx, lastItem)) + } + // Add the database read timeout to account for any creations missed due to concurrent + // transactions in the original list pagination phase. + sAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PhaseLowerBound.Add(-globals.RefreshReadLookbackDuration), opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, account := range sAccounts { + accounts = append(accounts, account) + } + return accounts, listTime, nil + } + listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) { + // Add the database read timeout to account for any deletes missed due to concurrent + // transactions in the original list pagination phase. + return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration)) + } + + return pagination.ListRefreshPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok) +}