diff --git a/internal/db/schema/migrations/oss/postgres/23/01_session_credential.up.sql b/internal/db/schema/migrations/oss/postgres/23/01_session_credential.up.sql new file mode 100644 index 0000000000..6b4c9db177 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/23/01_session_credential.up.sql @@ -0,0 +1,44 @@ +begin; + + create table session_credential ( + session_id wt_public_id not null + constraint session_fkey + references session (public_id) + on delete cascade + on update cascade, + credential bytea not null -- encrypted + constraint credential_must_not_be_empty + check(length(credential) > 0), + key_id text not null + constraint kms_database_key_version_fkey + references kms_database_key_version (private_id) + on delete restrict + on update cascade, + constraint session_credential_session_id_credential_uq + unique(session_id, credential) + ); + comment on table session_credential is + 'session_credential is a table where each row contains a credential to be used by ' + 'by a worker when a connection is established for the session_id.'; + + create trigger immutable_columns before update on session_credential + for each row execute procedure immutable_columns('session_id', 'credential', 'key_id'); + + -- delete_credentials deletes all credentials for a session when the + -- session enters the canceling or terminated states. + create function delete_session_credentials() + returns trigger + as $$ + begin + if new.state in ('canceling', 'terminated') then + delete from session_credential + where session_id = new.session_id; + end if; + return new; + end; + $$ language plpgsql; + + create trigger delete_session_credentials after insert on session_state + for each row execute procedure delete_session_credentials(); + +commit; diff --git a/internal/db/sqltest/Makefile b/internal/db/sqltest/Makefile index 81dbffa295..1d28fef9b6 100644 --- a/internal/db/sqltest/Makefile +++ b/internal/db/sqltest/Makefile @@ -20,7 +20,8 @@ PROVE_OPTS ?= TESTS ?= tests/setup/*.sql \ tests/wh/*/*.sql \ tests/sentinel/*.sql \ - tests/credential/*/*.sql + tests/credential/*/*.sql \ + tests/session/*.sql POSTGRES_DOCKER_IMAGE_BASE ?= postgres diff --git a/internal/db/sqltest/initdb.d/03_widgets_persona.sql b/internal/db/sqltest/initdb.d/03_widgets_persona.sql index 21603875cd..d0b3e8434f 100644 --- a/internal/db/sqltest/initdb.d/03_widgets_persona.sql +++ b/internal/db/sqltest/initdb.d/03_widgets_persona.sql @@ -359,4 +359,3 @@ begin; end; $$ language plpgsql; commit; - diff --git a/internal/db/sqltest/tests/session/session_credential_canceling.sql b/internal/db/sqltest/tests/session/session_credential_canceling.sql new file mode 100644 index 0000000000..0ed5e03cd6 --- /dev/null +++ b/internal/db/sqltest/tests/session/session_credential_canceling.sql @@ -0,0 +1,32 @@ +-- session_credential tests: +-- the following triggers +-- delete_session_credentials + +begin; + select plan(4); + select wtt_load('widgets', 'iam', 'kms', 'auth'); + + -- validate the setup data + select is(count(*), 1::bigint) from session where public_id = 's1_____clare'; + + insert into session_credential + ( session_id , credential , key_id) + values + ('s1_____clare', 'clare credential data 1' , 'kdkv___widget'), + ('s1_____clare', 'clare credential data 2' , 'kdkv___widget'), + ('s1_____clare', 'clare credential data 3' , 'kdkv___widget'); + + select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare'; + + -- validate the delete triggers + prepare delete_session_credentials as + insert into session_state + (session_id , state) + values + ('s1_____clare' , 'canceling'); + select lives_ok('delete_session_credentials'); + + select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare'; + + select * from finish(); +rollback; diff --git a/internal/db/sqltest/tests/session/session_credential_terminated.sql b/internal/db/sqltest/tests/session/session_credential_terminated.sql new file mode 100644 index 0000000000..21318ca4b5 --- /dev/null +++ b/internal/db/sqltest/tests/session/session_credential_terminated.sql @@ -0,0 +1,32 @@ +-- session_credential tests: +-- the following triggers +-- delete_session_credentials + +begin; + select plan(4); + select wtt_load('widgets', 'iam', 'kms', 'auth'); + + -- validate the setup data + select is(count(*), 1::bigint) from session where public_id = 's1_____clare'; + + insert into session_credential + ( session_id , credential , key_id) + values + ('s1_____clare', 'clare credential data 1' , 'kdkv___widget'), + ('s1_____clare', 'clare credential data 2' , 'kdkv___widget'), + ('s1_____clare', 'clare credential data 3' , 'kdkv___widget'); + + select is(count(*), 3::bigint) from session_credential where session_id = 's1_____clare'; + + -- validate the delete triggers + prepare delete_session_credentials as + insert into session_state + (session_id , state) + values + ('s1_____clare' , 'terminated'); + select lives_ok('delete_session_credentials'); + + select is(count(*), 0::bigint) from session_credential where session_id = 's1_____clare'; + + select * from finish(); +rollback; diff --git a/internal/session/credential.go b/internal/session/credential.go new file mode 100644 index 0000000000..0e57cbdfb2 --- /dev/null +++ b/internal/session/credential.go @@ -0,0 +1,41 @@ +package session + +import ( + "context" + + "github.com/hashicorp/boundary/internal/errors" + wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-kms-wrapping/structwrapping" +) + +// Credential represents the credential data which is sent to the worker. +type Credential []byte + +type credential struct { + SessionId string + KeyId string + Credential []byte `gorm:"-" wrapping:"pt,credential_data"` + CtCredential []byte `gorm:"column:credential" wrapping:"ct,credential_data"` +} + +// TableName returns the table name. +func (c *credential) TableName() string { + return "session_credential" +} + +func (c *credential) encrypt(ctx context.Context, cipher wrapping.Wrapper) error { + const op = "session.(credential).encrypt" + if err := structwrapping.WrapStruct(ctx, cipher, c, nil); err != nil { + return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt)) + } + c.KeyId = cipher.KeyID() + return nil +} + +func (c *credential) decrypt(ctx context.Context, cipher wrapping.Wrapper) error { + const op = "session.(credential).decrypt" + if err := structwrapping.UnwrapStruct(ctx, cipher, c, nil); err != nil { + return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt)) + } + return nil +} diff --git a/internal/session/dynamic_credential.go b/internal/session/dynamic_credential.go index 264e36fcaf..1a66110d35 100644 --- a/internal/session/dynamic_credential.go +++ b/internal/session/dynamic_credential.go @@ -1,7 +1,7 @@ package session import ( - "github.com/hashicorp/boundary/internal/credential" + cred "github.com/hashicorp/boundary/internal/credential" ) // A DynamicCredential represents the relationship between a session, a @@ -18,7 +18,7 @@ type DynamicCredential struct { // NewDynamicCredential creates a new in memory Credential representing the // relationship between session and a credential library. -func NewDynamicCredential(libraryId string, purpose credential.Purpose) *DynamicCredential { +func NewDynamicCredential(libraryId string, purpose cred.Purpose) *DynamicCredential { return &DynamicCredential{ LibraryId: libraryId, CredentialPurpose: string(purpose), diff --git a/internal/session/repository_credential.go b/internal/session/repository_credential.go new file mode 100644 index 0000000000..0bba37ed76 --- /dev/null +++ b/internal/session/repository_credential.go @@ -0,0 +1,92 @@ +package session + +import ( + "context" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/kms" +) + +// AddSessionCredentials encrypts the credData and adds the credentials to the repository. The credentials are linked +// to the sessionID provided, and encrypted using the sessScopeId. Session credentials are only valid for pending and +// active sessions, once a session ends, all session credentials are deleted. +// All options are ignored. +func (r *Repository) AddSessionCredentials(ctx context.Context, sessScopeId, sessionId string, credData []Credential, _ ...Option) error { + const op = "session.(Repository).AddSessionCredentials" + if sessionId == "" { + return errors.New(ctx, errors.InvalidParameter, op, "missing session id") + } + if sessScopeId == "" { + return errors.New(ctx, errors.InvalidParameter, op, "missing session scope id") + } + if len(credData) == 0 { + return errors.New(ctx, errors.InvalidParameter, op, "missing credentials") + } + + databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase) + if err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) + } + + addCreds := make([]interface{}, 0, len(credData)) + for _, cred := range credData { + if len(cred) == 0 { + return errors.New(ctx, errors.InvalidParameter, op, "missing credential") + } + + sessCred := credential{ + SessionId: sessionId, + Credential: cred, + } + if err := sessCred.encrypt(ctx, databaseWrapper); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt credential")) + } + addCreds = append(addCreds, sessCred) + } + + _, err = r.writer.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, + func(_ db.Reader, w db.Writer) error { + if err := w.CreateItems(ctx, addCreds); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("unable to add session credentials")) + } + + return nil + }, + ) + return err +} + +// ListSessionCredentials returns all Credential attached to the sessionId. +// All options are ignored. +func (r *Repository) ListSessionCredentials(ctx context.Context, sessScopeId, sessionId string, _ ...Option) ([]Credential, error) { + const op = "session.(Repository).ListSessionCredentials" + if sessionId == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id") + } + if sessScopeId == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session scope id") + } + + databaseWrapper, err := r.kms.GetWrapper(ctx, sessScopeId, kms.KeyPurposeDatabase) + if err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper")) + } + + var creds []*credential + if err := r.reader.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil { + return nil, errors.Wrap(ctx, err, op) + } + if len(creds) == 0 { + return nil, nil + } + ret := make([]Credential, 0, len(creds)) + for _, c := range creds { + if err := c.decrypt(ctx, databaseWrapper); err != nil { + return nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt credential")) + } + ret = append(ret, c.Credential) + } + + return ret, nil +} diff --git a/internal/session/repository_credential_test.go b/internal/session/repository_credential_test.go new file mode 100644 index 0000000000..04bf283f2e --- /dev/null +++ b/internal/session/repository_credential_test.go @@ -0,0 +1,231 @@ +package session + +import ( + "context" + "testing" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/iam" + "github.com/hashicorp/boundary/internal/kms" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRepository_AddSessionCredentials(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + rw := db.New(conn) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + s := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo)) + + type args struct { + creds []Credential + sessionId string + sessionScopeId string + } + tests := []struct { + name string + args args + wantErr bool + wantErrCode errors.Code + wantErrMsg string + }{ + { + name: "invalid-missing-credentials", + args: args{ + sessionId: s.PublicId, + sessionScopeId: s.ScopeId, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).AddSessionCredentials: missing credentials: parameter violation: error #100", + }, + { + name: "invalid-missing-scope-id", + args: args{ + sessionId: s.PublicId, + creds: []Credential{ + Credential("test-cred"), + }, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).AddSessionCredentials: missing session scope id: parameter violation: error #100", + }, + { + name: "invalid-missing-session-id", + args: args{ + sessionScopeId: s.ScopeId, + creds: []Credential{ + Credential("test-cred"), + }, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).AddSessionCredentials: missing session id: parameter violation: error #100", + }, + { + name: "invalid-empty-cred", + args: args{ + sessionId: s.PublicId, + sessionScopeId: s.ScopeId, + creds: []Credential{ + Credential(""), + }, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100", + }, + { + name: "invalid-empty-cred-with-valid", + args: args{ + sessionId: s.PublicId, + sessionScopeId: s.ScopeId, + creds: []Credential{ + Credential("test-cred"), + Credential(""), + Credential("test-cred2"), + }, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).AddSessionCredentials: missing credential: parameter violation: error #100", + }, + { + name: "valid", + args: args{ + sessionId: s.PublicId, + sessionScopeId: s.ScopeId, + creds: []Credential{ + Credential("test-cred"), + Credential("test-cred1"), + Credential("test-cred2"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + err := repo.AddSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId, tt.args.creds) + if tt.wantErr { + require.Error(err) + assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err) + assert.Equal(tt.wantErrMsg, err.Error()) + return + } + require.NoError(err) + + creds, err := repo.ListSessionCredentials(context.Background(), s.ScopeId, s.PublicId) + require.NoError(err) + assert.ElementsMatch(creds, tt.args.creds) + }) + } +} + +func TestRepository_ListSessionCredentials(t *testing.T) { + t.Parallel() + conn, _ := db.TestSetup(t, "postgres") + wrapper := db.TestWrapper(t) + iamRepo := iam.TestRepo(t, conn, wrapper) + rw := db.New(conn) + kms := kms.TestKms(t, conn, wrapper) + repo, err := NewRepository(rw, rw, kms) + require.NoError(t, err) + s1 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo)) + s2 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo)) + s3 := TestSession(t, conn, wrapper, TestSessionParams(t, conn, wrapper, iamRepo)) + + s1Creds := []Credential{ + Credential("cred1"), + Credential("cred2"), + Credential("cred3"), + } + s2Creds := []Credential{ + Credential("cred1"), + } + + err = repo.AddSessionCredentials(context.Background(), s1.ScopeId, s1.PublicId, s1Creds) + require.NoError(t, err) + + err = repo.AddSessionCredentials(context.Background(), s2.ScopeId, s2.PublicId, s2Creds) + require.NoError(t, err) + + type args struct { + sessionId string + sessionScopeId string + } + tests := []struct { + name string + args args + wantCreds []Credential + wantErr bool + wantErrCode errors.Code + wantErrMsg string + }{ + { + name: "invalid-missing-scope-id", + args: args{ + sessionId: s1.PublicId, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).ListSessionCredentials: missing session scope id: parameter violation: error #100", + }, + { + name: "invalid-missing-session-id", + args: args{ + sessionScopeId: s1.ScopeId, + }, + wantErr: true, + wantErrCode: errors.InvalidParameter, + wantErrMsg: "session.(Repository).ListSessionCredentials: missing session id: parameter violation: error #100", + }, + { + name: "valid-s1", + args: args{ + sessionId: s1.PublicId, + sessionScopeId: s1.ScopeId, + }, + wantCreds: s1Creds, + }, + { + name: "valid-s2", + args: args{ + sessionId: s2.PublicId, + sessionScopeId: s2.ScopeId, + }, + wantCreds: s2Creds, + }, + { + name: "valid-s3-no-creds", + args: args{ + sessionId: s3.PublicId, + sessionScopeId: s3.ScopeId, + }, + wantCreds: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert, require := assert.New(t), require.New(t) + + creds, err := repo.ListSessionCredentials(context.Background(), tt.args.sessionScopeId, tt.args.sessionId) + if tt.wantErr { + require.Error(err) + assert.Truef(errors.Match(errors.T(tt.wantErrCode), err), "Unexpected error %s", err) + assert.Equal(tt.wantErrMsg, err.Error()) + return + } + require.NoError(err) + assert.ElementsMatch(creds, tt.wantCreds) + }) + } +} diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index cf2545455e..946cdac161 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -8,28 +8,25 @@ import ( "github.com/hashicorp/boundary/internal/authtoken" authtokenStore "github.com/hashicorp/boundary/internal/authtoken/store" - "github.com/hashicorp/boundary/internal/credential" + cred "github.com/hashicorp/boundary/internal/credential" "github.com/hashicorp/boundary/internal/credential/vault" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/host/static" staticStore "github.com/hashicorp/boundary/internal/host/static/store" + "github.com/hashicorp/boundary/internal/iam" + iamStore "github.com/hashicorp/boundary/internal/iam/store" + "github.com/hashicorp/boundary/internal/kms" + "github.com/hashicorp/boundary/internal/oplog" "github.com/hashicorp/boundary/internal/target" "github.com/hashicorp/boundary/internal/target/tcp" tcpStore "github.com/hashicorp/boundary/internal/target/tcp/store" wrapping "github.com/hashicorp/go-kms-wrapping" "github.com/jackc/pgconn" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/hashicorp/boundary/internal/errors" - "github.com/hashicorp/boundary/internal/iam" - - iamStore "github.com/hashicorp/boundary/internal/iam/store" - - "github.com/hashicorp/boundary/internal/kms" - "github.com/hashicorp/boundary/internal/oplog" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestRepository_ListSession(t *testing.T) { @@ -1700,15 +1697,15 @@ func testSessionCredentialParams(t *testing.T, conn *db.DB, wrapper wrapping.Wra stores := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1) libIds := vault.TestCredentialLibraries(t, conn, wrapper, stores[0].GetPublicId(), 2) libs := []*target.CredentialLibrary{ - target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[0].GetPublicId(), credential.ApplicationPurpose), - target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[1].GetPublicId(), credential.ApplicationPurpose), + target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[0].GetPublicId(), cred.ApplicationPurpose), + target.TestNewCredentialLibrary(tar.GetPublicId(), libIds[1].GetPublicId(), cred.ApplicationPurpose), } _, _, _, err = targetRepo.AddTargetCredentialSources(ctx, tar.GetPublicId(), tar.GetVersion(), libs) require.NoError(err) creds := []*DynamicCredential{ - NewDynamicCredential(libIds[0].GetPublicId(), credential.ApplicationPurpose), - NewDynamicCredential(libIds[1].GetPublicId(), credential.ApplicationPurpose), + NewDynamicCredential(libIds[0].GetPublicId(), cred.ApplicationPurpose), + NewDynamicCredential(libIds[1].GetPublicId(), cred.ApplicationPurpose), } params.DynamicCredentials = creds return params