feat(session): Store session credentials

Creates a new session_credential table and repository functions
to add and list credentials attached to a session. All credentials
are encrypted using the session scope id.
pull/1804/head
Louis Ruch 4 years ago
parent 27ebfa1125
commit 5fe23ab14d

@ -0,0 +1,44 @@
begin;
create table session_credential (
session_id wt_public_id not null
constraint session_fkey
references session (public_id)
on delete cascade
on update cascade,
credential bytea not null -- encrypted
constraint credential_must_not_be_empty
check(length(credential) > 0),
key_id text not null
constraint kms_database_key_version_fkey
references kms_database_key_version (private_id)
on delete restrict
on update cascade,
constraint session_credential_session_id_credential_uq
unique(session_id, credential)
);
comment on table session_credential is
'session_credential is a table where each row contains a credential to be used by '
'by a worker when a connection is established for the session_id.';
create trigger immutable_columns before update on session_credential
for each row execute procedure immutable_columns('session_id', 'credential', 'key_id');
-- delete_credentials deletes all credentials for a session when the
-- session enters the canceling or terminated states.
create function delete_session_credentials()
returns trigger
as $$
begin
if new.state in ('canceling', 'terminated') then
delete from session_credential
where session_id = new.session_id;
end if;
return new;
end;
$$ language plpgsql;
create trigger delete_session_credentials after insert on session_state
for each row execute procedure delete_session_credentials();
commit;

