You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/daemon/controller/handlers/sessions/session_service_test.go

783 lines
30 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package sessions_test
import (
"context"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/daemon/controller/auth"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers"
"github.com/hashicorp/boundary/internal/daemon/controller/handlers/sessions"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
authpb "github.com/hashicorp/boundary/internal/gen/controller/auth"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/requests"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/internal/target/tcp"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/scopes"
pb "github.com/hashicorp/boundary/sdk/pbs/controller/api/resources/sessions"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/protobuf/testing/protocmp"
)
var testAuthorizedActions = []string{"read:self", "cancel:self"}
func TestGetSession(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrap := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrap)
iamRepo := iam.TestRepo(t, conn, wrap)
rw := db.New(conn)
ctx := context.Background()
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
sessRepoFn := func(opt ...session.Option) (*session.Repository, error) {
return session.NewRepository(ctx, rw, rw, kms, opt...)
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(ctx, rw, rw, kms)
}
o, p := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
uId := at.GetIamUserId()
hc := static.TestCatalogs(t, conn, p.GetPublicId(), 1)[0]
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(context.Background(), t, conn, p.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ProjectId: p.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
wireSess := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: p.GetPublicId(),
AuthTokenId: at.GetPublicId(),
Endpoint: sess.Endpoint,
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
Status: session.StatusPending.String(),
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: p.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
States: []*pb.SessionState{{Status: session.StatusPending.String(), StartTime: sess.CreateTime.GetTimestamp()}},
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
}
cases := []struct {
name string
scopeId string
req *pbs.GetSessionRequest
res *pbs.GetSessionResponse
err error
}{
{
name: "Get a session",
scopeId: sess.ProjectId,
req: &pbs.GetSessionRequest{Id: sess.GetPublicId()},
res: &pbs.GetSessionResponse{Item: wireSess},
},
{
name: "Get a non existent Session",
req: &pbs.GetSessionRequest{Id: globals.SessionPrefix + "_DoesntExis"},
res: nil,
err: handlers.ApiErrorWithCode(codes.NotFound),
},
{
name: "Wrong id prefix",
req: &pbs.GetSessionRequest{Id: "j_1234567890"},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "space in id",
req: &pbs.GetSessionRequest{Id: globals.SessionPrefix + "_1 23456789"},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
s, err := sessions.NewService(ctx, sessRepoFn, iamRepoFn)
require.NoError(err, "Couldn't create new session service.")
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.GetSession(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "GetSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
}
if tc.res != nil {
assert.True(got.GetItem().GetExpirationTime().AsTime().Sub(tc.res.GetItem().GetExpirationTime().AsTime()) < 10*time.Millisecond)
tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime()
}
assert.Empty(cmp.Diff(tc.res, got, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
})
}
}
func TestList_Self(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrap := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrap)
iamRepo := iam.TestRepo(t, conn, wrap)
rw := db.New(conn)
ctx := context.Background()
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
sessRepoFn := func(opt ...session.Option) (*session.Repository, error) {
return session.NewRepository(ctx, rw, rw, kms, opt...)
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(ctx, rw, rw, kms)
}
o, pWithSessions := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
uId := at.GetIamUserId()
otherPrivAuthToken := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
unprivAuthToken := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
// See https://github.com/hashicorp/boundary/pull/2448 -- these roles both
// test functionality and serve as a regression test
// Create a "privileged" role that gives admin on the scope
privProjRole := iam.TestRole(t, conn, pWithSessions.GetPublicId())
iam.TestRoleGrant(t, conn, privProjRole.GetPublicId(), "id=*;type=*;actions=*")
iam.TestUserRole(t, conn, privProjRole.GetPublicId(), otherPrivAuthToken.GetIamUserId())
// Create an "unprivileged" role that only grants self variants and add the
// unprivileged user and other privileged users
unPrivProjRole := iam.TestRole(t, conn, pWithSessions.GetPublicId())
iam.TestRoleGrant(t, conn, unPrivProjRole.GetPublicId(), "id=*;type=session;actions=read:self,list,cancel:self")
iam.TestUserRole(t, conn, unPrivProjRole.GetPublicId(), unprivAuthToken.GetIamUserId())
iam.TestUserRole(t, conn, unPrivProjRole.GetPublicId(), otherPrivAuthToken.GetIamUserId())
hc := static.TestCatalogs(t, conn, pWithSessions.GetPublicId(), 1)[0]
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(context.Background(), t, conn, pWithSessions.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
// By default a user can read/cancel their own sessions.
session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ProjectId: pWithSessions.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
s, err := sessions.NewService(ctx, sessRepoFn, iamRepoFn)
require.NoError(t, err, "Couldn't create new session service.")
cases := []struct {
name string
requester *authtoken.AuthToken
count int
}{
{
name: "List Self Sessions",
requester: at,
count: 1,
},
{
name: "Can List Others Sessions when Authorized",
requester: otherPrivAuthToken,
count: 1,
},
{
name: "Can't List Others Sessions When Not Authorized",
requester: unprivAuthToken,
count: 0,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Setup the auth request information
req := httptest.NewRequest("GET", fmt.Sprintf("http://127.0.0.1/v1/sessions?scope_id=%s", pWithSessions.GetPublicId()), nil)
requestInfo := authpb.RequestInfo{
Path: req.URL.Path,
Method: req.Method,
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: tc.requester.GetPublicId(),
Token: tc.requester.GetToken(),
}
ctx := auth.NewVerifierContext(context.Background(), iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, err := s.ListSessions(ctx, &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()})
require.NoError(t, err)
assert.Equal(t, tc.count, len(got.GetItems()), got.GetItems())
})
}
}
func TestList(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrap := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrap)
ctx := context.Background()
iamRepo := iam.TestRepo(t, conn, wrap)
rw := db.New(conn)
sessRepo, err := session.NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
sessRepoFn := func(opt ...session.Option) (*session.Repository, error) {
return session.NewRepository(ctx, rw, rw, kms, opt...)
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(ctx, rw, rw, kms)
}
_, pNoSessions := iam.TestScopes(t, iamRepo)
o, pWithSessions := iam.TestScopes(t, iamRepo)
oOther, pWithOtherSessions := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
uId := at.GetIamUserId()
atOther := authtoken.TestAuthToken(t, conn, kms, oOther.GetPublicId())
uIdOther := atOther.GetIamUserId()
hc := static.TestCatalogs(t, conn, pWithSessions.GetPublicId(), 1)[0]
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(ctx, t, conn, pWithSessions.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
hcOther := static.TestCatalogs(t, conn, pWithOtherSessions.GetPublicId(), 1)[0]
hsOther := static.TestSets(t, conn, hcOther.GetPublicId(), 1)[0]
hOther := static.TestHosts(t, conn, hcOther.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hsOther.GetPublicId(), []*static.Host{hOther})
tarOther := tcp.TestTarget(ctx, t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()}))
var wantSession []*pb.Session
var wantOtherSession []*pb.Session
var wantAllSessions []*pb.Session
var wantIncludeTerminatedSessions []*pb.Session
for i := 0; i < 10; i++ {
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ProjectId: pWithSessions.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")
status, states := convertStates(sess.States)
firstOrgSession := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: pWithSessions.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
Endpoint: sess.Endpoint,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: status,
States: states,
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{}, // connections should not be returned for list
}
wantSession = append(wantSession, firstOrgSession)
wantAllSessions = append(wantAllSessions, firstOrgSession)
wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, wantSession[i])
sess = session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uIdOther,
HostId: hOther.GetPublicId(),
TargetId: tarOther.GetPublicId(),
HostSetId: hsOther.GetPublicId(),
AuthTokenId: atOther.GetPublicId(),
ProjectId: pWithOtherSessions.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
session.TestConnection(t, conn, sess.PublicId, "127.0.0.1", 22, "127.0.0.2", 23, "127.0.0.1")
status, states = convertStates(sess.States)
otherOrgSession := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: pWithSessions.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
Endpoint: sess.Endpoint,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: status,
States: states,
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{}, // connections should not be returned for list
}
wantOtherSession = append(wantOtherSession, otherOrgSession)
wantAllSessions = append(wantAllSessions, otherOrgSession)
}
{
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ProjectId: pWithSessions.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
sess, err := sessRepo.CancelSession(ctx, sess.PublicId, sess.Version)
require.NoError(t, err)
terminated, err := sessRepo.TerminateCompletedSessions(ctx)
require.NoError(t, err)
require.Equal(t, 1, terminated)
sess, _, err = sessRepo.LookupSession(ctx, sess.PublicId)
require.NoError(t, err)
status, states := convertStates(sess.States)
expected := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: pWithSessions.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
Endpoint: sess.Endpoint,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: status,
States: states,
Certificate: sess.Certificate,
TerminationReason: sess.TerminationReason,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
Connections: []*pb.Connection{}, // connections should not be returned for list
}
wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, expected)
}
cases := []struct {
name string
req *pbs.ListSessionsRequest
res *pbs.ListSessionsResponse
otherRes *pbs.ListSessionsResponse
allSessionRes *pbs.ListSessionsResponse
err error
}{
{
name: "List Many Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{Items: wantSession},
},
{
name: "List Many Include Terminated",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true},
res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions},
},
{
name: "List No Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{},
},
{
name: "List Sessions Recursively",
req: &pbs.ListSessionsRequest{ScopeId: scope.Global.String(), Recursive: true},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: wantOtherSession},
allSessionRes: &pbs.ListSessionsResponse{Items: wantAllSessions},
},
{
name: "Filter To Single Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: fmt.Sprintf(`"/item/id"==%q`, wantSession[4].Id)},
res: &pbs.ListSessionsResponse{Items: wantSession[4:5]},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{Items: wantSession[4:5]},
},
{
name: "Filter To Many Sessions",
req: &pbs.ListSessionsRequest{
ScopeId: scope.Global.String(), Recursive: true,
Filter: fmt.Sprintf(`"/item/scope/id" matches "^%s"`, pWithSessions.GetPublicId()[:8]),
},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{Items: wantSession},
},
{
name: "Filter To Nothing",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `"/item/id" == ""`},
res: &pbs.ListSessionsResponse{},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
allSessionRes: &pbs.ListSessionsResponse{},
},
{
name: "Filter Bad Format",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `//badformat/`},
err: handlers.InvalidArgumentErrorf("bad format", nil),
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
require, assert := require.New(t), assert.New(t)
s, err := sessions.NewService(ctx, sessRepoFn, iamRepoFn)
require.NoError(err, "Couldn't create new session service.")
// Test without anon user
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.ListSessions(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "ListSessions(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
return
}
require.NoError(gErr)
if tc.res != nil {
require.Equal(len(tc.res.GetItems()), len(got.GetItems()), "Didn't get expected number of sessions: %v", got.GetItems())
for i, wantSess := range tc.res.GetItems() {
assert.True(got.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond)
assert.Equal(0, len(wantSess.GetConnections())) // no connections on list
wantSess.ExpirationTime = got.GetItems()[i].GetExpirationTime()
}
}
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "ListSessions(%q) got response %q, wanted %q", tc.req, got, tc.res)
// Test with other user
otherRequestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: atOther.GetPublicId(),
Token: atOther.GetToken(),
}
otherRequestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
otherCtx := auth.NewVerifierContext(otherRequestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &otherRequestInfo)
got, gErr = s.ListSessions(otherCtx, tc.req)
require.NoError(gErr)
assert.Len(got.Items, len(tc.otherRes.Items))
for i, wantSess := range tc.otherRes.GetItems() {
assert.True(got.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond)
assert.Equal(0, len(wantSess.GetConnections())) // no connections on list
wantSess.ExpirationTime = got.GetItems()[i].GetExpirationTime()
}
// Test with recovery user
recoveryRequestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeRecoveryKms),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
recoveryRequestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
recoveryCtx := auth.NewVerifierContext(recoveryRequestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &recoveryRequestInfo)
recoveryGot, gErr := s.ListSessions(recoveryCtx, tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "ListSessions(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
return
}
require.NoError(gErr)
if tc.allSessionRes != nil {
require.Equal(len(tc.allSessionRes.GetItems()), len(recoveryGot.GetItems()), "Didn't get expected number of sessions: %v", recoveryGot.GetItems())
for i, wantSess := range tc.allSessionRes.GetItems() {
assert.True(recoveryGot.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond)
assert.Equal(0, len(wantSess.GetConnections())) // no connections on list
wantSess.ExpirationTime = recoveryGot.GetItems()[i].GetExpirationTime()
}
}
})
}
}
func convertStates(in []*session.State) (string, []*pb.SessionState) {
out := make([]*pb.SessionState, 0, len(in))
for _, s := range in {
sessState := &pb.SessionState{
Status: s.Status.String(),
}
if s.StartTime != nil {
sessState.StartTime = s.StartTime.GetTimestamp()
}
if s.EndTime != nil {
sessState.EndTime = s.EndTime.GetTimestamp()
}
out = append(out, sessState)
}
return out[0].Status, out
}
func TestCancel(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrap := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrap)
iamRepo := iam.TestRepo(t, conn, wrap)
rw := db.New(conn)
ctx := context.Background()
iamRepoFn := func() (*iam.Repository, error) {
return iamRepo, nil
}
sessRepoFn := func(opt ...session.Option) (*session.Repository, error) {
return session.NewRepository(ctx, rw, rw, kms, opt...)
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(ctx, rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(ctx, rw, rw, kms)
}
o, p := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
uId := at.GetIamUserId()
hc := static.TestCatalogs(t, conn, p.GetPublicId(), 1)[0]
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
tar := tcp.TestTarget(context.Background(), t, conn, p.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
UserId: uId,
HostId: h.GetPublicId(),
TargetId: tar.GetPublicId(),
HostSetId: hs.GetPublicId(),
AuthTokenId: at.GetPublicId(),
ProjectId: p.GetPublicId(),
Endpoint: "tcp://127.0.0.1:22",
})
wireSess := &pb.Session{
Id: sess.GetPublicId(),
ScopeId: p.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
HostSetId: sess.HostSetId,
HostId: sess.HostId,
Version: sess.Version,
Endpoint: sess.Endpoint,
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: p.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: session.StatusCanceling.String(),
Certificate: sess.Certificate,
Type: tcp.Subtype.String(),
AuthorizedActions: testAuthorizedActions,
}
version := wireSess.GetVersion()
cases := []struct {
name string
scopeId string
req *pbs.CancelSessionRequest
res *pbs.CancelSessionResponse
err error
}{
{
name: "Cancel a session",
scopeId: sess.ProjectId,
req: &pbs.CancelSessionRequest{Id: sess.GetPublicId()},
res: &pbs.CancelSessionResponse{Item: wireSess},
},
{
name: "Already canceled",
scopeId: sess.ProjectId,
req: &pbs.CancelSessionRequest{Id: sess.GetPublicId()},
res: &pbs.CancelSessionResponse{Item: wireSess},
},
{
name: "Cancel a non existing Session",
req: &pbs.CancelSessionRequest{Id: globals.SessionPrefix + "_DoesntExis"},
res: nil,
err: handlers.ApiErrorWithCode(codes.NotFound),
},
{
name: "Wrong id prefix",
req: &pbs.CancelSessionRequest{Id: "j_1234567890"},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
{
name: "space in id",
req: &pbs.CancelSessionRequest{Id: globals.SessionPrefix + "_1 23456789"},
res: nil,
err: handlers.ApiErrorWithCode(codes.InvalidArgument),
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
s, err := sessions.NewService(ctx, sessRepoFn, iamRepoFn)
require.NoError(err, "Couldn't create new session service.")
tc.req.Version = version
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.CancelSession(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
// It's hard to mix and match api/error package errors right now
// so use old/new behavior depending on the type. If validate
// gets updated this can be standardized.
if errors.Match(errors.T(errors.InvalidSessionState), gErr) {
assert.True(errors.Match(errors.T(tc.err), gErr), "CancelSession(%+v) got error %#v, wanted %#v", tc.req, gErr, tc.err)
} else {
assert.True(errors.Is(gErr, tc.err), "CancelSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
}
}
if tc.res == nil {
require.Nil(got)
return
}
require.NotNil(got)
tc.res.GetItem().Version = got.GetItem().Version
// Compare the new canceling state and then remove it for the rest of the proto comparison
assert.True(got.GetItem().GetUpdatedTime().AsTime().After(got.GetItem().GetCreatedTime().AsTime()))
assert.Len(got.GetItem().GetStates(), 2)
wantState := []*pb.SessionState{
{
Status: session.StatusCanceling.String(),
StartTime: got.GetItem().GetUpdatedTime(),
},
{
Status: session.StatusPending.String(),
StartTime: got.GetItem().GetCreatedTime(),
EndTime: got.GetItem().GetUpdatedTime(),
},
}
assert.Empty(cmp.Diff(got.GetItem().GetStates(), wantState, protocmp.Transform()), "CancelSession(%q) states")
got.GetItem().States = nil
got.GetItem().UpdatedTime = nil
if tc.res != nil {
assert.True(got.GetItem().GetExpirationTime().AsTime().Sub(tc.res.GetItem().GetExpirationTime().AsTime()) < 10*time.Millisecond)
tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime()
}
assert.Equal(got.GetItem().HostId, tc.res.GetItem().HostId)
assert.Equal(got.GetItem().HostSetId, tc.res.GetItem().HostSetId)
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "CancelSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
if tc.req != nil {
require.NotNil(got)
version = got.GetItem().GetVersion()
}
})
}
}