@ -19,6 +19,7 @@ import (
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/requests"
"github.com/hashicorp/boundary/internal/server"
"github.com/hashicorp/boundary/internal/session"
"github.com/hashicorp/boundary/internal/target"
@ -32,7 +33,7 @@ import (
"google.golang.org/protobuf/testing/protocmp"
)
var testAuthorizedActions = [ ] string { " no-op", "read" , " read:self", "cancel ", "cancel:self" }
var testAuthorizedActions = [ ] string { " read:self", "cancel:self" }
func TestGetSession ( t * testing . T ) {
conn , _ := db . TestSetup ( t , "postgres" )
@ -42,6 +43,7 @@ func TestGetSession(t *testing.T) {
iamRepo := iam . TestRepo ( t , conn , wrap )
rw := db . New ( conn )
sessRepo , err := session . NewRepository ( rw , rw , kms )
require . NoError ( t , err )
@ -51,6 +53,12 @@ func TestGetSession(t *testing.T) {
sessRepoFn := func ( ) ( * session . Repository , error ) {
return sessRepo , nil
}
tokenRepoFn := func ( ) ( * authtoken . Repository , error ) {
return authtoken . NewRepository ( rw , rw , kms )
}
serversRepoFn := func ( ) ( * server . Repository , error ) {
return server . NewRepository ( rw , rw , kms )
}
o , p := iam . TestScopes ( t , iamRepo )
at := authtoken . TestAuthToken ( t , conn , kms , o . GetPublicId ( ) )
@ -106,7 +114,7 @@ func TestGetSession(t *testing.T) {
res : & pbs . GetSessionResponse { Item : wireSess } ,
} ,
{
name : "Get a non exist a nt Session",
name : "Get a non exist e nt Session",
req : & pbs . GetSessionRequest { Id : session . SessionPrefix + "_DoesntExis" } ,
res : nil ,
err : handlers . ApiErrorWithCode ( codes . NotFound ) ,
@ -131,7 +139,15 @@ func TestGetSession(t *testing.T) {
s , err := sessions . NewService ( sessRepoFn , iamRepoFn )
require . NoError ( err , "Couldn't create new session service." )
got , gErr := s . GetSession ( auth . DisabledAuthTestContext ( iamRepoFn , tc . scopeId ) , tc . req )
requestInfo := authpb . RequestInfo {
TokenFormat : uint32 ( auth . AuthTokenTypeBearer ) ,
PublicId : at . GetPublicId ( ) ,
Token : at . GetToken ( ) ,
}
requestContext := context . WithValue ( context . Background ( ) , requests . ContextRequestInformationKey , & requests . RequestContext { } )
ctx := auth . NewVerifierContext ( requestContext , iamRepoFn , tokenRepoFn , serversRepoFn , kms , & requestInfo )
got , gErr := s . GetSession ( ctx , tc . req )
if tc . err != nil {
require . Error ( gErr )
assert . True ( errors . Is ( gErr , tc . err ) , "GetSession(%+v) got error %v, wanted %v" , tc . req , gErr , tc . err )
@ -140,7 +156,7 @@ func TestGetSession(t *testing.T) {
assert . True ( got . GetItem ( ) . GetExpirationTime ( ) . AsTime ( ) . Sub ( tc . res . GetItem ( ) . GetExpirationTime ( ) . AsTime ( ) ) < 10 * time . Millisecond )
tc . res . GetItem ( ) . ExpirationTime = got . GetItem ( ) . GetExpirationTime ( )
}
assert . Empty ( cmp . Diff ( got, tc. res , protocmp . Transform ( ) ) , "GetSession(%q) got response\n%q, wanted\n%q" , tc . req , got , tc . res )
assert . Empty ( cmp . Diff ( tc. res , got , protocmp . Transform ( ) ) , "GetSession(%q) got response\n%q, wanted\n%q" , tc . req , got , tc . res )
} )
}
}
@ -246,6 +262,12 @@ func TestList(t *testing.T) {
sessRepoFn := func ( ) ( * session . Repository , error ) {
return sessRepo , nil
}
tokenRepoFn := func ( ) ( * authtoken . Repository , error ) {
return authtoken . NewRepository ( rw , rw , kms )
}
serversRepoFn := func ( ) ( * server . Repository , error ) {
return server . NewRepository ( rw , rw , kms )
}
_ , pNoSessions := iam . TestScopes ( t , iamRepo )
o , pWithSessions := iam . TestScopes ( t , iamRepo )
@ -270,7 +292,7 @@ func TestList(t *testing.T) {
tarOther := tcp . TestTarget ( ctx , t , conn , pWithOtherSessions . GetPublicId ( ) , "test" , target . WithHostSources ( [ ] string { hsOther . GetPublicId ( ) } ) )
var wantSession [ ] * pb . Session
var total Session [ ] * pb . Session
var wantOther Session [ ] * pb . Session
var wantIncludeTerminatedSessions [ ] * pb . Session
for i := 0 ; i < 10 ; i ++ {
sess := session . TestSession ( t , conn , wrap , session . ComposedOf {
@ -309,7 +331,6 @@ func TestList(t *testing.T) {
Connections : [ ] * pb . Connection { } , // connections should not be returned for list
} )
totalSession = append ( totalSession , wantSession [ i ] )
wantIncludeTerminatedSessions = append ( wantIncludeTerminatedSessions , wantSession [ i ] )
sess = session . TestSession ( t , conn , wrap , session . ComposedOf {
@ -326,11 +347,11 @@ func TestList(t *testing.T) {
status , states = convertStates ( sess . States )
totalSession = append ( total Session, & pb . Session {
wantOtherSession = append ( wantOther Session, & pb . Session {
Id : sess . GetPublicId ( ) ,
ScopeId : pWith Other Sessions. GetPublicId ( ) ,
AuthTokenId : at Other . GetPublicId ( ) ,
UserId : at Other . GetIamUserId ( ) ,
ScopeId : pWith Sessions. GetPublicId ( ) ,
AuthTokenId : at . GetPublicId ( ) ,
UserId : at . GetIamUserId ( ) ,
TargetId : sess . TargetId ,
Endpoint : sess . Endpoint ,
HostSetId : sess . HostSetId ,
@ -339,7 +360,7 @@ func TestList(t *testing.T) {
UpdatedTime : sess . UpdateTime . GetTimestamp ( ) ,
CreatedTime : sess . CreateTime . GetTimestamp ( ) ,
ExpirationTime : sess . ExpirationTime . GetTimestamp ( ) ,
Scope : & scopes . ScopeInfo { Id : pWith Other Sessions. GetPublicId ( ) , Type : scope . Project . String ( ) , ParentScopeId : o Other . GetPublicId ( ) } ,
Scope : & scopes . ScopeInfo { Id : pWith Sessions. GetPublicId ( ) , Type : scope . Project . String ( ) , ParentScopeId : o . GetPublicId ( ) } ,
Status : status ,
States : states ,
Certificate : sess . Certificate ,
@ -397,35 +418,41 @@ func TestList(t *testing.T) {
}
cases := [ ] struct {
name string
req * pbs . ListSessionsRequest
res * pbs . ListSessionsResponse
err error
name string
req * pbs . ListSessionsRequest
res * pbs . ListSessionsResponse
otherRes * pbs . ListSessionsResponse
err error
} {
{
name : "List Many Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) } ,
res : & pbs . ListSessionsResponse { Items : wantSession } ,
name : "List Many Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) } ,
res : & pbs . ListSessionsResponse { Items : wantSession } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "List Many Include Terminated" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , IncludeTerminated : true } ,
res : & pbs . ListSessionsResponse { Items : wantIncludeTerminatedSessions } ,
name : "List Many Include Terminated" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , IncludeTerminated : true } ,
res : & pbs . ListSessionsResponse { Items : wantIncludeTerminatedSessions } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "List No Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pNoSessions . GetPublicId ( ) } ,
res : & pbs . ListSessionsResponse { } ,
name : "List No Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pNoSessions . GetPublicId ( ) } ,
res : & pbs . ListSessionsResponse { } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "List Sessions Recursively" ,
req : & pbs . ListSessionsRequest { ScopeId : scope . Global . String ( ) , Recursive : true } ,
res : & pbs . ListSessionsResponse { Items : totalSession } ,
name : "List Sessions Recursively" ,
req : & pbs . ListSessionsRequest { ScopeId : scope . Global . String ( ) , Recursive : true } ,
res : & pbs . ListSessionsResponse { Items : wantSession } ,
otherRes : & pbs . ListSessionsResponse { Items : wantOtherSession } ,
} ,
{
name : "Filter To Single Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : fmt . Sprintf ( ` "/item/id"==%q ` , totalSession [ 4 ] . Id ) } ,
res : & pbs . ListSessionsResponse { Items : totalSession [ 4 : 5 ] } ,
name : "Filter To Single Sessions" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : fmt . Sprintf ( ` "/item/id"==%q ` , wantSession [ 4 ] . Id ) } ,
res : & pbs . ListSessionsResponse { Items : wantSession [ 4 : 5 ] } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "Filter To Many Sessions" ,
@ -433,17 +460,20 @@ func TestList(t *testing.T) {
ScopeId : scope . Global . String ( ) , Recursive : true ,
Filter : fmt . Sprintf ( ` "/item/scope/id" matches "^%s" ` , pWithSessions . GetPublicId ( ) [ : 8 ] ) ,
} ,
res : & pbs . ListSessionsResponse { Items : wantSession } ,
res : & pbs . ListSessionsResponse { Items : wantSession } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "Filter To Nothing" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : ` "/item/id" == "" ` } ,
res : & pbs . ListSessionsResponse { } ,
name : "Filter To Nothing" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : ` "/item/id" == "" ` } ,
res : & pbs . ListSessionsResponse { } ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
{
name : "Filter Bad Format" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : ` //badformat/ ` } ,
err : handlers . InvalidArgumentErrorf ( "bad format" , nil ) ,
name : "Filter Bad Format" ,
req : & pbs . ListSessionsRequest { ScopeId : pWithSessions . GetPublicId ( ) , Filter : ` //badformat/ ` } ,
err : handlers . InvalidArgumentErrorf ( "bad format" , nil ) ,
otherRes : & pbs . ListSessionsResponse { Items : [ ] * pb . Session { } } ,
} ,
}
for _ , tc := range cases {
@ -453,7 +483,14 @@ func TestList(t *testing.T) {
require . NoError ( err , "Couldn't create new session service." )
// Test without anon user
got , gErr := s . ListSessions ( auth . DisabledAuthTestContext ( iamRepoFn , tc . req . GetScopeId ( ) ) , tc . req )
requestInfo := authpb . RequestInfo {
TokenFormat : uint32 ( auth . AuthTokenTypeBearer ) ,
PublicId : at . GetPublicId ( ) ,
Token : at . GetToken ( ) ,
}
requestContext := context . WithValue ( context . Background ( ) , requests . ContextRequestInformationKey , & requests . RequestContext { } )
ctx := auth . NewVerifierContext ( requestContext , iamRepoFn , tokenRepoFn , serversRepoFn , kms , & requestInfo )
got , gErr := s . ListSessions ( ctx , tc . req )
if tc . err != nil {
require . Error ( gErr )
assert . True ( errors . Is ( gErr , tc . err ) , "ListSessions(%+v) got error %v, wanted %v" , tc . req , gErr , tc . err )
@ -470,22 +507,21 @@ func TestList(t *testing.T) {
}
assert . Empty ( cmp . Diff ( got , tc . res , protocmp . Transform ( ) ) , "ListSessions(%q) got response %q, wanted %q" , tc . req , got , tc . res )
// Test with anon user
got , gErr = s . ListSessions ( auth . DisabledAuthTestContext ( iamRepoFn , tc . req . GetScopeId ( ) , auth . WithUserId ( auth . AnonymousUserId ) ) , tc . req )
// Test with other user
otherRequestInfo := authpb . RequestInfo {
TokenFormat : uint32 ( auth . AuthTokenTypeBearer ) ,
PublicId : atOther . GetPublicId ( ) ,
Token : atOther . GetToken ( ) ,
}
otherRequestContext := context . WithValue ( context . Background ( ) , requests . ContextRequestInformationKey , & requests . RequestContext { } )
otherCtx := auth . NewVerifierContext ( otherRequestContext , iamRepoFn , tokenRepoFn , serversRepoFn , kms , & otherRequestInfo )
got , gErr = s . ListSessions ( otherCtx , tc . req )
require . NoError ( gErr )
assert . Len ( got . Items , len ( tc . res . Items ) )
for _ , item := range got . GetItems ( ) {
require . Empty ( item . Version )
require . Empty ( item . UserId )
require . Empty ( item . HostId )
require . Empty ( item . HostSetId )
require . Empty ( item . AuthTokenId )
require . Empty ( item . Endpoint )
require . Nil ( item . CreatedTime )
require . Nil ( item . ExpirationTime )
require . Nil ( item . UpdatedTime )
require . Empty ( item . Certificate )
require . Empty ( item . TerminationReason )
assert . Len ( got . Items , len ( tc . otherRes . Items ) )
for i , wantSess := range tc . otherRes . GetItems ( ) {
assert . True ( got . GetItems ( ) [ i ] . GetExpirationTime ( ) . AsTime ( ) . Sub ( wantSess . GetExpirationTime ( ) . AsTime ( ) ) < 10 * time . Millisecond )
assert . Equal ( 0 , len ( wantSess . GetConnections ( ) ) ) // no connections on list
wantSess . ExpirationTime = got . GetItems ( ) [ i ] . GetExpirationTime ( )
}
} )
}
@ -525,6 +561,12 @@ func TestCancel(t *testing.T) {
sessRepoFn := func ( ) ( * session . Repository , error ) {
return sessRepo , nil
}
tokenRepoFn := func ( ) ( * authtoken . Repository , error ) {
return authtoken . NewRepository ( rw , rw , kms )
}
serversRepoFn := func ( ) ( * server . Repository , error ) {
return server . NewRepository ( rw , rw , kms )
}
o , p := iam . TestScopes ( t , iamRepo )
at := authtoken . TestAuthToken ( t , conn , kms , o . GetPublicId ( ) )
@ -613,16 +655,23 @@ func TestCancel(t *testing.T) {
tc . req . Version = version
got , gErr := s . CancelSession ( auth . DisabledAuthTestContext ( iamRepoFn , tc . scopeId ) , tc . req )
requestInfo := authpb . RequestInfo {
TokenFormat : uint32 ( auth . AuthTokenTypeBearer ) ,
PublicId : at . GetPublicId ( ) ,
Token : at . GetToken ( ) ,
}
requestContext := context . WithValue ( context . Background ( ) , requests . ContextRequestInformationKey , & requests . RequestContext { } )
ctx := auth . NewVerifierContext ( requestContext , iamRepoFn , tokenRepoFn , serversRepoFn , kms , & requestInfo )
got , gErr := s . CancelSession ( ctx , tc . req )
if tc . err != nil {
require . Error ( gErr )
// It's hard to mix and match api/error package errors right now
// so use old/new behavior depending on the type. If validate
// gets updated this can be standardized.
if errors . Match ( errors . T ( errors . InvalidSessionState ) , gErr ) {
assert . True ( errors . Match ( errors . T ( tc . err ) , gErr ) , "GetSession(%+v) got error %#v, wanted %#v" , tc . req , gErr , tc . err )
assert . True ( errors . Match ( errors . T ( tc . err ) , gErr ) , " Cancel Session(%+v) got error %#v, wanted %#v", tc . req , gErr , tc . err )
} else {
assert . True ( errors . Is ( gErr , tc . err ) , " Get Session(%+v) got error %v, wanted %v", tc . req , gErr , tc . err )
assert . True ( errors . Is ( gErr , tc . err ) , " Cancel Session(%+v) got error %v, wanted %v", tc . req , gErr , tc . err )
}
}
@ -630,6 +679,7 @@ func TestCancel(t *testing.T) {
require . Nil ( got )
return
}
require . NotNil ( got )
tc . res . GetItem ( ) . Version = got . GetItem ( ) . Version
// Compare the new canceling state and then remove it for the rest of the proto comparison
@ -647,7 +697,7 @@ func TestCancel(t *testing.T) {
EndTime : got . GetItem ( ) . GetUpdatedTime ( ) ,
} ,
}
assert . Empty ( cmp . Diff ( got . GetItem ( ) . GetStates ( ) , wantState , protocmp . Transform ( ) ) , " Get Session(%q) states")
assert . Empty ( cmp . Diff ( got . GetItem ( ) . GetStates ( ) , wantState , protocmp . Transform ( ) ) , " Cancel Session(%q) states")
got . GetItem ( ) . States = nil
got . GetItem ( ) . UpdatedTime = nil
@ -656,7 +706,7 @@ func TestCancel(t *testing.T) {
tc . res . GetItem ( ) . ExpirationTime = got . GetItem ( ) . GetExpirationTime ( )
}
assert . Empty ( cmp . Diff ( got , tc . res , protocmp . Transform ( ) ) , " Get Session(%q) got response\n%q, wanted\n%q", tc . req , got , tc . res )
assert . Empty ( cmp . Diff ( got , tc . res , protocmp . Transform ( ) ) , " Cancel Session(%q) got response\n%q, wanted\n%q", tc . req , got , tc . res )
if tc . req != nil {
require . NotNil ( got )