@ -20,7 +20,8 @@ PROVE_OPTS ?=
TESTS ?= tests/setup/*.sql \
tests/wh/*/*.sql \
tests/sentinel/*.sql \
tests/credential/*/*.sql
tests/credential/*/*.sql \
tests/session/*.sql
POSTGRES_DOCKER_IMAGE_BASE ?= postgres

@ -359,4 +359,3 @@ begin;
end;
$$ language plpgsql;
commit;

@ -0,0 +1,32 @@
-- session_credential tests:
-- the following triggers
-- delete_session_credentials
begin;
select plan(4);
select wtt_load('widgets', 'iam', 'kms', 'auth');
-- validate the setup data
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
insert into session_credential
( session_id , credential , key_id)
values
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
-- validate the delete triggers
prepare delete_session_credentials as
insert into session_state
(session_id , state)
values
('s1_____clare' , 'canceling');
select lives_ok('delete_session_credentials');
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
select * from finish();
rollback;

@ -0,0 +1,32 @@
-- session_credential tests:
-- the following triggers
-- delete_session_credentials
begin;
select plan(4);
select wtt_load('widgets', 'iam', 'kms', 'auth');
-- validate the setup data
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
insert into session_credential
( session_id , credential , key_id)
values
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
-- validate the delete triggers
prepare delete_session_credentials as
insert into session_state
(session_id , state)
values
('s1_____clare' , 'terminated');
select lives_ok('delete_session_credentials');
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
select * from finish();
rollback;

@ -0,0 +1,41 @@
package session
import (
"context"
"github.com/hashicorp/boundary/internal/errors"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/structwrapping"
)
// Credential represents the credential data which is sent to the worker.
type Credential []byte
type credential struct {
SessionId string
KeyId string
Credential []byte `gorm:"-" wrapping:"pt,credential_data"`
CtCredential []byte `gorm:"column:credential" wrapping:"ct,credential_data"`
}
// TableName returns the table name.
func (c *credential) TableName() string {
return "session_credential"
}
func (c *credential) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(credential).encrypt"
if err := structwrapping.WrapStruct(ctx, cipher, c, nil); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
}
c.KeyId = cipher.KeyID()
return nil
}
func (c *credential) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(credential).decrypt"
if err := structwrapping.UnwrapStruct(ctx, cipher, c, nil); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt))
}
return nil
}

@ -1,7 +1,7 @@
package session
import (
"github.com/hashicorp/boundary/internal/credential"
cred "github.com/hashicorp/boundary/internal/credential"
)
// A DynamicCredential represents the relationship between a session, a
@ -18,7 +18,7 @@ type DynamicCredential struct {
// NewDynamicCredential creates a new in memory Credential representing the
// relationship between session and a credential library.
func NewDynamicCredential(libraryId string, purpose credential.Purpose) *DynamicCredential {
func NewDynamicCredential(libraryId string, purpose cred.Purpose) *DynamicCredential {
return &DynamicCredential{
LibraryId: libraryId,
CredentialPurpose: string(purpose),

@ -0,0 +1,92 @@
package session
import (
"context"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/kms"
)
// AddSessionCredentials encrypts the credData and adds the credentials to the repository. The credentials are linked
// to the sessionID provided, and encrypted using the sessScopeId. Session credentials are only valid for pending and
// active sessions, once a session ends, all session credentials are deleted.
// All options are ignored.
func (r *Repository) AddSessionCredentials(ctx context.Context, sessScopeId, sessionId string, credData []Credential, _ ...Option) error {
const op = "session.(Repository).AddSessionCredentials"
if sessionId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing session id")
}
if sessScopeId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
}
if len(credData) == 0 {
return errors.New(ctx, errors.InvalidParameter, op, "missing credentials")
}
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
addCreds := make([]interface{}, 0, len(credData))
for _, cred := range credData {
if len(cred) == 0 {
return errors.New(ctx, errors.InvalidParameter, op, "missing credential")
}
sessCred := credential{
SessionId: sessionId,
Credential: cred,
}
if err := sessCred.encrypt(ctx, databaseWrapper); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt credential"))
}
addCreds = append(addCreds, sessCred)
}
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
if err := w.CreateItems(ctx, addCreds); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add session credentials"))
}
return nil
},
)
return err
}
// ListSessionCredentials returns all Credential attached to the sessionId.
// All options are ignored.
func (r *Repository) ListSessionCredentials(ctx context.Context, sessScopeId, sessionId string, _ ...Option) ([]Credential, error) {
const op = "session.(Repository).ListSessionCredentials"
if sessionId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id")
}
if sessScopeId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
}
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
var creds []*credential
if err := r.reader.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
if len(creds) == 0 {
return nil, nil
}
ret := make([]Credential, 0, len(creds))
for _, c := range creds {
if err := c.decrypt(ctx, databaseWrapper); err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt credential"))
}
ret = append(ret, c.Credential)
}
return ret, nil
}

@ -0,0 +1,231 @@
package session
import (
"context"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRepository_AddSessionCredentials(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
rw := db.New(conn)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
s := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
type args struct {
creds []Credential
sessionId string
sessionScopeId string
}
tests := []struct {
name string
args args
wantErr bool
wantErrCode errors.Code
wantErrMsg string
}{
{
name: "invalid-missing-credentials",
args: args{
sessionId: s.PublicId,
sessionScopeId: s.ScopeId,
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credentials: parameter violation: error #100",
},
{
name: "invalid-missing-scope-id",
args: args{
sessionId: s.PublicId,
creds: []Credential{
Credential("test-cred"),
},
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session scope id: parameter violation: error #100",
},
{
name: "invalid-missing-session-id",
args: args{
sessionScopeId: s.ScopeId,
creds: []Credential{
Credential("test-cred"),
},
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session id: parameter violation: error #100",
},
{
name: "invalid-empty-cred",
args: args{
sessionId: s.PublicId,
sessionScopeId: s.ScopeId,
creds: []Credential{
Credential(""),
},
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
},
{
name: "invalid-empty-cred-with-valid",
args: args{
sessionId: s.PublicId,
sessionScopeId: s.ScopeId,
creds: []Credential{
Credential("test-cred"),
Credential(""),
Credential("test-cred2"),
},
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
},
{
name: "valid",
args: args{
sessionId: s.PublicId,
sessionScopeId: s.ScopeId,
creds: []Credential{
Credential("test-cred"),
Credential("test-cred1"),
Credential("test-cred2"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
err := repo.AddSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId, tt.args.creds)
if tt.wantErr {
require.Error(err)
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
require.NoError(err)
creds, err := repo.ListSessionCredentials(context.Background(), s.ScopeId, s.PublicId)
require.NoError(err)
assert.ElementsMatch(creds, tt.args.creds)
})
}
}
func TestRepository_ListSessionCredentials(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
iamRepo := iam.TestRepo(t, conn, wrapper)
rw := db.New(conn)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
s1 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
s2 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
s3 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
s1Creds := []Credential{
Credential("cred1"),
Credential("cred2"),
Credential("cred3"),
}
s2Creds := []Credential{
Credential("cred1"),
}
err = repo.AddSessionCredentials(context.Background(), s1.ScopeId, s1.PublicId, s1Creds)
require.NoError(t, err)
err = repo.AddSessionCredentials(context.Background(), s2.ScopeId, s2.PublicId, s2Creds)
require.NoError(t, err)
type args struct {
sessionId string
sessionScopeId string
}
tests := []struct {
name string
args args
wantCreds []Credential
wantErr bool
wantErrCode errors.Code
wantErrMsg string
}{
{
name: "invalid-missing-scope-id",
args: args{
sessionId: s1.PublicId,
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session scope id: parameter violation: error #100",
},
{
name: "invalid-missing-session-id",
args: args{
sessionScopeId: s1.ScopeId,
},
wantErr: true,
wantErrCode: errors.InvalidParameter,
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session id: parameter violation: error #100",
},
{
name: "valid-s1",
args: args{
sessionId: s1.PublicId,
sessionScopeId: s1.ScopeId,
},
wantCreds: s1Creds,
},
{
name: "valid-s2",
args: args{
sessionId: s2.PublicId,
sessionScopeId: s2.ScopeId,
},
wantCreds: s2Creds,
},
{
name: "valid-s3-no-creds",
args: args{
sessionId: s3.PublicId,
sessionScopeId: s3.ScopeId,
},
wantCreds: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
creds, err := repo.ListSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId)
if tt.wantErr {
require.Error(err)
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
require.NoError(err)
assert.ElementsMatch(creds, tt.wantCreds)
})
}
}

@ -8,28 +8,25 @@ import (
"github.com/hashicorp/boundary/internal/authtoken"
authtokenStore "github.com/hashicorp/boundary/internal/authtoken/store"
"github.com/hashicorp/boundary/internal/credential"
cred "github.com/hashicorp/boundary/internal/credential"
"github.com/hashicorp/boundary/internal/credential/vault"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/host/static"
staticStore "github.com/hashicorp/boundary/internal/host/static/store"
"github.com/hashicorp/boundary/internal/iam"
iamStore "github.com/hashicorp/boundary/internal/iam/store"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/target"
"github.com/hashicorp/boundary/internal/target/tcp"
tcpStore "github.com/hashicorp/boundary/internal/target/tcp/store"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/jackc/pgconn"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
iamStore "github.com/hashicorp/boundary/internal/iam/store"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
)
func TestRepository_ListSession(t *testing.T) {
@ -1700,15 +1697,15 @@ func testSessionCredentialParams(t *testing.T, conn *db.DB, wrapper wrapping.Wra
stores := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1)
libIds := vault.TestCredentialLibraries(t, conn, wrapper, stores[0].GetPublicId(), 2)
libs := []*target.CredentialLibrary{
target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[0].GetPublicId(), credential.ApplicationPurpose),
target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[1].GetPublicId(), credential.ApplicationPurpose),
target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[0].GetPublicId(), cred.ApplicationPurpose),
target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[1].GetPublicId(), cred.ApplicationPurpose),
}
_, _, _, err = targetRepo.AddTargetCredentialSources(ctx, tar.GetPublicId(), tar.GetVersion(), libs)
require.NoError(err)
creds := []*DynamicCredential{
NewDynamicCredential(libIds[0].GetPublicId(), credential.ApplicationPurpose),
NewDynamicCredential(libIds[1].GetPublicId(), credential.ApplicationPurpose),
NewDynamicCredential(libIds[0].GetPublicId(), cred.ApplicationPurpose),
NewDynamicCredential(libIds[1].GetPublicId(), cred.ApplicationPurpose),
}
params.DynamicCredentials = creds
return params

Loading…
Cancel
Save