feat(clientcache): Extend SearchParams and Options with sort support and validation (#6342)

pull/6383/head
Sepehr 5 months ago
parent 614abbe6e8
commit 2e9623d462

@ -5,10 +5,21 @@ package cache
import (
stderrors "errors"
"fmt"
"regexp"
"slices"
"github.com/hashicorp/go-dbw"
)
// safeSortColumnRegex contains characters that could break SQL ORDER BY clauses
var safeSortColumnRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)
var (
errInvalidSortColumn = stderrors.New("not allowed for this resource type")
errUnsafeSortColumn = stderrors.New("contains unsafe characters")
)
type testRefreshWaitChs struct {
firstSempahore chan struct{}
secondSemaphore chan struct{}
@ -26,6 +37,8 @@ type options struct {
withMaxResultSetSize int
withTestRefreshWaitChs *testRefreshWaitChs
withUseNonPagedListing bool
withSortBy SortBy // validated DB column name
withSortDirection SortDirection // "asc" or "desc"
}
// Option - how options are passed as args
@ -139,3 +152,26 @@ func WithUseNonPagedListing(b bool) Option {
return nil
}
}
// WithSort configures sorting for query results.
// Empty sortBy is silently ignored. Empty direction defaults to ascending in the repository layer.
// Validates column against sortableColumns and rejects SQL-unsafe characters.
func WithSort(sortBy SortBy, direction SortDirection, sortableColumns []SortBy) Option {
return func(o *options) error {
// ignore empty sortBy
if sortBy == SortByDefault {
return nil
}
switch {
case !slices.Contains(sortableColumns, sortBy):
return fmt.Errorf("invalid sort column %q: %w", sortBy, errInvalidSortColumn)
case !safeSortColumnRegex.MatchString(string(sortBy)):
return fmt.Errorf("invalid sort column %q: %w", sortBy, errUnsafeSortColumn)
}
o.withSortBy = sortBy
o.withSortDirection = direction
return nil
}
}

@ -126,4 +126,105 @@ func Test_GetOpts(t *testing.T) {
testOpts.withUseNonPagedListing = true
assert.Equal(t, opts, testOpts)
})
t.Run("WithSort-default-sortby-ignored", func(t *testing.T) {
opts, err := getOpts(WithSort(SortByDefault, Ascending, []SortBy{SortByName}))
require.NoError(t, err)
testOpts := getDefaultOptions()
assert.Equal(t, opts, testOpts)
})
t.Run("WithSort-empty-sortby-ignored", func(t *testing.T) {
opts, err := getOpts(WithSort("", Ascending, []SortBy{SortByName}))
require.NoError(t, err)
testOpts := getDefaultOptions()
assert.Equal(t, opts, testOpts)
})
t.Run("WithSort-valid-name-ascending", func(t *testing.T) {
opts, err := getOpts(WithSort(SortByName, Ascending, []SortBy{SortByName, SortByCreatedAt}))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withSortBy = SortByName
testOpts.withSortDirection = Ascending
assert.Equal(t, opts, testOpts)
})
t.Run("WithSort-valid-created_at-descending", func(t *testing.T) {
opts, err := getOpts(WithSort(SortByCreatedAt, Descending, []SortBy{SortByCreatedAt}))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withSortBy = SortByCreatedAt
testOpts.withSortDirection = Descending
assert.Equal(t, opts, testOpts)
})
t.Run("WithSort-column-not-in-sortable-list", func(t *testing.T) {
_, err := getOpts(WithSort(SortByName, Ascending, []SortBy{SortByCreatedAt}))
require.Error(t, err)
assert.ErrorContains(t, err, errInvalidSortColumn.Error())
})
t.Run("WithSort-empty-sortable-columns", func(t *testing.T) {
_, err := getOpts(WithSort(SortByName, Ascending, []SortBy{}))
require.Error(t, err)
assert.ErrorContains(t, err, errInvalidSortColumn.Error())
})
t.Run("WithSort-nil-sortable-columns", func(t *testing.T) {
_, err := getOpts(WithSort(SortByName, Ascending, nil))
require.Error(t, err)
assert.ErrorContains(t, err, errInvalidSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-semicolon", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name; DROP TABLE"), Ascending, []SortBy{SortBy("name; DROP TABLE")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-quote", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name'--"), Ascending, []SortBy{SortBy("name'--")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-double-quote", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name\"--"), Ascending, []SortBy{SortBy("name\"--")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-backslash", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name\\x00"), Ascending, []SortBy{SortBy("name\\x00")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-comma", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name,other"), Ascending, []SortBy{SortBy("name,other")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-parenthesis", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name("), Ascending, []SortBy{SortBy("name(")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-space", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name "), Ascending, []SortBy{SortBy("name ")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-tab", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name\t"), Ascending, []SortBy{SortBy("name\t")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-newline", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name\n"), Ascending, []SortBy{SortBy("name\n")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-unsafe-chars-dash", func(t *testing.T) {
_, err := getOpts(WithSort(SortBy("name-col"), Ascending, []SortBy{SortBy("name-col")}))
require.Error(t, err)
assert.ErrorContains(t, err, errUnsafeSortColumn.Error())
})
t.Run("WithSort-default-direction", func(t *testing.T) {
opts, err := getOpts(WithSort(SortByName, SortDirectionDefault, []SortBy{SortByName}))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.withSortBy = SortByName
testOpts.withSortDirection = SortDirectionDefault
assert.Equal(t, opts, testOpts)
})
}

