mirror of https://github.com/hashicorp/boundary
commit
58aca600a4
@ -0,0 +1,44 @@
|
||||
begin;
|
||||
|
||||
create table session_credential (
|
||||
session_id wt_public_id not null
|
||||
constraint session_fkey
|
||||
references session (public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
credential bytea not null -- encrypted
|
||||
constraint credential_must_not_be_empty
|
||||
check(length(credential) > 0),
|
||||
key_id text not null
|
||||
constraint kms_database_key_version_fkey
|
||||
references kms_database_key_version (private_id)
|
||||
on delete restrict
|
||||
on update cascade,
|
||||
constraint session_credential_session_id_credential_uq
|
||||
unique(session_id, credential)
|
||||
);
|
||||
comment on table session_credential is
|
||||
'session_credential is a table where each row contains a credential to be used by '
|
||||
'by a worker when a connection is established for the session_id.';
|
||||
|
||||
create trigger immutable_columns before update on session_credential
|
||||
for each row execute procedure immutable_columns('session_id', 'credential', 'key_id');
|
||||
|
||||
-- delete_credentials deletes all credentials for a session when the
|
||||
-- session enters the canceling or terminated states.
|
||||
create function delete_session_credentials()
|
||||
returns trigger
|
||||
as $$
|
||||
begin
|
||||
if new.state in ('canceling', 'terminated') then
|
||||
delete from session_credential
|
||||
where session_id = new.session_id;
|
||||
end if;
|
||||
return new;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger delete_session_credentials after insert on session_state
|
||||
for each row execute procedure delete_session_credentials();
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,32 @@
|
||||
-- session_credential tests:
|
||||
-- the following triggers
|
||||
-- delete_session_credentials
|
||||
|
||||
begin;
|
||||
select plan(4);
|
||||
select wtt_load('widgets', 'iam', 'kms', 'auth');
|
||||
|
||||
-- validate the setup data
|
||||
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
|
||||
|
||||
insert into session_credential
|
||||
( session_id , credential , key_id)
|
||||
values
|
||||
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
|
||||
|
||||
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
-- validate the delete triggers
|
||||
prepare delete_session_credentials as
|
||||
insert into session_state
|
||||
(session_id , state)
|
||||
values
|
||||
('s1_____clare' , 'canceling');
|
||||
select lives_ok('delete_session_credentials');
|
||||
|
||||
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
select * from finish();
|
||||
rollback;
|
||||
@ -0,0 +1,32 @@
|
||||
-- session_credential tests:
|
||||
-- the following triggers
|
||||
-- delete_session_credentials
|
||||
|
||||
begin;
|
||||
select plan(4);
|
||||
select wtt_load('widgets', 'iam', 'kms', 'auth');
|
||||
|
||||
-- validate the setup data
|
||||
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
|
||||
|
||||
insert into session_credential
|
||||
( session_id , credential , key_id)
|
||||
values
|
||||
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
|
||||
|
||||
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
-- validate the delete triggers
|
||||
prepare delete_session_credentials as
|
||||
insert into session_state
|
||||
(session_id , state)
|
||||
values
|
||||
('s1_____clare' , 'terminated');
|
||||
select lives_ok('delete_session_credentials');
|
||||
|
||||
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
select * from finish();
|
||||
rollback;
|
||||
@ -0,0 +1,253 @@
|
||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||
// versions:
|
||||
// protoc-gen-go v1.27.1
|
||||
// protoc v3.17.3
|
||||
// source: controller/servers/services/v1/credential.proto
|
||||
|
||||
package services
|
||||
|
||||
import (
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
|
||||
reflect "reflect"
|
||||
sync "sync"
|
||||
)
|
||||
|
||||
const (
|
||||
// Verify that this generated code is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
|
||||
// Verify that runtime/protoimpl is sufficiently up-to-date.
|
||||
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
|
||||
)
|
||||
|
||||
type Credential struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// Types that are assignable to Credential:
|
||||
// *Credential_UserPassword
|
||||
Credential isCredential_Credential `protobuf_oneof:"credential"`
|
||||
}
|
||||
|
||||
func (x *Credential) Reset() {
|
||||
*x = Credential{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_controller_servers_services_v1_credential_proto_msgTypes[0]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *Credential) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*Credential) ProtoMessage() {}
|
||||
|
||||
func (x *Credential) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_controller_servers_services_v1_credential_proto_msgTypes[0]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use Credential.ProtoReflect.Descriptor instead.
|
||||
func (*Credential) Descriptor() ([]byte, []int) {
|
||||
return file_controller_servers_services_v1_credential_proto_rawDescGZIP(), []int{0}
|
||||
}
|
||||
|
||||
func (m *Credential) GetCredential() isCredential_Credential {
|
||||
if m != nil {
|
||||
return m.Credential
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *Credential) GetUserPassword() *UserPassword {
|
||||
if x, ok := x.GetCredential().(*Credential_UserPassword); ok {
|
||||
return x.UserPassword
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type isCredential_Credential interface {
|
||||
isCredential_Credential()
|
||||
}
|
||||
|
||||
type Credential_UserPassword struct {
|
||||
UserPassword *UserPassword `protobuf:"bytes,1,opt,name=user_password,json=userPassword,proto3,oneof"`
|
||||
}
|
||||
|
||||
func (*Credential_UserPassword) isCredential_Credential() {}
|
||||
|
||||
// UserPassword is a credential containing a username and a password.
|
||||
type UserPassword struct {
|
||||
state protoimpl.MessageState
|
||||
sizeCache protoimpl.SizeCache
|
||||
unknownFields protoimpl.UnknownFields
|
||||
|
||||
// The username of the credential
|
||||
Username string `protobuf:"bytes,10,opt,name=username,proto3" json:"username,omitempty"` // @gotags: `class:"public"`
|
||||
// The password of the credential
|
||||
Password string `protobuf:"bytes,20,opt,name=password,proto3" json:"password,omitempty"` // @gotags: `class:"secret"`
|
||||
}
|
||||
|
||||
func (x *UserPassword) Reset() {
|
||||
*x = UserPassword{}
|
||||
if protoimpl.UnsafeEnabled {
|
||||
mi := &file_controller_servers_services_v1_credential_proto_msgTypes[1]
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
}
|
||||
|
||||
func (x *UserPassword) String() string {
|
||||
return protoimpl.X.MessageStringOf(x)
|
||||
}
|
||||
|
||||
func (*UserPassword) ProtoMessage() {}
|
||||
|
||||
func (x *UserPassword) ProtoReflect() protoreflect.Message {
|
||||
mi := &file_controller_servers_services_v1_credential_proto_msgTypes[1]
|
||||
if protoimpl.UnsafeEnabled && x != nil {
|
||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||
if ms.LoadMessageInfo() == nil {
|
||||
ms.StoreMessageInfo(mi)
|
||||
}
|
||||
return ms
|
||||
}
|
||||
return mi.MessageOf(x)
|
||||
}
|
||||
|
||||
// Deprecated: Use UserPassword.ProtoReflect.Descriptor instead.
|
||||
func (*UserPassword) Descriptor() ([]byte, []int) {
|
||||
return file_controller_servers_services_v1_credential_proto_rawDescGZIP(), []int{1}
|
||||
}
|
||||
|
||||
func (x *UserPassword) GetUsername() string {
|
||||
if x != nil {
|
||||
return x.Username
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (x *UserPassword) GetPassword() string {
|
||||
if x != nil {
|
||||
return x.Password
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var File_controller_servers_services_v1_credential_proto protoreflect.FileDescriptor
|
||||
|
||||
var file_controller_servers_services_v1_credential_proto_rawDesc = []byte{
|
||||
0x0a, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x65, 0x72,
|
||||
0x76, 0x65, 0x72, 0x73, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2f, 0x76, 0x31,
|
||||
0x2f, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74,
|
||||
0x6f, 0x12, 0x1e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x65,
|
||||
0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76,
|
||||
0x31, 0x22, 0x6f, 0x0a, 0x0a, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12,
|
||||
0x53, 0x0a, 0x0d, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64,
|
||||
0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2c, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c,
|
||||
0x6c, 0x65, 0x72, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x73, 0x65, 0x72, 0x76,
|
||||
0x69, 0x63, 0x65, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x50, 0x61, 0x73, 0x73,
|
||||
0x77, 0x6f, 0x72, 0x64, 0x48, 0x00, 0x52, 0x0c, 0x75, 0x73, 0x65, 0x72, 0x50, 0x61, 0x73, 0x73,
|
||||
0x77, 0x6f, 0x72, 0x64, 0x42, 0x0c, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69,
|
||||
0x61, 0x6c, 0x22, 0x46, 0x0a, 0x0c, 0x55, 0x73, 0x65, 0x72, 0x50, 0x61, 0x73, 0x73, 0x77, 0x6f,
|
||||
0x72, 0x64, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x0a,
|
||||
0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a,
|
||||
0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x14, 0x20, 0x01, 0x28, 0x09,
|
||||
0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x42, 0x51, 0x5a, 0x4f, 0x67, 0x69,
|
||||
0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f,
|
||||
0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e, 0x74, 0x65,
|
||||
0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c,
|
||||
0x6c, 0x65, 0x72, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2f, 0x73, 0x65, 0x72, 0x76,
|
||||
0x69, 0x63, 0x65, 0x73, 0x3b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x62, 0x06, 0x70,
|
||||
0x72, 0x6f, 0x74, 0x6f, 0x33,
|
||||
}
|
||||
|
||||
var (
|
||||
file_controller_servers_services_v1_credential_proto_rawDescOnce sync.Once
|
||||
file_controller_servers_services_v1_credential_proto_rawDescData = file_controller_servers_services_v1_credential_proto_rawDesc
|
||||
)
|
||||
|
||||
func file_controller_servers_services_v1_credential_proto_rawDescGZIP() []byte {
|
||||
file_controller_servers_services_v1_credential_proto_rawDescOnce.Do(func() {
|
||||
file_controller_servers_services_v1_credential_proto_rawDescData = protoimpl.X.CompressGZIP(file_controller_servers_services_v1_credential_proto_rawDescData)
|
||||
})
|
||||
return file_controller_servers_services_v1_credential_proto_rawDescData
|
||||
}
|
||||
|
||||
var file_controller_servers_services_v1_credential_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
|
||||
var file_controller_servers_services_v1_credential_proto_goTypes = []interface{}{
|
||||
(*Credential)(nil), // 0: controller.servers.services.v1.Credential
|
||||
(*UserPassword)(nil), // 1: controller.servers.services.v1.UserPassword
|
||||
}
|
||||
var file_controller_servers_services_v1_credential_proto_depIdxs = []int32{
|
||||
1, // 0: controller.servers.services.v1.Credential.user_password:type_name -> controller.servers.services.v1.UserPassword
|
||||
1, // [1:1] is the sub-list for method output_type
|
||||
1, // [1:1] is the sub-list for method input_type
|
||||
1, // [1:1] is the sub-list for extension type_name
|
||||
1, // [1:1] is the sub-list for extension extendee
|
||||
0, // [0:1] is the sub-list for field type_name
|
||||
}
|
||||
|
||||
func init() { file_controller_servers_services_v1_credential_proto_init() }
|
||||
func file_controller_servers_services_v1_credential_proto_init() {
|
||||
if File_controller_servers_services_v1_credential_proto != nil {
|
||||
return
|
||||
}
|
||||
if !protoimpl.UnsafeEnabled {
|
||||
file_controller_servers_services_v1_credential_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*Credential); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
file_controller_servers_services_v1_credential_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
||||
switch v := v.(*UserPassword); i {
|
||||
case 0:
|
||||
return &v.state
|
||||
case 1:
|
||||
return &v.sizeCache
|
||||
case 2:
|
||||
return &v.unknownFields
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
file_controller_servers_services_v1_credential_proto_msgTypes[0].OneofWrappers = []interface{}{
|
||||
(*Credential_UserPassword)(nil),
|
||||
}
|
||||
type x struct{}
|
||||
out := protoimpl.TypeBuilder{
|
||||
File: protoimpl.DescBuilder{
|
||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||
RawDescriptor: file_controller_servers_services_v1_credential_proto_rawDesc,
|
||||
NumEnums: 0,
|
||||
NumMessages: 2,
|
||||
NumExtensions: 0,
|
||||
NumServices: 0,
|
||||
},
|
||||
GoTypes: file_controller_servers_services_v1_credential_proto_goTypes,
|
||||
DependencyIndexes: file_controller_servers_services_v1_credential_proto_depIdxs,
|
||||
MessageInfos: file_controller_servers_services_v1_credential_proto_msgTypes,
|
||||
}.Build()
|
||||
File_controller_servers_services_v1_credential_proto = out.File
|
||||
file_controller_servers_services_v1_credential_proto_rawDesc = nil
|
||||
file_controller_servers_services_v1_credential_proto_goTypes = nil
|
||||
file_controller_servers_services_v1_credential_proto_depIdxs = nil
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package controller.servers.services.v1;
|
||||
|
||||
option go_package = "github.com/hashicorp/boundary/internal/gen/controller/servers/services;services";
|
||||
|
||||
message Credential {
|
||||
oneof credential {
|
||||
UserPassword user_password = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// UserPassword is a credential containing a username and a password.
|
||||
message UserPassword {
|
||||
// The username of the credential
|
||||
string username = 10; // @gotags: `class:"public"`
|
||||
|
||||
// The password of the credential
|
||||
string password = 20; // @gotags: `class:"secret"`
|
||||
}
|
||||
@ -0,0 +1,173 @@
|
||||
package workers_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/hashicorp/boundary/internal/authtoken"
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
pbs "github.com/hashicorp/boundary/internal/gen/controller/servers/services"
|
||||
"github.com/hashicorp/boundary/internal/host/static"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/hashicorp/boundary/internal/servers"
|
||||
"github.com/hashicorp/boundary/internal/servers/controller/handlers/workers"
|
||||
"github.com/hashicorp/boundary/internal/session"
|
||||
"github.com/hashicorp/boundary/internal/target"
|
||||
"github.com/hashicorp/boundary/internal/target/tcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
func TestLookupSession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
rw := db.New(conn)
|
||||
wrapper := db.TestWrapper(t)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
org, prj := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
|
||||
|
||||
serversRepoFn := func() (*servers.Repository, error) {
|
||||
return servers.NewRepository(rw, rw, kms)
|
||||
}
|
||||
sessionRepoFn := func() (*session.Repository, error) {
|
||||
return session.NewRepository(rw, rw, kms)
|
||||
}
|
||||
|
||||
at := authtoken.TestAuthToken(t, conn, kms, org.GetPublicId())
|
||||
uId := at.GetIamUserId()
|
||||
hc := static.TestCatalogs(t, conn, prj.GetPublicId(), 1)[0]
|
||||
hs := static.TestSets(t, conn, hc.GetPublicId(), 1)[0]
|
||||
h := static.TestHosts(t, conn, hc.GetPublicId(), 1)[0]
|
||||
static.TestSetMembers(t, conn, hs.GetPublicId(), []*static.Host{h})
|
||||
tar := tcp.TestTarget(ctx, t, conn, prj.GetPublicId(), "test", target.WithHostSources([]string{hs.GetPublicId()}))
|
||||
|
||||
sess := session.TestSession(t, conn, wrapper, session.ComposedOf{
|
||||
UserId: uId,
|
||||
HostId: h.GetPublicId(),
|
||||
TargetId: tar.GetPublicId(),
|
||||
HostSetId: hs.GetPublicId(),
|
||||
AuthTokenId: at.GetPublicId(),
|
||||
ScopeId: prj.GetPublicId(),
|
||||
Endpoint: "tcp://127.0.0.1:22",
|
||||
})
|
||||
|
||||
egressSess := session.TestSession(t, conn, wrapper, session.ComposedOf{
|
||||
UserId: uId,
|
||||
HostId: h.GetPublicId(),
|
||||
TargetId: tar.GetPublicId(),
|
||||
HostSetId: hs.GetPublicId(),
|
||||
AuthTokenId: at.GetPublicId(),
|
||||
ScopeId: prj.GetPublicId(),
|
||||
Endpoint: "tcp://127.0.0.1:22",
|
||||
})
|
||||
|
||||
repo, err := sessionRepoFn()
|
||||
require.NoError(t, err)
|
||||
|
||||
creds := []*pbs.Credential{
|
||||
{
|
||||
Credential: &pbs.Credential_UserPassword{
|
||||
UserPassword: &pbs.UserPassword{
|
||||
Username: "username",
|
||||
Password: "password",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Credential: &pbs.Credential_UserPassword{
|
||||
UserPassword: &pbs.UserPassword{
|
||||
Username: "another-username",
|
||||
Password: "a different password",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
workerCreds := make([]session.Credential, 0, len(creds))
|
||||
for _, c := range creds {
|
||||
data, err := proto.Marshal(c)
|
||||
require.NoError(t, err)
|
||||
workerCreds = append(workerCreds, data)
|
||||
}
|
||||
err = repo.AddSessionCredentials(ctx, egressSess.ScopeId, egressSess.GetPublicId(), workerCreds)
|
||||
require.NoError(t, err)
|
||||
|
||||
s := workers.NewWorkerServiceServer(serversRepoFn, sessionRepoFn, new(sync.Map), kms)
|
||||
require.NotNil(t, s)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
wantErr bool
|
||||
wantErrMsg string
|
||||
want *pbs.LookupSessionResponse
|
||||
sessionId string
|
||||
}{
|
||||
{
|
||||
name: "Invalid session id",
|
||||
sessionId: "s_fakesession",
|
||||
wantErr: true,
|
||||
wantErrMsg: "rpc error: code = PermissionDenied desc = Unknown session ID.",
|
||||
},
|
||||
{
|
||||
name: "Valid",
|
||||
sessionId: sess.PublicId,
|
||||
want: &pbs.LookupSessionResponse{
|
||||
ConnectionLimit: 1,
|
||||
ConnectionsLeft: 1,
|
||||
Version: 1,
|
||||
Endpoint: sess.Endpoint,
|
||||
HostId: sess.HostId,
|
||||
HostSetId: sess.HostSetId,
|
||||
TargetId: sess.TargetId,
|
||||
UserId: sess.UserId,
|
||||
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Valid-with-egress-creds",
|
||||
sessionId: egressSess.PublicId,
|
||||
want: &pbs.LookupSessionResponse{
|
||||
ConnectionLimit: 1,
|
||||
ConnectionsLeft: 1,
|
||||
Version: 1,
|
||||
Endpoint: egressSess.Endpoint,
|
||||
HostId: egressSess.HostId,
|
||||
HostSetId: egressSess.HostSetId,
|
||||
TargetId: egressSess.TargetId,
|
||||
UserId: egressSess.UserId,
|
||||
Status: pbs.SESSIONSTATUS_SESSIONSTATUS_PENDING,
|
||||
Credentials: creds,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
req := &pbs.LookupSessionRequest{
|
||||
SessionId: tc.sessionId,
|
||||
}
|
||||
|
||||
got, err := s.LookupSession(ctx, req)
|
||||
if tc.wantErr {
|
||||
require.Error(err)
|
||||
assert.Nil(got)
|
||||
assert.Equal(tc.wantErrMsg, err.Error())
|
||||
return
|
||||
}
|
||||
assert.Empty(
|
||||
cmp.Diff(
|
||||
tc.want,
|
||||
got,
|
||||
cmpopts.IgnoreUnexported(pbs.LookupSessionResponse{}, pbs.Credential{}, pbs.UserPassword{}),
|
||||
cmpopts.IgnoreFields(pbs.LookupSessionResponse{}, "Expiration", "Authorization"),
|
||||
),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,41 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
"github.com/hashicorp/go-kms-wrapping/structwrapping"
|
||||
)
|
||||
|
||||
// Credential represents the credential data which is sent to the worker.
|
||||
type Credential []byte
|
||||
|
||||
type credential struct {
|
||||
SessionId string
|
||||
KeyId string
|
||||
Credential []byte `gorm:"-" wrapping:"pt,credential_data"`
|
||||
CtCredential []byte `gorm:"column:credential" wrapping:"ct,credential_data"`
|
||||
}
|
||||
|
||||
// TableName returns the table name.
|
||||
func (c *credential) TableName() string {
|
||||
return "session_credential"
|
||||
}
|
||||
|
||||
func (c *credential) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
|
||||
const op = "session.(credential).encrypt"
|
||||
if err := structwrapping.WrapStruct(ctx, cipher, c, nil); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
|
||||
}
|
||||
c.KeyId = cipher.KeyID()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credential) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
|
||||
const op = "session.(credential).decrypt"
|
||||
if err := structwrapping.UnwrapStruct(ctx, cipher, c, nil); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,92 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
)
|
||||
|
||||
// AddSessionCredentials encrypts the credData and adds the credentials to the repository. The credentials are linked
|
||||
// to the sessionID provided, and encrypted using the sessScopeId. Session credentials are only valid for pending and
|
||||
// active sessions, once a session ends, all session credentials are deleted.
|
||||
// All options are ignored.
|
||||
func (r *Repository) AddSessionCredentials(ctx context.Context, sessScopeId, sessionId string, credData []Credential, _ ...Option) error {
|
||||
const op = "session.(Repository).AddSessionCredentials"
|
||||
if sessionId == "" {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing session id")
|
||||
}
|
||||
if sessScopeId == "" {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
|
||||
}
|
||||
if len(credData) == 0 {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing credentials")
|
||||
}
|
||||
|
||||
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
|
||||
if err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
|
||||
}
|
||||
|
||||
addCreds := make([]interface{}, 0, len(credData))
|
||||
for _, cred := range credData {
|
||||
if len(cred) == 0 {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing credential")
|
||||
}
|
||||
|
||||
sessCred := credential{
|
||||
SessionId: sessionId,
|
||||
Credential: cred,
|
||||
}
|
||||
if err := sessCred.encrypt(ctx, databaseWrapper); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt credential"))
|
||||
}
|
||||
addCreds = append(addCreds, sessCred)
|
||||
}
|
||||
|
||||
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{},
|
||||
func(_ db.Reader, w db.Writer) error {
|
||||
if err := w.CreateItems(ctx, addCreds); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add session credentials"))
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListSessionCredentials returns all Credential attached to the sessionId.
|
||||
// All options are ignored.
|
||||
func (r *Repository) ListSessionCredentials(ctx context.Context, sessScopeId, sessionId string, _ ...Option) ([]Credential, error) {
|
||||
const op = "session.(Repository).ListSessionCredentials"
|
||||
if sessionId == "" {
|
||||
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id")
|
||||
}
|
||||
if sessScopeId == "" {
|
||||
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
|
||||
}
|
||||
|
||||
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
|
||||
}
|
||||
|
||||
var creds []*credential
|
||||
if err := r.reader.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op)
|
||||
}
|
||||
if len(creds) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ret := make([]Credential, 0, len(creds))
|
||||
for _, c := range creds {
|
||||
if err := c.decrypt(ctx, databaseWrapper); err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt credential"))
|
||||
}
|
||||
ret = append(ret, c.Credential)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
@ -0,0 +1,231 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRepository_AddSessionCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
rw := db.New(conn)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
s := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
|
||||
type args struct {
|
||||
creds []Credential
|
||||
sessionId string
|
||||
sessionScopeId string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
wantErrCode errors.Code
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "invalid-missing-credentials",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credentials: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-scope-id",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session scope id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-session-id",
|
||||
args: args{
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-empty-cred",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential(""),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-empty-cred-with-valid",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
Credential(""),
|
||||
Credential("test-cred2"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
Credential("test-cred1"),
|
||||
Credential("test-cred2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
err := repo.AddSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId, tt.args.creds)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
|
||||
assert.Equal(tt.wantErrMsg, err.Error())
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
creds, err := repo.ListSessionCredentials(context.Background(), s.ScopeId, s.PublicId)
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(creds, tt.args.creds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_ListSessionCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
rw := db.New(conn)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
s1 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
s2 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
s3 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
|
||||
s1Creds := []Credential{
|
||||
Credential("cred1"),
|
||||
Credential("cred2"),
|
||||
Credential("cred3"),
|
||||
}
|
||||
s2Creds := []Credential{
|
||||
Credential("cred1"),
|
||||
}
|
||||
|
||||
err = repo.AddSessionCredentials(context.Background(), s1.ScopeId, s1.PublicId, s1Creds)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.AddSessionCredentials(context.Background(), s2.ScopeId, s2.PublicId, s2Creds)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
sessionId string
|
||||
sessionScopeId string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantCreds []Credential
|
||||
wantErr bool
|
||||
wantErrCode errors.Code
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "invalid-missing-scope-id",
|
||||
args: args{
|
||||
sessionId: s1.PublicId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session scope id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-session-id",
|
||||
args: args{
|
||||
sessionScopeId: s1.ScopeId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "valid-s1",
|
||||
args: args{
|
||||
sessionId: s1.PublicId,
|
||||
sessionScopeId: s1.ScopeId,
|
||||
},
|
||||
wantCreds: s1Creds,
|
||||
},
|
||||
{
|
||||
name: "valid-s2",
|
||||
args: args{
|
||||
sessionId: s2.PublicId,
|
||||
sessionScopeId: s2.ScopeId,
|
||||
},
|
||||
wantCreds: s2Creds,
|
||||
},
|
||||
{
|
||||
name: "valid-s3-no-creds",
|
||||
args: args{
|
||||
sessionId: s3.PublicId,
|
||||
sessionScopeId: s3.ScopeId,
|
||||
},
|
||||
wantCreds: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
creds, err := repo.ListSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
|
||||
assert.Equal(tt.wantErrMsg, err.Error())
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(creds, tt.wantCreds)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue