test(sessions): Remove auth.DisabledAuthTestContext from tests

Disabling the auth was resulting in some unrealistic tests setup and
expectations. It also makes it difficult to refactor any behavior in the
authn/authz flow since this option can bypass most of it.
pull/2342/head
Timothy Messier 4 years ago
parent ccb17df01a
commit 02cef3d8a8
No known key found for this signature in database
GPG Key ID: EFD2F184F7600572

@ -19,6 +19,7 @@ import (
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/requests"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/target"
@ -32,7 +33,7 @@ import (
"google.golang.org/protobuf/testing/protocmp"
)
var testAuthorizedActions = []string{"no-op", "read", "read:self", "cancel", "cancel:self"}
var testAuthorizedActions = []string{"read:self", "cancel:self"}
func TestGetSession(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
@ -42,6 +43,7 @@ func TestGetSession(t *testing.T) {
iamRepo := iam.TestRepo(t, conn, wrap)
rw := db.New(conn)
sessRepo, err := session.NewRepository(rw, rw, kms)
require.NoError(t, err)
@ -51,6 +53,12 @@ func TestGetSession(t *testing.T) {
sessRepoFn := func() (*session.Repository, error) {
return sessRepo, nil
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(rw, rw, kms)
}
o, p := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
@ -106,7 +114,7 @@ func TestGetSession(t *testing.T) {
res: &pbs.GetSessionResponse{Item: wireSess},
},
{
name: "Get a non existant Session",
name: "Get a non existent Session",
req: &pbs.GetSessionRequest{Id: session.SessionPrefix + "_DoesntExis"},
res: nil,
err: handlers.ApiErrorWithCode(codes.NotFound),
@ -131,7 +139,15 @@ func TestGetSession(t *testing.T) {
s, err := sessions.NewService(sessRepoFn, iamRepoFn)
require.NoError(err, "Couldn't create new session service.")
got, gErr := s.GetSession(auth.DisabledAuthTestContext(iamRepoFn, tc.scopeId), tc.req)
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.GetSession(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "GetSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
@ -140,7 +156,7 @@ func TestGetSession(t *testing.T) {
assert.True(got.GetItem().GetExpirationTime().AsTime().Sub(tc.res.GetItem().GetExpirationTime().AsTime()) < 10*time.Millisecond)
tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime()
}
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
assert.Empty(cmp.Diff(tc.res, got, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
})
}
}
@ -246,6 +262,12 @@ func TestList(t *testing.T) {
sessRepoFn := func() (*session.Repository, error) {
return sessRepo, nil
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(rw, rw, kms)
}
_, pNoSessions := iam.TestScopes(t, iamRepo)
o, pWithSessions := iam.TestScopes(t, iamRepo)
@ -270,7 +292,7 @@ func TestList(t *testing.T) {
tarOther := tcp.TestTarget(ctx, t, conn, pWithOtherSessions.GetPublicId(), "test", target.WithHostSources([]string{hsOther.GetPublicId()}))
var wantSession []*pb.Session
var totalSession []*pb.Session
var wantOtherSession []*pb.Session
var wantIncludeTerminatedSessions []*pb.Session
for i := 0; i < 10; i++ {
sess := session.TestSession(t, conn, wrap, session.ComposedOf{
@ -309,7 +331,6 @@ func TestList(t *testing.T) {
Connections: []*pb.Connection{}, // connections should not be returned for list
})
totalSession = append(totalSession, wantSession[i])
wantIncludeTerminatedSessions = append(wantIncludeTerminatedSessions, wantSession[i])
sess = session.TestSession(t, conn, wrap, session.ComposedOf{
@ -326,11 +347,11 @@ func TestList(t *testing.T) {
status, states = convertStates(sess.States)
totalSession = append(totalSession, &pb.Session{
wantOtherSession = append(wantOtherSession, &pb.Session{
Id: sess.GetPublicId(),
ScopeId: pWithOtherSessions.GetPublicId(),
AuthTokenId: atOther.GetPublicId(),
UserId: atOther.GetIamUserId(),
ScopeId: pWithSessions.GetPublicId(),
AuthTokenId: at.GetPublicId(),
UserId: at.GetIamUserId(),
TargetId: sess.TargetId,
Endpoint: sess.Endpoint,
HostSetId: sess.HostSetId,
@ -339,7 +360,7 @@ func TestList(t *testing.T) {
UpdatedTime: sess.UpdateTime.GetTimestamp(),
CreatedTime: sess.CreateTime.GetTimestamp(),
ExpirationTime: sess.ExpirationTime.GetTimestamp(),
Scope: &scopes.ScopeInfo{Id: pWithOtherSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: oOther.GetPublicId()},
Scope: &scopes.ScopeInfo{Id: pWithSessions.GetPublicId(), Type: scope.Project.String(), ParentScopeId: o.GetPublicId()},
Status: status,
States: states,
Certificate: sess.Certificate,
@ -397,35 +418,41 @@ func TestList(t *testing.T) {
}
cases := []struct {
name string
req *pbs.ListSessionsRequest
res *pbs.ListSessionsResponse
err error
name string
req *pbs.ListSessionsRequest
res *pbs.ListSessionsResponse
otherRes *pbs.ListSessionsResponse
err error
}{
{
name: "List Many Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{Items: wantSession},
name: "List Many Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "List Many Include Terminated",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true},
res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions},
name: "List Many Include Terminated",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), IncludeTerminated: true},
res: &pbs.ListSessionsResponse{Items: wantIncludeTerminatedSessions},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "List No Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{},
name: "List No Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pNoSessions.GetPublicId()},
res: &pbs.ListSessionsResponse{},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "List Sessions Recursively",
req: &pbs.ListSessionsRequest{ScopeId: scope.Global.String(), Recursive: true},
res: &pbs.ListSessionsResponse{Items: totalSession},
name: "List Sessions Recursively",
req: &pbs.ListSessionsRequest{ScopeId: scope.Global.String(), Recursive: true},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: wantOtherSession},
},
{
name: "Filter To Single Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: fmt.Sprintf(`"/item/id"==%q`, totalSession[4].Id)},
res: &pbs.ListSessionsResponse{Items: totalSession[4:5]},
name: "Filter To Single Sessions",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: fmt.Sprintf(`"/item/id"==%q`, wantSession[4].Id)},
res: &pbs.ListSessionsResponse{Items: wantSession[4:5]},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "Filter To Many Sessions",
@ -433,17 +460,20 @@ func TestList(t *testing.T) {
ScopeId: scope.Global.String(), Recursive: true,
Filter: fmt.Sprintf(`"/item/scope/id" matches "^%s"`, pWithSessions.GetPublicId()[:8]),
},
res: &pbs.ListSessionsResponse{Items: wantSession},
res: &pbs.ListSessionsResponse{Items: wantSession},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "Filter To Nothing",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `"/item/id" == ""`},
res: &pbs.ListSessionsResponse{},
name: "Filter To Nothing",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `"/item/id" == ""`},
res: &pbs.ListSessionsResponse{},
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
{
name: "Filter Bad Format",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `//badformat/`},
err: handlers.InvalidArgumentErrorf("bad format", nil),
name: "Filter Bad Format",
req: &pbs.ListSessionsRequest{ScopeId: pWithSessions.GetPublicId(), Filter: `//badformat/`},
err: handlers.InvalidArgumentErrorf("bad format", nil),
otherRes: &pbs.ListSessionsResponse{Items: []*pb.Session{}},
},
}
for _, tc := range cases {
@ -453,7 +483,14 @@ func TestList(t *testing.T) {
require.NoError(err, "Couldn't create new session service.")
// Test without anon user
got, gErr := s.ListSessions(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId()), tc.req)
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.ListSessions(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
assert.True(errors.Is(gErr, tc.err), "ListSessions(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
@ -470,22 +507,21 @@ func TestList(t *testing.T) {
}
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "ListSessions(%q) got response %q, wanted %q", tc.req, got, tc.res)
// Test with anon user
got, gErr = s.ListSessions(auth.DisabledAuthTestContext(iamRepoFn, tc.req.GetScopeId(), auth.WithUserId(auth.AnonymousUserId)), tc.req)
// Test with other user
otherRequestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: atOther.GetPublicId(),
Token: atOther.GetToken(),
}
otherRequestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
otherCtx := auth.NewVerifierContext(otherRequestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &otherRequestInfo)
got, gErr = s.ListSessions(otherCtx, tc.req)
require.NoError(gErr)
assert.Len(got.Items, len(tc.res.Items))
for _, item := range got.GetItems() {
require.Empty(item.Version)
require.Empty(item.UserId)
require.Empty(item.HostId)
require.Empty(item.HostSetId)
require.Empty(item.AuthTokenId)
require.Empty(item.Endpoint)
require.Nil(item.CreatedTime)
require.Nil(item.ExpirationTime)
require.Nil(item.UpdatedTime)
require.Empty(item.Certificate)
require.Empty(item.TerminationReason)
assert.Len(got.Items, len(tc.otherRes.Items))
for i, wantSess := range tc.otherRes.GetItems() {
assert.True(got.GetItems()[i].GetExpirationTime().AsTime().Sub(wantSess.GetExpirationTime().AsTime()) < 10*time.Millisecond)
assert.Equal(0, len(wantSess.GetConnections())) // no connections on list
wantSess.ExpirationTime = got.GetItems()[i].GetExpirationTime()
}
})
}
@ -525,6 +561,12 @@ func TestCancel(t *testing.T) {
sessRepoFn := func() (*session.Repository, error) {
return sessRepo, nil
}
tokenRepoFn := func() (*authtoken.Repository, error) {
return authtoken.NewRepository(rw, rw, kms)
}
serversRepoFn := func() (*server.Repository, error) {
return server.NewRepository(rw, rw, kms)
}
o, p := iam.TestScopes(t, iamRepo)
at := authtoken.TestAuthToken(t, conn, kms, o.GetPublicId())
@ -613,16 +655,23 @@ func TestCancel(t *testing.T) {
tc.req.Version = version
got, gErr := s.CancelSession(auth.DisabledAuthTestContext(iamRepoFn, tc.scopeId), tc.req)
requestInfo := authpb.RequestInfo{
TokenFormat: uint32(auth.AuthTokenTypeBearer),
PublicId: at.GetPublicId(),
Token: at.GetToken(),
}
requestContext := context.WithValue(context.Background(), requests.ContextRequestInformationKey, &requests.RequestContext{})
ctx := auth.NewVerifierContext(requestContext, iamRepoFn, tokenRepoFn, serversRepoFn, kms, &requestInfo)
got, gErr := s.CancelSession(ctx, tc.req)
if tc.err != nil {
require.Error(gErr)
// It's hard to mix and match api/error package errors right now
// so use old/new behavior depending on the type. If validate
// gets updated this can be standardized.
if errors.Match(errors.T(errors.InvalidSessionState), gErr) {
assert.True(errors.Match(errors.T(tc.err), gErr), "GetSession(%+v) got error %#v, wanted %#v", tc.req, gErr, tc.err)
assert.True(errors.Match(errors.T(tc.err), gErr), "CancelSession(%+v) got error %#v, wanted %#v", tc.req, gErr, tc.err)
} else {
assert.True(errors.Is(gErr, tc.err), "GetSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
assert.True(errors.Is(gErr, tc.err), "CancelSession(%+v) got error %v, wanted %v", tc.req, gErr, tc.err)
}
}
@ -630,6 +679,7 @@ func TestCancel(t *testing.T) {
require.Nil(got)
return
}
require.NotNil(got)
tc.res.GetItem().Version = got.GetItem().Version
// Compare the new canceling state and then remove it for the rest of the proto comparison
@ -647,7 +697,7 @@ func TestCancel(t *testing.T) {
EndTime: got.GetItem().GetUpdatedTime(),
},
}
assert.Empty(cmp.Diff(got.GetItem().GetStates(), wantState, protocmp.Transform()), "GetSession(%q) states")
assert.Empty(cmp.Diff(got.GetItem().GetStates(), wantState, protocmp.Transform()), "CancelSession(%q) states")
got.GetItem().States = nil
got.GetItem().UpdatedTime = nil
@ -656,7 +706,7 @@ func TestCancel(t *testing.T) {
tc.res.GetItem().ExpirationTime = got.GetItem().GetExpirationTime()
}
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "GetSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
assert.Empty(cmp.Diff(got, tc.res, protocmp.Transform()), "CancelSession(%q) got response\n%q, wanted\n%q", tc.req, got, tc.res)
if tc.req != nil {
require.NotNil(got)

Loading…
Cancel
Save