@ -25,6 +25,15 @@ const (
SortByCreatedTime SortBy = "created_time"
)
// Valid returns true if the SortBy value is a known good value
func (s SortBy) Valid() bool {
switch s {
case SortByDefault, SortByName, SortByCreatedAt:
return true
}
return false
}
type SortDirection string
const (
@ -33,6 +42,15 @@ const (
Descending SortDirection = "desc"
)
// Valid returns true if the SortDirection value is a known good value
func (d SortDirection) Valid() bool {
switch d {
case SortDirectionDefault, Ascending, Descending:
return true
}
return false
}
type SearchableResource string
const (
@ -136,6 +154,7 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er
}
in.Targets = finalResults
},
sortableColumns: []SortBy{SortByName, SortByCreatedAt},
},
Sessions: &resourceSearchFns[*sessions.Session]{
list: repo.ListSessions,
@ -149,6 +168,7 @@ func NewSearchService(ctx context.Context, repo *Repository) (*SearchService, er
}
in.Sessions = finalResults
},
sortableColumns: []SortBy{SortByCreatedAt},
},
ImplicitScopes: &resourceSearchFns[*scopes.Scope]{
list: repo.ListImplicitScopes,
@ -204,6 +224,10 @@ func (s *SearchService) Search(ctx context.Context, params SearchParams) (*Searc
return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid resource")
case params.AuthTokenId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing auth token id")
case !params.SortBy.Valid():
return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid sort by value")
case !params.SortDirection.Valid():
return nil, errors.New(ctx, errors.InvalidParameter, op, "invalid sort direction value")
}
rSearcher, ok := s.searchableResources[params.Resource]
if !ok {
@ -231,6 +255,8 @@ type resourceSearchFns[T any] struct {
// filter takes results and a ready-to-use evaluator and filters the items
// in the result
filter func(*SearchResult, *bexpr.Evaluator)
// sortableColumns is a list of columns that can be used for sorting
sortableColumns []SortBy
}
// resourceSearcher is an interface that only resourceSearchFns[T] is expected
@ -249,13 +275,18 @@ type resourceSearcher interface {
func (l *resourceSearchFns[T]) search(ctx context.Context, p SearchParams) (*SearchResult, error) {
const op = "cache.(resourceSearchFns).search"
opts := []Option{WithMaxResultSetSize(p.MaxResultSetSize)}
if p.SortBy != SortByDefault {
opts = append(opts, WithSort(p.SortBy, p.SortDirection, l.sortableColumns))
}
var found *SearchResult
var err error
switch p.Query {
case "":
found, err = l.list(ctx, p.AuthTokenId, WithMaxResultSetSize(p.MaxResultSetSize))
found, err = l.list(ctx, p.AuthTokenId, opts...)
default:
found, err = l.query(ctx, p.AuthTokenId, p.Query, WithMaxResultSetSize(p.MaxResultSetSize))
found, err = l.query(ctx, p.AuthTokenId, p.Query, opts...)
}
if err != nil {
return nil, errors.Wrap(ctx, err, op)

@ -442,3 +442,166 @@ func TestSearch(t *testing.T) {
assert.Equal(t, &SearchResult{Targets: []*targets.Target{}}, got)
})
}
func TestSortByValid(t *testing.T) {
cases := []struct {
sortBy SortBy
valid bool
}{
{SortByDefault, true},
{SortBy(""), true},
{SortByName, true},
{SortByCreatedAt, true},
{SortBy("unknown"), false},
{SortBy("id"), false},
{SortBy("invalid_column"), false},
{SortBy("name; DROP TABLE"), false},
{SortBy("name'--"), false},
{SortBy("name\"--"), false},
{SortBy("name\\x00"), false},
{SortBy("name,other"), false},
{SortBy("name ("), false},
{SortBy("name)"), false},
{SortBy("name\t"), false},
{SortBy("name\n"), false},
{SortBy("name\r"), false},
{SortBy("name "), false},
}
for _, tc := range cases {
t.Run(string(tc.sortBy), func(t *testing.T) {
assert.Equal(t, tc.valid, tc.sortBy.Valid())
})
}
}
func TestSortDirectionValid(t *testing.T) {
cases := []struct {
direction SortDirection
valid bool
}{
{SortDirectionDefault, true},
{SortDirection(""), true},
{Ascending, true},
{Descending, true},
{SortDirection("ASC"), false},
{SortDirection("DESC"), false},
{SortDirection("invalid"), false},
}
for _, tc := range cases {
t.Run(string(tc.direction), func(t *testing.T) {
assert.Equal(t, tc.valid, tc.direction.Valid())
})
}
}
func TestSearch_Sorting(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
at := &AuthToken{
Id: "at_sort",
UserId: "u_sort",
}
{
u := &user{Id: at.UserId, Address: "address"}
rw := db.New(s)
require.NoError(t, rw.Create(ctx, u))
require.NoError(t, rw.Create(ctx, at))
targets := []*Target{
{FkUserId: u.Id, Id: "t_1", Name: "alpha", Type: "tcp", Item: `{"id": "t_1", "name": "alpha", "type": "tcp"}`},
{FkUserId: u.Id, Id: "t_2", Name: "charlie", Type: "tcp", Item: `{"id": "t_2", "name": "charlie", "type": "tcp"}`},
{FkUserId: u.Id, Id: "t_3", Name: "bravo", Type: "tcp", Item: `{"id": "t_3", "name": "bravo", "type": "tcp"}`},
}
require.NoError(t, rw.CreateItems(ctx, targets))
sessions := []*Session{
{FkUserId: u.Id, Id: "s_1", Endpoint: "one", Type: "tcp", UserId: "u123", Item: `{"id": "s_1", "endpoint": "one", "type": "tcp", "user_id": "u123"}`},
{FkUserId: u.Id, Id: "s_2", Endpoint: "two", Type: "ssh", UserId: "u321", Item: `{"id": "s_2", "endpoint": "two", "type": "ssh", "user_id": "u321"}`},
}
require.NoError(t, rw.CreateItems(ctx, sessions))
aliases := []*ResolvableAlias{
{FkUserId: u.Id, Id: "alt_1", Value: "one", Type: "target", Item: `{"id": "alt_1", "value": "one", "type": "target"}`},
}
require.NoError(t, rw.CreateItems(ctx, aliases))
}
r, err := NewRepository(ctx, s, &sync.Map{},
mapBasedAuthTokenKeyringLookup(nil),
sliceBasedAuthTokenBoundaryReader(nil))
require.NoError(t, err)
ss, err := NewSearchService(ctx, r)
require.NoError(t, err)
t.Run("no sort specified returns results", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: Targets,
AuthTokenId: at.Id,
})
require.NoError(t, err)
require.Len(t, got.Targets, 3)
})
t.Run("invalid sort by value", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: Targets,
AuthTokenId: at.Id,
SortBy: SortBy("invalid_column"),
})
assert.Error(t, err)
assert.ErrorContains(t, err, "invalid sort by value")
assert.Nil(t, got)
})
t.Run("invalid sort direction value", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: Targets,
AuthTokenId: at.Id,
SortBy: SortByName,
SortDirection: SortDirection("invalid"),
})
assert.Error(t, err)
assert.ErrorContains(t, err, "invalid sort direction value")
assert.Nil(t, got)
})
t.Run("sort column not allowed for resolvable aliases", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: ResolvableAliases,
AuthTokenId: at.Id,
SortBy: SortByName,
SortDirection: Ascending,
})
assert.Error(t, err)
assert.ErrorContains(t, err, "not allowed for this resource type")
assert.Nil(t, got)
})
t.Run("sessions reject name sort", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: Sessions,
AuthTokenId: at.Id,
SortBy: SortByName,
SortDirection: Ascending,
})
assert.Error(t, err)
assert.ErrorContains(t, err, "not allowed for this resource type")
assert.Nil(t, got)
})
t.Run("sessions accept created_at sort", func(t *testing.T) {
got, err := ss.Search(ctx, SearchParams{
Resource: Sessions,
AuthTokenId: at.Id,
SortBy: SortByCreatedAt,
SortDirection: Descending,
})
require.NoError(t, err)
require.Len(t, got.Sessions, 2)
})
}

@ -63,12 +63,10 @@ const (
sortDirectionKey = "sort_direction"
)
var (
sortableColumnsForResource = map[cache.SearchableResource][]cache.SortBy{
cache.Targets: []cache.SortBy{cache.SortByName},
cache.Sessions: []cache.SortBy{cache.SortByCreatedTime},
}
)
var sortableColumnsForResource = map[cache.SearchableResource][]cache.SortBy{
cache.Targets: {cache.SortByName},
cache.Sessions: {cache.SortByCreatedTime},
}
func newSearchHandlerFunc(ctx context.Context, repo *cache.Repository, refreshService *cache.RefreshService, logger hclog.Logger) (http.HandlerFunc, error) {
const op = "daemon.newSearchHandlerFunc"

Loading…
Cancel
Save