diff --git a/internal/clientcache/internal/cache/search.go b/internal/clientcache/internal/cache/search.go index b1a9370402..305801facb 100644 --- a/internal/clientcache/internal/cache/search.go +++ b/internal/clientcache/internal/cache/search.go @@ -77,6 +77,10 @@ type SearchParams struct { Filter string // Max result set size is an override to the default max result set size MaxResultSetSize int + // Which column to sort results by, default is resource specific + SortBy SortBy + // Which direction to sort results by (asc, desc), default is resource specific + SortDirection SortDirection } // SearchResult returns the results from searching the cache. diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index 2d7e706f1c..56af6faf8d 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -9,7 +9,9 @@ import ( stderrors "errors" "fmt" "net/http" + "slices" "strconv" + "strings" "github.com/hashicorp/boundary/api" "github.com/hashicorp/boundary/api/aliases" @@ -57,6 +59,15 @@ const ( forceRefreshKey = "force_refresh" authTokenIdKey = "auth_token_id" maxResultSetSizeKey = "max_result_set_size" + sortByKey = "sort_by" + sortDirectionKey = "sort_direction" +) + +var ( + sortableColumnsForResource = map[cache.SearchableResource][]cache.SortBy{ + cache.Targets: []cache.SortBy{cache.SortByName}, + cache.Sessions: []cache.SortBy{cache.SortByCreatedTime}, + } ) func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshService *cache.RefreshService, logger hclog.Logger) (http.HandlerFunc, error) { @@ -84,6 +95,8 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe maxResultSetSizeInt, maxResultSetSizeIntErr := strconv.Atoi(maxResultSetSizeStr) query := q.Get(queryKey) filter := q.Get(filterKey) + sb := q.Get(sortByKey) + sd := q.Get(sortDirectionKey) searchableResource := cache.ToSearchableResource(resource) switch { @@ -121,6 +134,20 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe return } + sortBy, valid := parseSortBy(sb, searchableResource) + if !valid { + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("sort_by parameter %q not valid for resource %q", sb, searchableResource))) + writeError(w, fmt.Sprintf("sort_by parameter %q not valid for resource %q", sb, searchableResource), http.StatusBadRequest) + return + } + + sortDirection, valid := parseSortDirection(sd) + if !valid { + event.WriteError(ctx, op, errors.New(ctx, errors.InvalidParameter, op, fmt.Sprintf("sort_direction parameter %q not valid", sd))) + writeError(w, fmt.Sprintf("sort_direction parameter %q not valid ", sd), http.StatusBadRequest) + return + } + t, err := repo.LookupToken(reqCtx, authTokenId, cache.WithUpdateLastAccessedTime(true)) if err != nil || t == nil { if err != nil { @@ -175,6 +202,8 @@ func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshSe Query: query, Filter: filter, MaxResultSetSize: maxResultSetSizeInt, + SortBy: sortBy, + SortDirection: sortDirection, }) if err != nil { event.WriteError(ctx, op, err, event.WithInfoMsg("when performing search", "auth_token_id", authTokenId, "resource", searchableResource, "query", query, "filter", filter)) @@ -237,3 +266,36 @@ func writeUnsupportedError(w http.ResponseWriter) { } http.Error(w, string(b), http.StatusBadRequest) } + +// Parses a raw sort direction string into a cache.SortDirection +// Returns the sort direction and whether the provided direction was valid or not +func parseSortDirection(sd string) (cache.SortDirection, bool) { + sd = strings.ToLower(sd) + switch sd { + case "asc", "ascending": + return cache.Ascending, true + case "desc", "descending": + return cache.Descending, true + case "": + return cache.SortDirectionDefault, true + default: + return cache.SortDirectionDefault, false + } +} + +// Parses a raw column name to sort by into a cache.SortBy +// Returns the column to sort by and whether the provided column was valid or not +func parseSortBy(sb string, sr cache.SearchableResource) (cache.SortBy, bool) { + sb = strings.ToLower(sb) + by := cache.SortBy(sb) + + if by == cache.SortByDefault { + return cache.SortByDefault, true + } + + sortableBys, ok := sortableColumnsForResource[sr] + if !ok || !slices.Contains(sortableBys, by) { + return cache.SortByDefault, false + } + return by, true +} diff --git a/internal/clientcache/internal/daemon/search_handler_test.go b/internal/clientcache/internal/daemon/search_handler_test.go new file mode 100644 index 0000000000..57a49e9d1a --- /dev/null +++ b/internal/clientcache/internal/daemon/search_handler_test.go @@ -0,0 +1,60 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package daemon + +import ( + "testing" + + "github.com/hashicorp/boundary/internal/clientcache/internal/cache" + "github.com/stretchr/testify/assert" +) + +func TestParseSortBy(t *testing.T) { + testCases := []struct { + inputSb string + inputSr cache.SearchableResource + expectedValid bool + expectedSortBy cache.SortBy + }{ + {"name", cache.Targets, true, cache.SortByName}, + {"name", cache.Sessions, false, cache.SortByDefault}, + {"created_time", cache.Targets, false, cache.SortByDefault}, + {"created_time", cache.Sessions, true, cache.SortByCreatedTime}, + {"", cache.Targets, true, cache.SortByDefault}, + {"", cache.Sessions, true, cache.SortByDefault}, + {"ljkdhnsfg", cache.Targets, false, cache.SortByDefault}, + {"xcvbxcvb", cache.Sessions, false, cache.SortByDefault}, + {"nameĀ ", cache.Targets, false, cache.SortByDefault}, // Unicode no break space + {"name\u202e", cache.Targets, false, cache.SortByDefault}, // Unicode RtL override + {"\u202ename", cache.Targets, false, cache.SortByDefault}, // Unicode RtL override + } + for _, tc := range testCases { + actualSortBy, actualValid := parseSortBy(tc.inputSb, tc.inputSr) + assert.Equal(t, tc.expectedSortBy, actualSortBy) + assert.Equal(t, tc.expectedValid, actualValid) + } +} + +func TestParseSortDirection(t *testing.T) { + testCases := []struct { + inputSd string + expectedValid bool + expectedSortDirection cache.SortDirection + }{ + {"asc", true, cache.Ascending}, + {"ascending", true, cache.Ascending}, + {"desc", true, cache.Descending}, + {"descending", true, cache.Descending}, + {"", true, cache.SortDirectionDefault}, + {"asdasd", false, cache.SortDirectionDefault}, + {"ascĀ ", false, cache.SortDirectionDefault}, + {"name\u202e", false, cache.SortDirectionDefault}, + {"\u202ename", false, cache.SortDirectionDefault}, + } + for _, tc := range testCases { + actualSortDirection, actualValid := parseSortDirection(tc.inputSd) + assert.Equal(t, tc.expectedSortDirection, actualSortDirection) + assert.Equal(t, tc.expectedValid, actualValid) + } +}