diff --git a/internal/authtoken/query.go b/internal/authtoken/query.go index c25b4947b7..cffb89c8aa 100644 --- a/internal/authtoken/query.go +++ b/internal/authtoken/query.go @@ -6,81 +6,5 @@ package authtoken const ( estimateCountAuthTokens = ` select reltuples::bigint as estimate from pg_class where oid in ('auth_token'::regclass) -` - listAuthTokensTemplate = ` - select at.public_id, - at.auth_account_id, - aa.scope_id, - aa.auth_method_id, - aa.iam_user_id, - aa.iam_user_scope_id, - at.create_time, - at.update_time, - at.approximate_last_access_time, - at.expiration_time, - at.status - from auth_token at - join auth_account aa on aa.public_id = at.auth_account_id - where aa.scope_id in @scope_ids -order by at.create_time desc, at.public_id desc - limit %d; -` - listAuthTokensPageTemplate = ` - select at.public_id, - at.auth_account_id, - aa.scope_id, - aa.auth_method_id, - aa.iam_user_id, - aa.iam_user_scope_id, - at.create_time, - at.update_time, - at.approximate_last_access_time, - at.expiration_time, - at.status - from auth_token at - join auth_account aa on aa.public_id = at.auth_account_id - where aa.scope_id in @scope_ids - and (at.create_time, at.public_id) < (@last_item_create_time, @last_item_id) -order by at.create_time desc, at.public_id desc - limit %d; -` - refreshAuthTokensTemplate = ` - select at.public_id, - at.auth_account_id, - aa.scope_id, - aa.auth_method_id, - aa.iam_user_id, - aa.iam_user_scope_id, - at.create_time, - at.update_time, - at.approximate_last_access_time, - at.expiration_time, - at.status - from auth_token at - join auth_account aa on aa.public_id = at.auth_account_id - where aa.scope_id in @scope_ids - and at.update_time > @updated_after_time -order by at.update_time desc, at.public_id desc - limit %d; -` - refreshAuthTokensPageTemplate = ` - select at.public_id, - at.auth_account_id, - aa.scope_id, - aa.auth_method_id, - aa.iam_user_id, - aa.iam_user_scope_id, - at.create_time, - at.update_time, - at.approximate_last_access_time, - at.expiration_time, - at.status - from auth_token at - join auth_account aa on aa.public_id = at.auth_account_id - where aa.scope_id in @scope_ids - and at.update_time > @updated_after_time - and (at.update_time, at.public_id) < (@last_item_update_time, @last_item_id) -order by at.update_time desc, at.public_id desc - limit %d; ` ) diff --git a/internal/authtoken/repository.go b/internal/authtoken/repository.go index 0818f8d5c5..f062730659 100644 --- a/internal/authtoken/repository.go +++ b/internal/authtoken/repository.go @@ -307,16 +307,17 @@ func (r *Repository) listAuthTokens(ctx context.Context, withScopeIds []string, } args := []any{sql.Named("scope_ids", withScopeIds)} - query := fmt.Sprintf(listAuthTokensTemplate, limit) + whereClause := "scope_id in @scope_ids" if opts.withStartPageAfterItem != nil { - query = fmt.Sprintf(listAuthTokensPageTemplate, limit) + whereClause = fmt.Sprintf("(create_time, public_id) < (@last_item_create_time, @last_item_id) and %s", whereClause) args = append(args, sql.Named("last_item_create_time", opts.withStartPageAfterItem.GetCreateTime()), sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), ) } - return r.queryAuthTokens(ctx, query, args) + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("create_time desc, public_id desc")} + return r.queryAuthTokens(ctx, whereClause, args, dbOpts...) } // listAuthTokensRefresh lists auth tokens in the given scopes and supports the @@ -344,46 +345,34 @@ func (r *Repository) listAuthTokensRefresh(ctx context.Context, updatedAfter tim sql.Named("scope_ids", withScopeIds), sql.Named("updated_after_time", timestamp.New(updatedAfter)), } - query := fmt.Sprintf(refreshAuthTokensTemplate, limit) + whereClause := "scope_id in @scope_ids and update_time > @updated_after_time" if opts.withStartPageAfterItem != nil { - query = fmt.Sprintf(refreshAuthTokensPageTemplate, limit) + whereClause = fmt.Sprintf("(update_time, public_id) < (@last_item_update_time, @last_item_id) and %s", whereClause) args = append(args, sql.Named("last_item_update_time", opts.withStartPageAfterItem.GetUpdateTime()), sql.Named("last_item_id", opts.withStartPageAfterItem.GetPublicId()), ) } - return r.queryAuthTokens(ctx, query, args) + dbOpts := []db.Option{db.WithLimit(limit), db.WithOrder("update_time desc, public_id desc")} + return r.queryAuthTokens(ctx, whereClause, args, dbOpts...) } -func (r *Repository) queryAuthTokens(ctx context.Context, query string, args []any) ([]*AuthToken, time.Time, error) { +func (r *Repository) queryAuthTokens(ctx context.Context, whereClause string, args []any, opt ...db.Option) ([]*AuthToken, time.Time, error) { const op = "authtoken.(Repository).queryAuthTokens" var transactionTimestamp time.Time var authTokens []*AuthToken if _, err := r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(rd db.Reader, w db.Writer) error { - rows, err := rd.Query(ctx, query, args) - if err != nil { - return err - } - defer rows.Close() - - // use the view, to bring in the required account columns. Just don't forget - // to convert them before returning them var atvs []*authTokenView - for rows.Next() { - var atv authTokenView - if err := rd.ScanRows(ctx, rows, &atv); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("scan row failed")) - } - atvs = append(atvs, &atv) - } - if err := rows.Err(); err != nil { - return errors.Wrap(ctx, err, op, errors.WithMsg("rows next error")) + err := rd.SearchWhere(ctx, &atvs, whereClause, args, opt...) + if err != nil { + return errors.Wrap(ctx, err, op) } - authTokens = make([]*AuthToken, 0, len(atvs)) for _, atv := range atvs { + // Remove encrypted token value before converting + atv.CtToken = nil authTokens = append(authTokens, atv.toAuthToken()) } transactionTimestamp, err = rd.Now(ctx)