diff --git a/internal/daemon/controller/handlers/sessions/session_service_test.go b/internal/daemon/controller/handlers/sessions/session_service_test.go index 158a059936..c68aba5f40 100644 --- a/internal/daemon/controller/handlers/sessions/session_service_test.go +++ b/internal/daemon/controller/handlers/sessions/session_service_test.go @@ -19,6 +19,7 @@ import ( "github.com/hashicorp/boundary/internal/host/static" "github.com/hashicorp/boundary/internal/iam" "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/requests" "github.com/hashicorp/boundary/internal/server" "github.com/hashicorp/boundary/internal/session" "github.com/hashicorp/boundary/internal/target" @@ -32,7 +33,7 @@ import ( "google.golang.org/protobuf/testing/protocmp" ) -var testAuthorizedActions = []string{"no-op", "read", "read:self", "cancel", "cancel:self"} +var testAuthorizedActions = []string{"read:self", "cancel:self"} func TestGetSession(t *testing.T) { conn, _ := db.TestSetup(t, "postgres") @@ -42,6 +43,7 @@ func TestGetSession(t *testing.T) { iamRepo := iam.TestRepo(t, conn, wrap) rw := db.New(conn) + sessRepo, err := session.NewRepository(rw, rw, kms) require.NoError(t, err) @@ -51,6 +53,12 @@ func TestGetSession(t *testing.T) { sessRepoFn := func() (*session.Repository, error) { return sessRepo, nil } + tokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kms) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(rw, rw, kms) + } o, p := iam.TestScopes(t, iamRepo) at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId()) @@ -106,7 +114,7 @@ func TestGetSession(t *testing.T) { res: &pbs.GetSessionResponse{Item: wireSess}, }, { - name: "Get a non existant Session", + name: "Get a non existent Session", req: &pbs.GetSessionRequest{Id: session.SessionPrefix + "_DoesntExis"}, res: nil, err: handlers.ApiErrorWithCode(codes.NotFound), @@ -131,7 +139,15 @@ func TestGetSession(t *testing.T) { s, err := sessions.NewService(sessRepoFn, iamRepoFn) require.NoError(err, "Couldn't create new session service.") - got, gErr := s.GetSession(auth.DisabledAuthTestContext(iamRepoFn, tc.scopeId), tc.req) + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + + got, gErr := s.GetSession(ctx, tc.req) if tc.err != nil { require.Error(gErr) assert.True(errors.Is(gErr, tc.err), "GetSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) @@ -140,7 +156,7 @@ func TestGetSession(t *testing.T) { assert.True(got.GetItem().GetExpirationTime().AsTime().Sub(tc.res.GetItem().GetExpirationTime().AsTime()) < 10*time.Millisecond) tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime() } - assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res) + assert.Empty(cmp.Diff(tc.res, got, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res) }) } } @@ -246,6 +262,12 @@ func TestList(t *testing.T) { sessRepoFn := func() (*session.Repository, error) { return sessRepo, nil } + tokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kms) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(rw, rw, kms) + } _, pNoSessions := iam.TestScopes(t, iamRepo) o, pWithSessions := iam.TestScopes(t, iamRepo) @@ -270,7 +292,7 @@ func TestList(t *testing.T) { tarOther := tcp.TestTarget(ctx, t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()})) var wantSession []*pb.Session - var totalSession []*pb.Session + var wantOtherSession []*pb.Session var wantIncludeTerminatedSessions []*pb.Session for i := 0; i < 10; i++ { sess := session.TestSession(t, conn, wrap, session.ComposedOf{ @@ -309,7 +331,6 @@ func TestList(t *testing.T) { Connections: []*pb.Connection{}, // connections should not be returned for list }) - totalSession = append(totalSession, wantSession[i]) wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, wantSession[i]) sess = session.TestSession(t, conn, wrap, session.ComposedOf{ @@ -326,11 +347,11 @@ func TestList(t *testing.T) { status, states = convertStates(sess.States) - totalSession = append(totalSession, &pb.Session{ + wantOtherSession = append(wantOtherSession, &pb.Session{ Id: sess.GetPublicId(), - ScopeId: pWithOtherSessions.GetPublicId(), - AuthTokenId: atOther.GetPublicId(), - UserId: atOther.GetIamUserId(), + ScopeId: pWithSessions.GetPublicId(), + AuthTokenId: at.GetPublicId(), + UserId: at.GetIamUserId(), TargetId: sess.TargetId, Endpoint: sess.Endpoint, HostSetId: sess.HostSetId, @@ -339,7 +360,7 @@ func TestList(t *testing.T) { UpdatedTime: sess.UpdateTime.GetTimestamp(), CreatedTime: sess.CreateTime.GetTimestamp(), ExpirationTime: sess.ExpirationTime.GetTimestamp(), - Scope: &scopes.ScopeInfo{Id: pWithOtherSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: oOther.GetPublicId()}, + Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()}, Status: status, States: states, Certificate: sess.Certificate, @@ -397,35 +418,41 @@ func TestList(t *testing.T) { } cases := []struct { - name string - req *pbs.ListSessionsRequest - res *pbs.ListSessionsResponse - err error + name string + req *pbs.ListSessionsRequest + res *pbs.ListSessionsResponse + otherRes *pbs.ListSessionsResponse + err error }{ { - name: "List Many Sessions", - req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()}, - res: &pbs.ListSessionsResponse{Items: wantSession}, + name: "List Many Sessions", + req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()}, + res: &pbs.ListSessionsResponse{Items: wantSession}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { - name: "List Many Include Terminated", - req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true}, - res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions}, + name: "List Many Include Terminated", + req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true}, + res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { - name: "List No Sessions", - req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()}, - res: &pbs.ListSessionsResponse{}, + name: "List No Sessions", + req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()}, + res: &pbs.ListSessionsResponse{}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { - name: "List Sessions Recursively", - req: &pbs.ListSessionsRequest{ScopeId: scope.Global.String(), Recursive: true}, - res: &pbs.ListSessionsResponse{Items: totalSession}, + name: "List Sessions Recursively", + req: &pbs.ListSessionsRequest{ScopeId: scope.Global.String(), Recursive: true}, + res: &pbs.ListSessionsResponse{Items: wantSession}, + otherRes: &pbs.ListSessionsResponse{Items: wantOtherSession}, }, { - name: "Filter To Single Sessions", - req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: fmt.Sprintf(`"/item/id"==%q`, totalSession[4].Id)}, - res: &pbs.ListSessionsResponse{Items: totalSession[4:5]}, + name: "Filter To Single Sessions", + req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: fmt.Sprintf(`"/item/id"==%q`, wantSession[4].Id)}, + res: &pbs.ListSessionsResponse{Items: wantSession[4:5]}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { name: "Filter To Many Sessions", @@ -433,17 +460,20 @@ func TestList(t *testing.T) { ScopeId: scope.Global.String(), Recursive: true, Filter: fmt.Sprintf(`"/item/scope/id" matches "^%s"`, pWithSessions.GetPublicId()[:8]), }, - res: &pbs.ListSessionsResponse{Items: wantSession}, + res: &pbs.ListSessionsResponse{Items: wantSession}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { - name: "Filter To Nothing", - req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `"/item/id" == ""`}, - res: &pbs.ListSessionsResponse{}, + name: "Filter To Nothing", + req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `"/item/id" == ""`}, + res: &pbs.ListSessionsResponse{}, + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, { - name: "Filter Bad Format", - req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `//badformat/`}, - err: handlers.InvalidArgumentErrorf("bad format", nil), + name: "Filter Bad Format", + req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `//badformat/`}, + err: handlers.InvalidArgumentErrorf("bad format", nil), + otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}}, }, } for _, tc := range cases { @@ -453,7 +483,14 @@ func TestList(t *testing.T) { require.NoError(err, "Couldn't create new session service.") // Test without anon user - got, gErr := s.ListSessions(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId()), tc.req) + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + got, gErr := s.ListSessions(ctx, tc.req) if tc.err != nil { require.Error(gErr) assert.True(errors.Is(gErr, tc.err), "ListSessions(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) @@ -470,22 +507,21 @@ func TestList(t *testing.T) { } assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "ListSessions(%q) got response %q, wanted %q", tc.req, got, tc.res) - // Test with anon user - got, gErr = s.ListSessions(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId(), auth.WithUserId(auth.AnonymousUserId)), tc.req) + // Test with other user + otherRequestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: atOther.GetPublicId(), + Token: atOther.GetToken(), + } + otherRequestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + otherCtx := auth.NewVerifierContext(otherRequestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &otherRequestInfo) + got, gErr = s.ListSessions(otherCtx, tc.req) require.NoError(gErr) - assert.Len(got.Items, len(tc.res.Items)) - for _, item := range got.GetItems() { - require.Empty(item.Version) - require.Empty(item.UserId) - require.Empty(item.HostId) - require.Empty(item.HostSetId) - require.Empty(item.AuthTokenId) - require.Empty(item.Endpoint) - require.Nil(item.CreatedTime) - require.Nil(item.ExpirationTime) - require.Nil(item.UpdatedTime) - require.Empty(item.Certificate) - require.Empty(item.TerminationReason) + assert.Len(got.Items, len(tc.otherRes.Items)) + for i, wantSess := range tc.otherRes.GetItems() { + assert.True(got.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond) + assert.Equal(0, len(wantSess.GetConnections())) // no connections on list + wantSess.ExpirationTime = got.GetItems()[i].GetExpirationTime() } }) } @@ -525,6 +561,12 @@ func TestCancel(t *testing.T) { sessRepoFn := func() (*session.Repository, error) { return sessRepo, nil } + tokenRepoFn := func() (*authtoken.Repository, error) { + return authtoken.NewRepository(rw, rw, kms) + } + serversRepoFn := func() (*server.Repository, error) { + return server.NewRepository(rw, rw, kms) + } o, p := iam.TestScopes(t, iamRepo) at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId()) @@ -613,16 +655,23 @@ func TestCancel(t *testing.T) { tc.req.Version = version - got, gErr := s.CancelSession(auth.DisabledAuthTestContext(iamRepoFn, tc.scopeId), tc.req) + requestInfo := authpb.RequestInfo{ + TokenFormat: uint32(auth.AuthTokenTypeBearer), + PublicId: at.GetPublicId(), + Token: at.GetToken(), + } + requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{}) + ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo) + got, gErr := s.CancelSession(ctx, tc.req) if tc.err != nil { require.Error(gErr) // It's hard to mix and match api/error package errors right now // so use old/new behavior depending on the type. If validate // gets updated this can be standardized. if errors.Match(errors.T(errors.InvalidSessionState), gErr) { - assert.True(errors.Match(errors.T(tc.err), gErr), "GetSession(%+v) got error %#v, wanted %#v", tc.req, gErr, tc.err) + assert.True(errors.Match(errors.T(tc.err), gErr), "CancelSession(%+v) got error %#v, wanted %#v", tc.req, gErr, tc.err) } else { - assert.True(errors.Is(gErr, tc.err), "GetSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) + assert.True(errors.Is(gErr, tc.err), "CancelSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err) } } @@ -630,6 +679,7 @@ func TestCancel(t *testing.T) { require.Nil(got) return } + require.NotNil(got) tc.res.GetItem().Version = got.GetItem().Version // Compare the new canceling state and then remove it for the rest of the proto comparison @@ -647,7 +697,7 @@ func TestCancel(t *testing.T) { EndTime: got.GetItem().GetUpdatedTime(), }, } - assert.Empty(cmp.Diff(got.GetItem().GetStates(), wantState, protocmp.Transform()), "GetSession(%q) states") + assert.Empty(cmp.Diff(got.GetItem().GetStates(), wantState, protocmp.Transform()), "CancelSession(%q) states") got.GetItem().States = nil got.GetItem().UpdatedTime = nil @@ -656,7 +706,7 @@ func TestCancel(t *testing.T) { tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime() } - assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res) + assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "CancelSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res) if tc.req != nil { require.NotNil(got)