diff --git a/internal/session/repository.go b/internal/session/repository.go index 9c0dce9102..32713c85db 100644 --- a/internal/session/repository.go +++ b/internal/session/repository.go @@ -52,8 +52,9 @@ func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repo // list will return a listing of resources and honor the WithLimit option or the // repo defaultLimit. Supports WithOrder option. -func (r *Repository) list(ctx context.Context, resources interface{}, where string, args []interface{}, opts options) error { +func (r *Repository) list(ctx context.Context, resources interface{}, where string, args []interface{}, opt ...Option) error { const op = "session.(Repository).list" + opts := getOpts(opt...) limit := r.defaultLimit var dbOpts []db.Option if opts.withLimit != 0 { diff --git a/internal/session/repository_connection.go b/internal/session/repository_connection.go index 43d1e1cf09..c145c14baa 100644 --- a/internal/session/repository_connection.go +++ b/internal/session/repository_connection.go @@ -44,21 +44,15 @@ func (r *Repository) LookupConnection(ctx context.Context, connectionId string, return &connection, states, nil } -// ListConnections will connections by session ID. If session ID is *, all -// connections are returned. Supports the WithLimit and WithOrder options. +// ListConnectionsBySessionId will list connections by session ID. Supports the +// WithLimit and WithOrder options. func (r *Repository) ListConnectionsBySessionId(ctx context.Context, sessionId string, opt ...Option) ([]*Connection, error) { - const op = "session.(Repository).ListConnections" - opts := getOpts(opt...) - var connections []*Connection - var args []interface{} - var where string - if sessionId != "*" { - where = "session_id = ?" - args = append(args, sessionId) + const op = "session.(Repository).ListConnectionsBySessionId" + if sessionId == "" { + return nil, errors.New(errors.InvalidParameter, op, "no session ID supplied") } - // The where clause will be ignored if args is empty, if we don't want to - // scope to a specific session ID - err := r.list(ctx, &connections, where, args, opts) // pass options, so WithLimit and WithOrder are supported + var connections []*Connection + err := r.list(ctx, &connections, "session_id = ?", []interface{}{sessionId}, opt...) // pass options, so WithLimit and WithOrder are supported if err != nil { return nil, errors.Wrap(err, op) } diff --git a/internal/session/repository_connection_test.go b/internal/session/repository_connection_test.go index 73068b8a35..3a4636090f 100644 --- a/internal/session/repository_connection_test.go +++ b/internal/session/repository_connection_test.go @@ -292,8 +292,8 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { // expect all others to be open. shouldBeClosed := worker2ConnIds[2:] - conns, err := repo.ListConnectionsBySessionId(ctx, "*") - require.NoError(err) + var conns []*Connection + require.NoError(repo.list(ctx, &conns, "", nil)) for _, conn := range conns { _, states, err := repo.LookupConnection(ctx, conn.PublicId) require.NoError(err) @@ -315,8 +315,8 @@ func TestRepository_CloseDeadConnectionsOnWorkerReport(t *testing.T) { assert.Equal(6, count) // We now expect all but those blessed few to be closed - conns, err = repo.ListConnectionsBySessionId(ctx, "*") - require.NoError(err) + conns = nil + require.NoError(repo.list(ctx, &conns, "", nil)) for _, conn := range conns { _, states, err := repo.LookupConnection(ctx, conn.PublicId) require.NoError(err)