mirror of https://github.com/hashicorp/boundary
Creates a new session_credential table and repository functions to add and list credentials attached to a session. All credentials are encrypted using the session scope id.pull/1804/head
parent
27ebfa1125
commit
5fe23ab14d
@ -0,0 +1,44 @@
|
||||
begin;
|
||||
|
||||
create table session_credential (
|
||||
session_id wt_public_id not null
|
||||
constraint session_fkey
|
||||
references session (public_id)
|
||||
on delete cascade
|
||||
on update cascade,
|
||||
credential bytea not null -- encrypted
|
||||
constraint credential_must_not_be_empty
|
||||
check(length(credential) > 0),
|
||||
key_id text not null
|
||||
constraint kms_database_key_version_fkey
|
||||
references kms_database_key_version (private_id)
|
||||
on delete restrict
|
||||
on update cascade,
|
||||
constraint session_credential_session_id_credential_uq
|
||||
unique(session_id, credential)
|
||||
);
|
||||
comment on table session_credential is
|
||||
'session_credential is a table where each row contains a credential to be used by '
|
||||
'by a worker when a connection is established for the session_id.';
|
||||
|
||||
create trigger immutable_columns before update on session_credential
|
||||
for each row execute procedure immutable_columns('session_id', 'credential', 'key_id');
|
||||
|
||||
-- delete_credentials deletes all credentials for a session when the
|
||||
-- session enters the canceling or terminated states.
|
||||
create function delete_session_credentials()
|
||||
returns trigger
|
||||
as $$
|
||||
begin
|
||||
if new.state in ('canceling', 'terminated') then
|
||||
delete from session_credential
|
||||
where session_id = new.session_id;
|
||||
end if;
|
||||
return new;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger delete_session_credentials after insert on session_state
|
||||
for each row execute procedure delete_session_credentials();
|
||||
|
||||
commit;
|
||||
@ -0,0 +1,32 @@
|
||||
-- session_credential tests:
|
||||
-- the following triggers
|
||||
-- delete_session_credentials
|
||||
|
||||
begin;
|
||||
select plan(4);
|
||||
select wtt_load('widgets', 'iam', 'kms', 'auth');
|
||||
|
||||
-- validate the setup data
|
||||
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
|
||||
|
||||
insert into session_credential
|
||||
( session_id , credential , key_id)
|
||||
values
|
||||
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
|
||||
|
||||
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
-- validate the delete triggers
|
||||
prepare delete_session_credentials as
|
||||
insert into session_state
|
||||
(session_id , state)
|
||||
values
|
||||
('s1_____clare' , 'canceling');
|
||||
select lives_ok('delete_session_credentials');
|
||||
|
||||
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
select * from finish();
|
||||
rollback;
|
||||
@ -0,0 +1,32 @@
|
||||
-- session_credential tests:
|
||||
-- the following triggers
|
||||
-- delete_session_credentials
|
||||
|
||||
begin;
|
||||
select plan(4);
|
||||
select wtt_load('widgets', 'iam', 'kms', 'auth');
|
||||
|
||||
-- validate the setup data
|
||||
select is(count(*), 1::bigint) from session where public_id = 's1_____clare';
|
||||
|
||||
insert into session_credential
|
||||
( session_id , credential , key_id)
|
||||
values
|
||||
('s1_____clare', 'clare credential data 1' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 2' , 'kdkv___widget'),
|
||||
('s1_____clare', 'clare credential data 3' , 'kdkv___widget');
|
||||
|
||||
select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
-- validate the delete triggers
|
||||
prepare delete_session_credentials as
|
||||
insert into session_state
|
||||
(session_id , state)
|
||||
values
|
||||
('s1_____clare' , 'terminated');
|
||||
select lives_ok('delete_session_credentials');
|
||||
|
||||
select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare';
|
||||
|
||||
select * from finish();
|
||||
rollback;
|
||||
@ -0,0 +1,41 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
wrapping "github.com/hashicorp/go-kms-wrapping"
|
||||
"github.com/hashicorp/go-kms-wrapping/structwrapping"
|
||||
)
|
||||
|
||||
// Credential represents the credential data which is sent to the worker.
|
||||
type Credential []byte
|
||||
|
||||
type credential struct {
|
||||
SessionId string
|
||||
KeyId string
|
||||
Credential []byte `gorm:"-" wrapping:"pt,credential_data"`
|
||||
CtCredential []byte `gorm:"column:credential" wrapping:"ct,credential_data"`
|
||||
}
|
||||
|
||||
// TableName returns the table name.
|
||||
func (c *credential) TableName() string {
|
||||
return "session_credential"
|
||||
}
|
||||
|
||||
func (c *credential) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
|
||||
const op = "session.(credential).encrypt"
|
||||
if err := structwrapping.WrapStruct(ctx, cipher, c, nil); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
|
||||
}
|
||||
c.KeyId = cipher.KeyID()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *credential) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
|
||||
const op = "session.(credential).decrypt"
|
||||
if err := structwrapping.UnwrapStruct(ctx, cipher, c, nil); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,92 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
)
|
||||
|
||||
// AddSessionCredentials encrypts the credData and adds the credentials to the repository. The credentials are linked
|
||||
// to the sessionID provided, and encrypted using the sessScopeId. Session credentials are only valid for pending and
|
||||
// active sessions, once a session ends, all session credentials are deleted.
|
||||
// All options are ignored.
|
||||
func (r *Repository) AddSessionCredentials(ctx context.Context, sessScopeId, sessionId string, credData []Credential, _ ...Option) error {
|
||||
const op = "session.(Repository).AddSessionCredentials"
|
||||
if sessionId == "" {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing session id")
|
||||
}
|
||||
if sessScopeId == "" {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
|
||||
}
|
||||
if len(credData) == 0 {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing credentials")
|
||||
}
|
||||
|
||||
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
|
||||
if err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
|
||||
}
|
||||
|
||||
addCreds := make([]interface{}, 0, len(credData))
|
||||
for _, cred := range credData {
|
||||
if len(cred) == 0 {
|
||||
return errors.New(ctx, errors.InvalidParameter, op, "missing credential")
|
||||
}
|
||||
|
||||
sessCred := credential{
|
||||
SessionId: sessionId,
|
||||
Credential: cred,
|
||||
}
|
||||
if err := sessCred.encrypt(ctx, databaseWrapper); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt credential"))
|
||||
}
|
||||
addCreds = append(addCreds, sessCred)
|
||||
}
|
||||
|
||||
_, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{},
|
||||
func(_ db.Reader, w db.Writer) error {
|
||||
if err := w.CreateItems(ctx, addCreds); err != nil {
|
||||
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add session credentials"))
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListSessionCredentials returns all Credential attached to the sessionId.
|
||||
// All options are ignored.
|
||||
func (r *Repository) ListSessionCredentials(ctx context.Context, sessScopeId, sessionId string, _ ...Option) ([]Credential, error) {
|
||||
const op = "session.(Repository).ListSessionCredentials"
|
||||
if sessionId == "" {
|
||||
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id")
|
||||
}
|
||||
if sessScopeId == "" {
|
||||
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session scope id")
|
||||
}
|
||||
|
||||
databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
|
||||
}
|
||||
|
||||
var creds []*credential
|
||||
if err := r.reader.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op)
|
||||
}
|
||||
if len(creds) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
ret := make([]Credential, 0, len(creds))
|
||||
for _, c := range creds {
|
||||
if err := c.decrypt(ctx, databaseWrapper); err != nil {
|
||||
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt credential"))
|
||||
}
|
||||
ret = append(ret, c.Credential)
|
||||
}
|
||||
|
||||
return ret, nil
|
||||
}
|
||||
@ -0,0 +1,231 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/boundary/internal/db"
|
||||
"github.com/hashicorp/boundary/internal/errors"
|
||||
"github.com/hashicorp/boundary/internal/iam"
|
||||
"github.com/hashicorp/boundary/internal/kms"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRepository_AddSessionCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
rw := db.New(conn)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
s := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
|
||||
type args struct {
|
||||
creds []Credential
|
||||
sessionId string
|
||||
sessionScopeId string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
wantErrCode errors.Code
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "invalid-missing-credentials",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credentials: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-scope-id",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session scope id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-session-id",
|
||||
args: args{
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing session id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-empty-cred",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential(""),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-empty-cred-with-valid",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
Credential(""),
|
||||
Credential("test-cred2"),
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
args: args{
|
||||
sessionId: s.PublicId,
|
||||
sessionScopeId: s.ScopeId,
|
||||
creds: []Credential{
|
||||
Credential("test-cred"),
|
||||
Credential("test-cred1"),
|
||||
Credential("test-cred2"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
err := repo.AddSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId, tt.args.creds)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
|
||||
assert.Equal(tt.wantErrMsg, err.Error())
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
|
||||
creds, err := repo.ListSessionCredentials(context.Background(), s.ScopeId, s.PublicId)
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(creds, tt.args.creds)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepository_ListSessionCredentials(t *testing.T) {
|
||||
t.Parallel()
|
||||
conn, _ := db.TestSetup(t, "postgres")
|
||||
wrapper := db.TestWrapper(t)
|
||||
iamRepo := iam.TestRepo(t, conn, wrapper)
|
||||
rw := db.New(conn)
|
||||
kms := kms.TestKms(t, conn, wrapper)
|
||||
repo, err := NewRepository(rw, rw, kms)
|
||||
require.NoError(t, err)
|
||||
s1 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
s2 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
s3 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo))
|
||||
|
||||
s1Creds := []Credential{
|
||||
Credential("cred1"),
|
||||
Credential("cred2"),
|
||||
Credential("cred3"),
|
||||
}
|
||||
s2Creds := []Credential{
|
||||
Credential("cred1"),
|
||||
}
|
||||
|
||||
err = repo.AddSessionCredentials(context.Background(), s1.ScopeId, s1.PublicId, s1Creds)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = repo.AddSessionCredentials(context.Background(), s2.ScopeId, s2.PublicId, s2Creds)
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
sessionId string
|
||||
sessionScopeId string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantCreds []Credential
|
||||
wantErr bool
|
||||
wantErrCode errors.Code
|
||||
wantErrMsg string
|
||||
}{
|
||||
{
|
||||
name: "invalid-missing-scope-id",
|
||||
args: args{
|
||||
sessionId: s1.PublicId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session scope id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "invalid-missing-session-id",
|
||||
args: args{
|
||||
sessionScopeId: s1.ScopeId,
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrCode: errors.InvalidParameter,
|
||||
wantErrMsg: "session.(Repository).ListSessionCredentials: missing session id: parameter violation: error #100",
|
||||
},
|
||||
{
|
||||
name: "valid-s1",
|
||||
args: args{
|
||||
sessionId: s1.PublicId,
|
||||
sessionScopeId: s1.ScopeId,
|
||||
},
|
||||
wantCreds: s1Creds,
|
||||
},
|
||||
{
|
||||
name: "valid-s2",
|
||||
args: args{
|
||||
sessionId: s2.PublicId,
|
||||
sessionScopeId: s2.ScopeId,
|
||||
},
|
||||
wantCreds: s2Creds,
|
||||
},
|
||||
{
|
||||
name: "valid-s3-no-creds",
|
||||
args: args{
|
||||
sessionId: s3.PublicId,
|
||||
sessionScopeId: s3.ScopeId,
|
||||
},
|
||||
wantCreds: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert, require := assert.New(t), require.New(t)
|
||||
|
||||
creds, err := repo.ListSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId)
|
||||
if tt.wantErr {
|
||||
require.Error(err)
|
||||
assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err)
|
||||
assert.Equal(tt.wantErrMsg, err.Error())
|
||||
return
|
||||
}
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(creds, tt.wantCreds)
|
||||
})
|
||||
}
|
||||
}
|
||||
Loading…
Reference in new issue