From d5e8877f2cb8530d6517485ce88723e801a48aed Mon Sep 17 00:00:00 2001 From: Johan Brandhorst-Satzkorn Date: Thu, 21 Dec 2023 20:44:16 -0800 Subject: [PATCH] internal/auth/password: add account pagination --- internal/auth/password/account.go | 11 + internal/auth/password/options.go | 31 +- internal/auth/password/query.go | 3 + internal/auth/password/repository_account.go | 136 +++- .../auth/password/repository_account_test.go | 20 +- .../auth/password/service_list_accounts.go | 62 ++ .../service_list_accounts_ext_test.go | 613 ++++++++++++++++++ .../password/service_list_accounts_page.go | 79 +++ .../password/service_list_accounts_refresh.go | 81 +++ .../service_list_accounts_refresh_page.go | 91 +++ 10 files changed, 1104 insertions(+), 23 deletions(-) create mode 100644 internal/auth/password/service_list_accounts.go create mode 100644 internal/auth/password/service_list_accounts_ext_test.go create mode 100644 internal/auth/password/service_list_accounts_page.go create mode 100644 internal/auth/password/service_list_accounts_refresh.go create mode 100644 internal/auth/password/service_list_accounts_refresh_page.go diff --git a/internal/auth/password/account.go b/internal/auth/password/account.go index aa346a12b0..0adb1c7d98 100644 --- a/internal/auth/password/account.go +++ b/internal/auth/password/account.go @@ -7,6 +7,7 @@ import ( "context" "github.com/hashicorp/boundary/internal/auth/password/store" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/types/resource" @@ -99,3 +100,13 @@ func (a *Account) oplog(op oplog.OpType) oplog.Metadata { } return metadata } + +type deletedAccount struct { + PublicId string `gorm:"primary_key"` + DeleteTime *timestamp.Timestamp +} + +// TableName returns the tablename to override the default gorm table name +func (s *deletedAccount) TableName() string { + return "auth_password_account_deleted" +} diff --git a/internal/auth/password/options.go b/internal/auth/password/options.go index 5e1582abd8..b07a53fdc0 100644 --- a/internal/auth/password/options.go +++ b/internal/auth/password/options.go @@ -3,6 +3,8 @@ package password +import "github.com/hashicorp/boundary/internal/pagination" + // GetOpts - iterate the inbound Options and return a struct. func GetOpts(opt ...Option) options { opts := getDefaultOptions() @@ -17,16 +19,17 @@ type Option func(*options) // options = how options are represented type options struct { - withName string - withDescription string - WithLoginName string - withLimit int - withConfig Configuration - withPublicId string - password string - withPassword bool - withOrderByCreateTime bool - ascending bool + withName string + withDescription string + WithLoginName string + withLimit int + withConfig Configuration + withPublicId string + password string + withPassword bool + withOrderByCreateTime bool + ascending bool + withStartPageAfterItem pagination.Item } func getDefaultOptions() options { @@ -95,3 +98,11 @@ func WithOrderByCreateTime(ascending bool) Option { o.ascending = ascending } } + +// WithStartPageAfterItem is used to paginate over the results. +// The next page will start after the provided item. +func WithStartPageAfterItem(item pagination.Item) Option { + return func(o *options) { + o.withStartPageAfterItem = item + } +} diff --git a/internal/auth/password/query.go b/internal/auth/password/query.go index 07371a8e4d..3bfe6e5195 100644 --- a/internal/auth/password/query.go +++ b/internal/auth/password/query.go @@ -43,5 +43,8 @@ select * from auth_password_account where public_id = @public_id ); +` + estimateCountAccounts = ` +select sum(reltuples::bigint) as estimate from pg_class where oid in ('auth_password_account'::regclass) ` ) diff --git a/internal/auth/password/repository_account.go b/internal/auth/password/repository_account.go index 6bc11447e1..ebab9fe0f2 100644 --- a/internal/auth/password/repository_account.go +++ b/internal/auth/password/repository_account.go @@ -5,12 +5,15 @@ package password import ( "context" + "database/sql" "fmt" "regexp" "strings" + "time" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" @@ -145,24 +148,97 @@ func (r *Repository) LookupAccount(ctx context.Context, withPublicId string, opt return a, nil } -// ListAccounts in an auth method and supports WithLimit option. -func (r *Repository) ListAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, error) { - const op = "password.(Repository).ListAccounts" +// listAccounts returns a slice of accounts in the auth method. +// Supported options: +// - WithLimit which overrides the limit set in the Repository object +// - WithStartPageAfterItem which sets where to start listing from +func (r *Repository) listAccounts(ctx context.Context, withAuthMethodId string, opt ...Option) ([]*Account, time.Time, error) { + const op = "password.(Repository).listAccounts" if withAuthMethodId == "" { - return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") } opts := GetOpts(opt...) + limit := r.defaultLimit if opts.withLimit != 0 { // non-zero signals an override of the default limit for the repo. limit = opts.withLimit } - var accts []*Account - err := r.reader.SearchWhere(ctx, &accts, "auth_method_id = ?", []any{withAuthMethodId}, db.WithLimit(limit)) - if err != nil { - return nil, errors.Wrap(ctx, err, op) + + var args []any + whereClause := "auth_method_id = @auth_method_id" + args = append(args, sql.Named("auth_method_id", withAuthMethodId)) + + if opts.withStartPageAfterItem != nil { + whereClause = fmt.Sprintf("(create_time, public_id) < (@last_item_create_time, @last_item_id) and %s", whereClause) + args = append(args, + sql.Named("last_item_create_time", opts.withStartPageAfterItem.GetCreateTime()), + sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), + ) } - return accts, nil + + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("create_time desc, public_id desc")} + return r.queryAccounts(ctx, whereClause, args, dbOpts...) +} + +// listAccountsRefresh returns a slice of accounts in the auth method. +// Supported options: +// - WithLimit which overrides the limit set in the Repository object +// - WithStartPageAfterItem which sets where to start listing from +func (r *Repository) listAccountsRefresh(ctx context.Context, withAuthMethodId string, updatedAfter time.Time, opt ...Option) ([]*Account, time.Time, error) { + const op = "password.(Repository).listAccountsRefresh" + switch { + case withAuthMethodId == "": + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing auth method id") + case updatedAfter.IsZero(): + return nil, time.Time{}, errors.New(ctx, errors.InvalidParameter, op, "missing updated after time") + } + + opts := GetOpts(opt...) + + limit := r.defaultLimit + if opts.withLimit != 0 { + // non-zero signals an override of the default limit for the repo. + limit = opts.withLimit + } + + var args []any + whereClause := "update_time > @updated_after_time and auth_method_id = @auth_method_id" + args = append(args, + sql.Named("updated_after_time", timestamp.New(updatedAfter)), + sql.Named("auth_method_id", withAuthMethodId), + ) + + if opts.withStartPageAfterItem != nil { + whereClause = fmt.Sprintf("(update_time, public_id) < (@last_item_update_time, @last_item_id) and %s", whereClause) + args = append(args, + sql.Named("last_item_update_time", opts.withStartPageAfterItem.GetUpdateTime()), + sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), + ) + } + + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("update_time desc, public_id desc")} + return r.queryAccounts(ctx, whereClause, args, dbOpts...) +} + +func (r *Repository) queryAccounts(ctx context.Context, whereClause string, args []any, opt ...db.Option) ([]*Account, time.Time, error) { + const op = "password.(Repository).queryAccounts" + + var accts []*Account + var transactionTimestamp time.Time + if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(rd db.Reader, w db.Writer) error { + var inAccts []*Account + if err := rd.SearchWhere(ctx, &inAccts, whereClause, args, opt...); err != nil { + return errors.Wrap(ctx, err, op) + } + accts = inAccts + var err error + transactionTimestamp, err = rd.Now(ctx) + return err + }); err != nil { + return nil, time.Time{}, errors.Wrap(ctx, err, op) + } + return accts, transactionTimestamp, nil } // DeleteAccount deletes the account for the provided id from the repository returning a count of the @@ -326,3 +402,45 @@ func (r *Repository) UpdateAccount(ctx context.Context, scopeId string, a *Accou return returnedAccount, rowsUpdated, nil } + +// listDeletedAccountIds lists the public IDs of any accounts deleted since the timestamp provided, +// and the timestamp of the transaction within which the accounts were listed. +func (r *Repository) listDeletedAccountIds(ctx context.Context, since time.Time) ([]string, time.Time, error) { + const op = "password.(Repository).listDeletedAccountIds" + var deleteAccounts []*deletedAccount + var transactionTimestamp time.Time + if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, _ db.Writer) error { + if err := r.SearchWhere(ctx, &deleteAccounts, "delete_time >= ?", []any{since}); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("failed to query deleted accounts")) + } + var err error + transactionTimestamp, err = r.Now(ctx) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("failed to get transaction timestamp")) + } + return nil + }); err != nil { + return nil, time.Time{}, err + } + var accountIds []string + for _, a := range deleteAccounts { + accountIds = append(accountIds, a.PublicId) + } + return accountIds, transactionTimestamp, nil +} + +// estimatedAccountCount returns an estimate of the total number of accounts. +func (r *Repository) estimatedAccountCount(ctx context.Context) (int, error) { + const op = "password.(Repository).estimatedAccountCount" + rows, err := r.reader.Query(ctx, estimateCountAccounts, nil) + if err != nil { + return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts")) + } + var count int + for rows.Next() { + if err := r.reader.ScanRows(ctx, rows, &count); err != nil { + return 0, errors.Wrap(ctx, err, op, errors.WithMsg("failed to query ldap account counts")) + } + } + return count, nil +} diff --git a/internal/auth/password/repository_account_test.go b/internal/auth/password/repository_account_test.go index 92f1cd16a6..a56c8342c8 100644 --- a/internal/auth/password/repository_account_test.go +++ b/internal/auth/password/repository_account_test.go @@ -5,19 +5,24 @@ package password import ( "context" + "slices" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/hashicorp/boundary/globals" "github.com/hashicorp/boundary/internal/auth/password/store" "github.com/hashicorp/boundary/internal/db" dbassert "github.com/hashicorp/boundary/internal/db/assert" + "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/oplog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestCheckLoginName(t *testing.T) { @@ -471,6 +476,7 @@ func TestRepository_ListAccounts(t *testing.T) { accounts1 := TestMultipleAccounts(t, conn, authMethods[0].GetPublicId(), 3) accounts2 := TestMultipleAccounts(t, conn, authMethods[1].GetPublicId(), 4) _ = accounts2 + slices.Reverse(accounts1) tests := []struct { name string @@ -483,7 +489,7 @@ func TestRepository_ListAccounts(t *testing.T) { { name: "With no auth method id", wantIsErr: errors.InvalidParameter, - wantErrMsg: "password.(Repository).ListAccounts: missing auth method id: parameter violation: error #100", + wantErrMsg: "password.(Repository).listAccounts: missing auth method id: parameter violation: error #100", }, { name: "With no accounts id", @@ -504,14 +510,17 @@ func TestRepository_ListAccounts(t *testing.T) { repo, err := NewRepository(context.Background(), rw, rw, kms) assert.NoError(err) require.NotNil(repo) - got, err := repo.ListAccounts(context.Background(), tt.in, tt.opts...) + got, ttime, err := repo.listAccounts(context.Background(), tt.in, tt.opts...) if tt.wantIsErr != 0 { assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err) assert.Equal(tt.wantErrMsg, err.Error()) return } require.NoError(err) - assert.EqualValues(tt.want, got) + assert.Empty(cmp.Diff(tt.want, got, cmpopts.IgnoreUnexported(Account{}, store.Account{}, timestamp.Timestamp{}, timestamppb.Timestamp{}))) + // Transaction timestamp should be within ~10 seconds of now + assert.True(time.Now().Before(ttime.Add(10 * time.Second))) + assert.True(time.Now().After(ttime.Add(-10 * time.Second))) }) } } @@ -581,9 +590,12 @@ func TestRepository_ListAccounts_Limits(t *testing.T) { repo, err := NewRepository(context.Background(), rw, rw, kms, tt.repoOpts...) assert.NoError(err) require.NotNil(repo) - got, err := repo.ListAccounts(context.Background(), am.GetPublicId(), tt.listOpts...) + got, ttime, err := repo.listAccounts(context.Background(), am.GetPublicId(), tt.listOpts...) require.NoError(err) assert.Len(got, tt.wantLen) + // Transaction timestamp should be within ~10 seconds of now + assert.True(time.Now().Before(ttime.Add(10 * time.Second))) + assert.True(time.Now().After(ttime.Add(-10 * time.Second))) }) } } diff --git a/internal/auth/password/service_list_accounts.go b/internal/auth/password/service_list_accounts.go new file mode 100644 index 0000000000..fd693a26a7 --- /dev/null +++ b/internal/auth/password/service_list_accounts.go @@ -0,0 +1,62 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package password + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/pagination" +) + +// ListAccounts lists up to page size password accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more accounts from the database, at page size chunks, to fill the page. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by create time descending (most recently created first). +func ListAccounts( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "password.ListAccounts" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(lastPageItem)) + } + passwordAccts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, acct := range passwordAccts { + accounts = append(accounts, acct) + } + return accounts, listTime, nil + } + + return pagination.List(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount) +} diff --git a/internal/auth/password/service_list_accounts_ext_test.go b/internal/auth/password/service_list_accounts_ext_test.go new file mode 100644 index 0000000000..31fb7ca162 --- /dev/null +++ b/internal/auth/password/service_list_accounts_ext_test.go @@ -0,0 +1,613 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package password_test + +import ( + "context" + "slices" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/auth/password" + "github.com/hashicorp/boundary/internal/auth/password/store" + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/types/resource" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +func TestService_ListAccounts(t *testing.T) { + // Set database read timeout to avoid duplicates in response + oldReadTimeout := globals.RefreshReadLookbackDuration + globals.RefreshReadLookbackDuration = 0 + t.Cleanup(func() { + globals.RefreshReadLookbackDuration = oldReadTimeout + }) + + ctx := context.Background() + conn, _ := db.TestSetup(t, "postgres") + sqlDb, err := conn.SqlDB(ctx) + require.NoError(t, err) + wrapper := db.TestWrapper(t) + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + rw := db.New(conn) + + kms := kms.TestKms(t, conn, wrapper) + iamRepo := iam.TestRepo(t, conn, wrapper) + org, _ := iam.TestScopes(t, iamRepo) + + repo, err := password.NewRepository(context.Background(), rw, rw, kms) + require.NoError(t, err) + + authMethod := password.TestAuthMethod(t, conn, org.GetPublicId()) + passAccts := password.TestMultipleAccounts(t, conn, authMethod.GetPublicId(), 5) + var accounts []auth.Account + for _, acct := range passAccts { + accounts = append(accounts, acct) + } + // since we sort by create time descending, we need to reverse the slice + slices.Reverse(accounts) + + // Run analyze to update host estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + cmpOpts := []cmp.Option{ + cmpopts.IgnoreUnexported( + password.Account{}, + store.Account{}, + timestamp.Timestamp{}, + timestamppb.Timestamp{}, + ), + } + + t.Run("ListAccounts validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := password.ListAccounts(ctx, nil, 1, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := password.ListAccounts(ctx, []byte("some hash"), 0, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := password.ListAccounts(ctx, []byte("some hash"), -1, filterFunc, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + _, err := password.ListAccounts(ctx, []byte("some hash"), 1, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := password.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err := password.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + }) + t.Run("ListAccountsPage validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("wrong token type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have a pagination token component") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + t.Run("ListAccountsRefresh validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing auth method ID", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewStartRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + t.Run("ListAccountsRefreshPage validation", func(t *testing.T) { + t.Parallel() + t.Run("missing grants hash", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, nil, 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing grants hash") + }) + t.Run("zero page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 0, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("negative page size", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), -1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "page size must be at least 1") + }) + t.Run("nil filter func", func(t *testing.T) { + t.Parallel() + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, nil, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing filter item callback") + }) + t.Run("nil token", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, nil, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing token") + }) + t.Run("wrong token type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewPagination(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), "some-id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have a refresh token component") + }) + t.Run("nil repo", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, nil, authMethod.GetPublicId()) + require.ErrorContains(t, err, "missing repo") + }) + t.Run("missing credential store id", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Account, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, "") + require.ErrorContains(t, err, "missing auth method ID") + }) + t.Run("wrong token resource type", func(t *testing.T) { + t.Parallel() + filterFunc := func(_ context.Context, a auth.Account) (bool, error) { + return true, nil + } + tok, err := listtoken.NewRefresh(ctx, fiveDaysAgo, resource.Target, []byte("some hash"), fiveDaysAgo, fiveDaysAgo, fiveDaysAgo, "some other id", fiveDaysAgo) + require.NoError(t, err) + _, err = password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, tok, repo, authMethod.GetPublicId()) + require.ErrorContains(t, err, "token did not have an account resource type") + }) + }) + + t.Run("simple pagination", func(t *testing.T) { + filterFunc := func(context.Context, auth.Account) (bool, error) { + return true, nil + } + resp, err := password.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 5) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...)) + + resp2, err := password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 5) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 1) + require.Empty(t, cmp.Diff(resp2.Items[0], accounts[1], cmpOpts...)) + + resp3, err := password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 5) + require.Empty(t, resp3.DeletedIds) + require.Len(t, resp3.Items, 1) + require.Empty(t, cmp.Diff(resp3.Items[0], accounts[2], cmpOpts...)) + + resp4, err := password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp4.CompleteListing) + require.Equal(t, resp4.EstimatedItemCount, 5) + require.Empty(t, resp4.DeletedIds) + require.Len(t, resp4.Items, 1) + require.Empty(t, cmp.Diff(resp4.Items[0], accounts[3], cmpOpts...)) + + resp5, err := password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp5.CompleteListing) + require.Equal(t, resp5.EstimatedItemCount, 5) + require.Empty(t, resp5.DeletedIds) + require.Len(t, resp5.Items, 1) + require.Empty(t, cmp.Diff(resp5.Items[0], accounts[4], cmpOpts...)) + + // Finished initial pagination phase, request refresh + // Expect no results. + resp6, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp5.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp6.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp6.CompleteListing) + require.Equal(t, resp6.EstimatedItemCount, 5) + require.Empty(t, resp6.DeletedIds) + require.Empty(t, resp6.Items) + + // Create some new accounts + account1 := password.TestAccount(t, conn, authMethod.GetPublicId(), "some-id-1") + account2 := password.TestAccount(t, conn, authMethod.GetPublicId(), "some-id-2") + t.Cleanup(func() { + repo.DeleteAccount(ctx, org.PublicId, account1.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, org.PublicId, account2.GetPublicId()) + require.NoError(t, err) + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + }) + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // Refresh again, should get account2 + resp7, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp6.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp7.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp7.CompleteListing) + require.Equal(t, resp7.EstimatedItemCount, 7) + require.Empty(t, resp7.DeletedIds) + require.Len(t, resp7.Items, 1) + require.Empty(t, cmp.Diff(resp7.Items[0], account2, cmpOpts...)) + + // Refresh again, should get account1 + resp8, err := password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp7.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp8.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp8.CompleteListing) + require.Equal(t, resp8.EstimatedItemCount, 7) + require.Empty(t, resp8.DeletedIds) + require.Len(t, resp8.Items, 1) + require.Empty(t, cmp.Diff(resp8.Items[0], account1, cmpOpts...)) + + // Refresh again, should get no results + resp9, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp8.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp9.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp9.CompleteListing) + require.Equal(t, resp9.EstimatedItemCount, 7) + require.Empty(t, resp9.DeletedIds) + require.Empty(t, resp9.Items) + }) + + t.Run("simple pagination with aggressive filtering", func(t *testing.T) { + filterFunc := func(ctx context.Context, a auth.Account) (bool, error) { + return a.GetPublicId() == accounts[1].GetPublicId() || + a.GetPublicId() == accounts[len(accounts)-1].GetPublicId(), nil + } + resp, err := password.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 5) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[1], cmpOpts...)) + + resp2, err := password.ListAccountsPage(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp2.ListToken) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 5) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 1) + require.Empty(t, cmp.Diff(resp2.Items[0], accounts[len(accounts)-1], cmpOpts...)) + + // request a refresh, nothing should be returned + resp3, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 5) + require.Empty(t, resp3.DeletedIds) + require.Empty(t, resp3.Items) + + // Create some new accounts + account1 := password.TestAccount(t, conn, authMethod.GetPublicId(), "some-id-1") + account2 := password.TestAccount(t, conn, authMethod.GetPublicId(), "some-id-2") + account3 := password.TestAccount(t, conn, authMethod.GetPublicId(), "some-id-3") + t.Cleanup(func() { + repo.DeleteAccount(ctx, org.PublicId, account1.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, org.PublicId, account2.GetPublicId()) + require.NoError(t, err) + repo.DeleteAccount(ctx, org.PublicId, account3.GetPublicId()) + require.NoError(t, err) + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + }) + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + filterFunc = func(_ context.Context, a auth.Account) (bool, error) { + return a.GetPublicId() == account3.GetPublicId() || + a.GetPublicId() == account1.GetPublicId(), nil + } + // Refresh again, should get account3 + resp4, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp3.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp4.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp4.CompleteListing) + require.Equal(t, resp4.EstimatedItemCount, 8) + require.Empty(t, resp4.DeletedIds) + require.Len(t, resp4.Items, 1) + require.Empty(t, cmp.Diff(resp4.Items[0], account3, cmpOpts...)) + + // Refresh again, should get account1 + resp5, err := password.ListAccountsRefreshPage(ctx, []byte("some hash"), 1, filterFunc, resp4.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp5.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp5.CompleteListing) + require.Equal(t, resp5.EstimatedItemCount, 8) + require.Empty(t, resp5.DeletedIds) + require.Len(t, resp5.Items, 1) + require.Empty(t, cmp.Diff(resp5.Items[0], account1, cmpOpts...)) + }) + + t.Run("simple pagination with deletion", func(t *testing.T) { + filterFunc := func(context.Context, auth.Account) (bool, error) { + return true, nil + } + deletedAccountId := accounts[0].GetPublicId() + repo.DeleteAccount(ctx, org.PublicId, deletedAccountId) + require.NoError(t, err) + accounts = accounts[1:] + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + resp, err := password.ListAccounts(ctx, []byte("some hash"), 1, filterFunc, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.NotNil(t, resp.ListToken) + require.Equal(t, resp.ListToken.GrantsHash, []byte("some hash")) + require.False(t, resp.CompleteListing) + require.Equal(t, resp.EstimatedItemCount, 4) + require.Empty(t, resp.DeletedIds) + require.Len(t, resp.Items, 1) + require.Empty(t, cmp.Diff(resp.Items[0], accounts[0], cmpOpts...)) + + // request remaining results + resp2, err := password.ListAccountsPage(ctx, []byte("some hash"), 3, filterFunc, resp.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp2.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp2.CompleteListing) + require.Equal(t, resp2.EstimatedItemCount, 4) + require.Empty(t, resp2.DeletedIds) + require.Len(t, resp2.Items, 3) + require.Empty(t, cmp.Diff(resp2.Items, accounts[1:], cmpOpts...)) + + deletedAccountId = accounts[0].GetPublicId() + repo.DeleteAccount(ctx, org.PublicId, deletedAccountId) + require.NoError(t, err) + accounts = accounts[1:] + + // Run analyze to update count estimate + _, err = sqlDb.ExecContext(ctx, "analyze") + require.NoError(t, err) + + // request a refresh, nothing should be returned except the deleted id + resp3, err := password.ListAccountsRefresh(ctx, []byte("some hash"), 1, filterFunc, resp2.ListToken, repo, authMethod.GetPublicId()) + require.NoError(t, err) + require.Equal(t, resp3.ListToken.GrantsHash, []byte("some hash")) + require.True(t, resp3.CompleteListing) + require.Equal(t, resp3.EstimatedItemCount, 3) + require.Contains(t, resp3.DeletedIds, deletedAccountId) + require.Empty(t, resp3.Items) + }) +} diff --git a/internal/auth/password/service_list_accounts_page.go b/internal/auth/password/service_list_accounts_page.go new file mode 100644 index 0000000000..f4844d1db6 --- /dev/null +++ b/internal/auth/password/service_list_accounts_page.go @@ -0,0 +1,79 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package password + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsPage lists up to page size password accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more password accounts from the database, at page size chunks, to fill the page. +// It will start its paging based on the information in the token. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by create time descending (most recently created first). +func ListAccountsPage( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "password.ListAccountsPage" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + if _, ok := tok.Subtype.(*listtoken.PaginationToken); !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a pagination token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(lastPageItem)) + } else { + lastItem, err := tok.LastItem(ctx) + if err != nil { + return nil, time.Time{}, err + } + opts = append(opts, WithStartPageAfterItem(lastItem)) + } + passwordAccounts, listTime, err := repo.listAccounts(ctx, authMethodId, opts...) + if err != nil { + return nil, time.Time{}, err + } + var accts []auth.Account + for _, acct := range passwordAccounts { + accts = append(accts, acct) + } + return accts, listTime, nil + } + + return pagination.ListPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, tok) +} diff --git a/internal/auth/password/service_list_accounts_refresh.go b/internal/auth/password/service_list_accounts_refresh.go new file mode 100644 index 0000000000..59a14af0b8 --- /dev/null +++ b/internal/auth/password/service_list_accounts_refresh.go @@ -0,0 +1,81 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package password + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsRefresh lists password accounts according to the page size +// and list token, filtering out entries that do not +// pass the filter item fn. It returns a new list token +// based on the old one, the grants hash, and the returned +// password accounts. +func ListAccountsRefresh( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "password.ListAccountsRefresh" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + rt, ok := tok.Subtype.(*listtoken.StartRefreshToken) + if !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a start-refresh token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(lastPageItem)) + } + // Add the database read timeout to account for any creations missed due to concurrent + // transactions in the initial pagination phase. + passwordAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PreviousPhaseUpperBound.Add(-globals.RefreshReadLookbackDuration), opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, account := range passwordAccounts { + accounts = append(accounts, account) + } + return accounts, listTime, nil + } + listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) { + // Add the database read timeout to account for any deletes missed due to concurrent + // transactions in the original list pagination phase. + return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration)) + } + + return pagination.ListRefresh(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok) +} diff --git a/internal/auth/password/service_list_accounts_refresh_page.go b/internal/auth/password/service_list_accounts_refresh_page.go new file mode 100644 index 0000000000..c7a06b857a --- /dev/null +++ b/internal/auth/password/service_list_accounts_refresh_page.go @@ -0,0 +1,91 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package password + +import ( + "context" + "time" + + "github.com/hashicorp/boundary/globals" + "github.com/hashicorp/boundary/internal/auth" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/listtoken" + "github.com/hashicorp/boundary/internal/pagination" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// ListAccountsRefreshPage lists up to page size accounts, filtering out entries that +// do not pass the filter item function. It will automatically request +// more accounts from the database, at page size chunks, to fill the page. +// It will start its paging based on the information in the token. +// It returns a new list token used to continue pagination or refresh items. +// Accounts are ordered by update time descending (most recently updated first). +// Accounts may contain items that were already returned during the initial +// pagination phase. It also returns a list of any accounts deleted since the +// last response. +func ListAccountsRefreshPage( + ctx context.Context, + grantsHash []byte, + pageSize int, + filterItemFn pagination.ListFilterFunc[auth.Account], + tok *listtoken.Token, + repo *Repository, + authMethodId string, +) (*pagination.ListResponse[auth.Account], error) { + const op = "password.ListAccountsRefreshPage" + + switch { + case len(grantsHash) == 0: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + case pageSize < 1: + return nil, errors.New(ctx, errors.InvalidParameter, op, "page size must be at least 1") + case filterItemFn == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing filter item callback") + case tok == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing token") + case repo == nil: + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing repo") + case authMethodId == "": + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth method ID") + case tok.ResourceType != resource.Account: + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have an account resource type") + } + rt, ok := tok.Subtype.(*listtoken.RefreshToken) + if !ok { + return nil, errors.New(ctx, errors.InvalidParameter, op, "token did not have a refresh token component") + } + + listItemsFn := func(ctx context.Context, lastPageItem auth.Account, limit int) ([]auth.Account, time.Time, error) { + opts := []Option{ + WithLimit(limit), + } + if lastPageItem != nil { + opts = append(opts, WithStartPageAfterItem(lastPageItem)) + } else { + lastItem, err := tok.LastItem(ctx) + if err != nil { + return nil, time.Time{}, err + } + opts = append(opts, WithStartPageAfterItem(lastItem)) + } + // Add the database read timeout to account for any creations missed due to concurrent + // transactions in the original list pagination phase. + sAccounts, listTime, err := repo.listAccountsRefresh(ctx, authMethodId, rt.PhaseLowerBound.Add(-globals.RefreshReadLookbackDuration), opts...) + if err != nil { + return nil, time.Time{}, err + } + var accounts []auth.Account + for _, account := range sAccounts { + accounts = append(accounts, account) + } + return accounts, listTime, nil + } + listDeletedIdsFn := func(ctx context.Context, since time.Time) ([]string, time.Time, error) { + // Add the database read timeout to account for any deletes missed due to concurrent + // transactions in the original list pagination phase. + return repo.listDeletedAccountIds(ctx, since.Add(-globals.RefreshReadLookbackDuration)) + } + + return pagination.ListRefreshPage(ctx, grantsHash, pageSize, filterItemFn, listItemsFn, repo.estimatedAccountCount, listDeletedIdsFn, tok) +}