diff --git a/internal/clientcache/internal/cache/options.go b/internal/clientcache/internal/cache/options.go index 41554ca3bf..61455fc01b 100644 --- a/internal/clientcache/internal/cache/options.go +++ b/internal/clientcache/internal/cache/options.go @@ -5,10 +5,21 @@ package cache import ( stderrors "errors" + "fmt" + "regexp" + "slices" "github.com/hashicorp/go-dbw" ) +// safeSortColumnRegex contains characters that could break SQL ORDER BY clauses +var safeSortColumnRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + +var ( + errInvalidSortColumn = stderrors.New("not allowed for this resource type") + errUnsafeSortColumn = stderrors.New("contains unsafe characters") +) + type testRefreshWaitChs struct { firstSempahore chan struct{} secondSemaphore chan struct{} @@ -26,6 +37,8 @@ type options struct { withMaxResultSetSize int withTestRefreshWaitChs *testRefreshWaitChs withUseNonPagedListing bool + withSortBy SortBy // validated DB column name + withSortDirection SortDirection // "asc" or "desc" } // Option - how options are passed as args @@ -139,3 +152,26 @@ func WithUseNonPagedListing(b bool) Option { return nil } } + +// WithSort configures sorting for query results. +// Empty sortBy is silently ignored. Empty direction defaults to ascending in the repository layer. +// Validates column against sortableColumns and rejects SQL-unsafe characters. +func WithSort(sortBy SortBy, direction SortDirection, sortableColumns []SortBy) Option { + return func(o *options) error { + // ignore empty sortBy + if sortBy == SortByDefault { + return nil + } + + switch { + case !slices.Contains(sortableColumns, sortBy): + return fmt.Errorf("invalid sort column %q: %w", sortBy, errInvalidSortColumn) + case !safeSortColumnRegex.MatchString(string(sortBy)): + return fmt.Errorf("invalid sort column %q: %w", sortBy, errUnsafeSortColumn) + } + + o.withSortBy = sortBy + o.withSortDirection = direction + return nil + } +} diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 7a511b9270..3f1a08823b 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -126,4 +126,105 @@ func Test_GetOpts(t *testing.T) { testOpts.withUseNonPagedListing = true assert.Equal(t, opts, testOpts) }) + t.Run("WithSort-default-sortby-ignored", func(t *testing.T) { + opts, err := getOpts(WithSort(SortByDefault, Ascending, []SortBy{SortByName})) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSort-empty-sortby-ignored", func(t *testing.T) { + opts, err := getOpts(WithSort("", Ascending, []SortBy{SortByName})) + require.NoError(t, err) + testOpts := getDefaultOptions() + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSort-valid-name-ascending", func(t *testing.T) { + opts, err := getOpts(WithSort(SortByName, Ascending, []SortBy{SortByName, SortByCreatedAt})) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withSortBy = SortByName + testOpts.withSortDirection = Ascending + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSort-valid-created_at-descending", func(t *testing.T) { + opts, err := getOpts(WithSort(SortByCreatedAt, Descending, []SortBy{SortByCreatedAt})) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withSortBy = SortByCreatedAt + testOpts.withSortDirection = Descending + assert.Equal(t, opts, testOpts) + }) + t.Run("WithSort-column-not-in-sortable-list", func(t *testing.T) { + _, err := getOpts(WithSort(SortByName, Ascending, []SortBy{SortByCreatedAt})) + require.Error(t, err) + assert.ErrorContains(t, err, errInvalidSortColumn.Error()) + }) + t.Run("WithSort-empty-sortable-columns", func(t *testing.T) { + _, err := getOpts(WithSort(SortByName, Ascending, []SortBy{})) + require.Error(t, err) + assert.ErrorContains(t, err, errInvalidSortColumn.Error()) + }) + t.Run("WithSort-nil-sortable-columns", func(t *testing.T) { + _, err := getOpts(WithSort(SortByName, Ascending, nil)) + require.Error(t, err) + assert.ErrorContains(t, err, errInvalidSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-semicolon", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name; DROP TABLE"), Ascending, []SortBy{SortBy("name; DROP TABLE")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-quote", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name'--"), Ascending, []SortBy{SortBy("name'--")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-double-quote", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name\"--"), Ascending, []SortBy{SortBy("name\"--")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-backslash", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name\\x00"), Ascending, []SortBy{SortBy("name\\x00")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-comma", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name,other"), Ascending, []SortBy{SortBy("name,other")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-parenthesis", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name("), Ascending, []SortBy{SortBy("name(")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-space", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name "), Ascending, []SortBy{SortBy("name ")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-tab", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name\t"), Ascending, []SortBy{SortBy("name\t")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-newline", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name\n"), Ascending, []SortBy{SortBy("name\n")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-unsafe-chars-dash", func(t *testing.T) { + _, err := getOpts(WithSort(SortBy("name-col"), Ascending, []SortBy{SortBy("name-col")})) + require.Error(t, err) + assert.ErrorContains(t, err, errUnsafeSortColumn.Error()) + }) + t.Run("WithSort-default-direction", func(t *testing.T) { + opts, err := getOpts(WithSort(SortByName, SortDirectionDefault, []SortBy{SortByName})) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.withSortBy = SortByName + testOpts.withSortDirection = SortDirectionDefault + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/clientcache/internal/cache/search.go b/internal/clientcache/internal/cache/search.go index 305801facb..285b7bbc02 100644 --- a/internal/clientcache/internal/cache/search.go +++ b/internal/clientcache/internal/cache/search.go @@ -25,6 +25,15 @@ const ( SortByCreatedTime SortBy = "created_time" ) +// Valid returns true if the SortBy value is a known good value +func (s SortBy) Valid() bool { + switch s { + case SortByDefault, SortByName, SortByCreatedAt: + return true + } + return false +} + type SortDirection string const ( @@ -33,6 +42,15 @@ const ( Descending SortDirection = "desc" ) +// Valid returns true if the SortDirection value is a known good value +func (d SortDirection) Valid() bool { + switch d { + case SortDirectionDefault, Ascending, Descending: + return true + } + return false +} + type SearchableResource string const ( @@ -136,6 +154,7 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er } in.Targets = finalResults }, + sortableColumns: []SortBy{SortByName, SortByCreatedAt}, }, Sessions: &resourceSearchFns[*sessions.Session]{ list: repo.ListSessions, @@ -149,6 +168,7 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er } in.Sessions = finalResults }, + sortableColumns: []SortBy{SortByCreatedAt}, }, ImplicitScopes: &resourceSearchFns[*scopes.Scope]{ list: repo.ListImplicitScopes, @@ -204,6 +224,10 @@ func (s *SearchService) Search(ctx context.Context, params SearchParams) (*Searc return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid resource") case params.AuthTokenId == "": return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token id") + case !params.SortBy.Valid(): + return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid sort by value") + case !params.SortDirection.Valid(): + return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid sort direction value") } rSearcher, ok := s.searchableResources[params.Resource] if !ok { @@ -231,6 +255,8 @@ type resourceSearchFns[T any] struct { // filter takes results and a ready-to-use evaluator and filters the items // in the result filter func(*SearchResult, *bexpr.Evaluator) + // sortableColumns is a list of columns that can be used for sorting + sortableColumns []SortBy } // resourceSearcher is an interface that only resourceSearchFns[T] is expected @@ -249,13 +275,18 @@ type resourceSearcher interface { func (l *resourceSearchFns[T]) search(ctx context.Context, p SearchParams) (*SearchResult, error) { const op = "cache.(resourceSearchFns).search" + opts := []Option{WithMaxResultSetSize(p.MaxResultSetSize)} + if p.SortBy != SortByDefault { + opts = append(opts, WithSort(p.SortBy, p.SortDirection, l.sortableColumns)) + } + var found *SearchResult var err error switch p.Query { case "": - found, err = l.list(ctx, p.AuthTokenId, WithMaxResultSetSize(p.MaxResultSetSize)) + found, err = l.list(ctx, p.AuthTokenId, opts...) default: - found, err = l.query(ctx, p.AuthTokenId, p.Query, WithMaxResultSetSize(p.MaxResultSetSize)) + found, err = l.query(ctx, p.AuthTokenId, p.Query, opts...) } if err != nil { return nil, errors.Wrap(ctx, err, op) diff --git a/internal/clientcache/internal/cache/search_test.go b/internal/clientcache/internal/cache/search_test.go index 1440ae3399..8ff4b097cd 100644 --- a/internal/clientcache/internal/cache/search_test.go +++ b/internal/clientcache/internal/cache/search_test.go @@ -442,3 +442,166 @@ func TestSearch(t *testing.T) { assert.Equal(t, &SearchResult{Targets: []*targets.Target{}}, got) }) } + +func TestSortByValid(t *testing.T) { + cases := []struct { + sortBy SortBy + valid bool + }{ + {SortByDefault, true}, + {SortBy(""), true}, + {SortByName, true}, + {SortByCreatedAt, true}, + {SortBy("unknown"), false}, + {SortBy("id"), false}, + {SortBy("invalid_column"), false}, + {SortBy("name; DROP TABLE"), false}, + {SortBy("name'--"), false}, + {SortBy("name\"--"), false}, + {SortBy("name\\x00"), false}, + {SortBy("name,other"), false}, + {SortBy("name ("), false}, + {SortBy("name)"), false}, + {SortBy("name\t"), false}, + {SortBy("name\n"), false}, + {SortBy("name\r"), false}, + {SortBy("name "), false}, + } + + for _, tc := range cases { + t.Run(string(tc.sortBy), func(t *testing.T) { + assert.Equal(t, tc.valid, tc.sortBy.Valid()) + }) + } +} + +func TestSortDirectionValid(t *testing.T) { + cases := []struct { + direction SortDirection + valid bool + }{ + {SortDirectionDefault, true}, + {SortDirection(""), true}, + {Ascending, true}, + {Descending, true}, + {SortDirection("ASC"), false}, + {SortDirection("DESC"), false}, + {SortDirection("invalid"), false}, + } + + for _, tc := range cases { + t.Run(string(tc.direction), func(t *testing.T) { + assert.Equal(t, tc.valid, tc.direction.Valid()) + }) + } +} + +func TestSearch_Sorting(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + at := &AuthToken{ + Id: "at_sort", + UserId: "u_sort", + } + { + u := &user{Id: at.UserId, Address: "address"} + rw := db.New(s) + require.NoError(t, rw.Create(ctx, u)) + require.NoError(t, rw.Create(ctx, at)) + + targets := []*Target{ + {FkUserId: u.Id, Id: "t_1", Name: "alpha", Type: "tcp", Item: `{"id": "t_1", "name": "alpha", "type": "tcp"}`}, + {FkUserId: u.Id, Id: "t_2", Name: "charlie", Type: "tcp", Item: `{"id": "t_2", "name": "charlie", "type": "tcp"}`}, + {FkUserId: u.Id, Id: "t_3", Name: "bravo", Type: "tcp", Item: `{"id": "t_3", "name": "bravo", "type": "tcp"}`}, + } + require.NoError(t, rw.CreateItems(ctx, targets)) + + sessions := []*Session{ + {FkUserId: u.Id, Id: "s_1", Endpoint: "one", Type: "tcp", UserId: "u123", Item: `{"id": "s_1", "endpoint": "one", "type": "tcp", "user_id": "u123"}`}, + {FkUserId: u.Id, Id: "s_2", Endpoint: "two", Type: "ssh", UserId: "u321", Item: `{"id": "s_2", "endpoint": "two", "type": "ssh", "user_id": "u321"}`}, + } + require.NoError(t, rw.CreateItems(ctx, sessions)) + + aliases := []*ResolvableAlias{ + {FkUserId: u.Id, Id: "alt_1", Value: "one", Type: "target", Item: `{"id": "alt_1", "value": "one", "type": "target"}`}, + } + require.NoError(t, rw.CreateItems(ctx, aliases)) + } + + r, err := NewRepository(ctx, s, &sync.Map{}, + mapBasedAuthTokenKeyringLookup(nil), + sliceBasedAuthTokenBoundaryReader(nil)) + require.NoError(t, err) + + ss, err := NewSearchService(ctx, r) + require.NoError(t, err) + + t.Run("no sort specified returns results", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: Targets, + AuthTokenId: at.Id, + }) + require.NoError(t, err) + require.Len(t, got.Targets, 3) + }) + + t.Run("invalid sort by value", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: Targets, + AuthTokenId: at.Id, + SortBy: SortBy("invalid_column"), + }) + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid sort by value") + assert.Nil(t, got) + }) + + t.Run("invalid sort direction value", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: Targets, + AuthTokenId: at.Id, + SortBy: SortByName, + SortDirection: SortDirection("invalid"), + }) + assert.Error(t, err) + assert.ErrorContains(t, err, "invalid sort direction value") + assert.Nil(t, got) + }) + + t.Run("sort column not allowed for resolvable aliases", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: ResolvableAliases, + AuthTokenId: at.Id, + SortBy: SortByName, + SortDirection: Ascending, + }) + assert.Error(t, err) + assert.ErrorContains(t, err, "not allowed for this resource type") + assert.Nil(t, got) + }) + + t.Run("sessions reject name sort", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: Sessions, + AuthTokenId: at.Id, + SortBy: SortByName, + SortDirection: Ascending, + }) + assert.Error(t, err) + assert.ErrorContains(t, err, "not allowed for this resource type") + assert.Nil(t, got) + }) + + t.Run("sessions accept created_at sort", func(t *testing.T) { + got, err := ss.Search(ctx, SearchParams{ + Resource: Sessions, + AuthTokenId: at.Id, + SortBy: SortByCreatedAt, + SortDirection: Descending, + }) + require.NoError(t, err) + require.Len(t, got.Sessions, 2) + }) +} diff --git a/internal/clientcache/internal/daemon/search_handler.go b/internal/clientcache/internal/daemon/search_handler.go index 56af6faf8d..64668d113e 100644 --- a/internal/clientcache/internal/daemon/search_handler.go +++ b/internal/clientcache/internal/daemon/search_handler.go @@ -63,12 +63,10 @@ const ( sortDirectionKey = "sort_direction" ) -var ( - sortableColumnsForResource = map[cache.SearchableResource][]cache.SortBy{ - cache.Targets: []cache.SortBy{cache.SortByName}, - cache.Sessions: []cache.SortBy{cache.SortByCreatedTime}, - } -) +var sortableColumnsForResource = map[cache.SearchableResource][]cache.SortBy{ + cache.Targets: {cache.SortByName}, + cache.Sessions: {cache.SortByCreatedTime}, +} func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshService *cache.RefreshService, logger hclog.Logger) (http.HandlerFunc, error) { const op = "daemon.newSearchHandlerFunc"