diff --git a/internal/db/schema/migrations/oss/postgres/50/03_session.up.sql b/internal/db/schema/migrations/oss/postgres/50/03_session.up.sql new file mode 100644 index 0000000000..71bbdad951 --- /dev/null +++ b/internal/db/schema/migrations/oss/postgres/50/03_session.up.sql @@ -0,0 +1,32 @@ +begin; + + create table session_credential_static ( + session_id wt_public_id not null + constraint session_fkey + references session (public_id) + on delete cascade + on update cascade, + credential_static_id wt_public_id + constraint credential_static_fkey + references credential_static (public_id) + on delete cascade + on update cascade, + credential_purpose text not null + constraint credential_purpose_fkey + references credential_purpose_enm (name) + on delete restrict + on update cascade, + primary key(session_id, credential_static_id, credential_purpose), + create_time wt_timestamp + ); + comment on table session_credential_dynamic is + 'session_credential_static is a join table between the session and static credential tables. ' + 'It also contains the credential purpose the relationship represents.'; + + create trigger default_create_time_column before insert on session_credential_static + for each row execute procedure default_create_time(); + + create trigger immutable_columns before update on session_credential_static + for each row execute procedure immutable_columns('session_id', 'credential_static_id', 'credential_purpose', 'create_time'); + +commit; diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index 07cfd76fb1..b225409fcd 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -93,6 +93,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. func(read db.Reader, w db.Writer) error { returnedSession = newSession.Clone().(*Session) returnedSession.DynamicCredentials = nil + returnedSession.StaticCredentials = nil if err = w.Create(ctx, returnedSession); err != nil { return errors.Wrap(ctx, err, op) } @@ -101,6 +102,25 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. cred.SessionId = newSession.PublicId } + var staticCreds []interface{} + for _, cred := range newSession.StaticCredentials { + cred.SessionId = newSession.PublicId + staticCreds = append(staticCreds, cred) + } + + if len(staticCreds) > 0 { + if err = w.CreateItems(ctx, staticCreds); err != nil { + return errors.Wrap(ctx, err, op, errors.WithMsg("failed to create static credentials")) + } + + // Get static creds back from the db for return + var c []*StaticCredential + if err := read.SearchWhere(ctx, &c, "session_id = ?", []interface{}{newSession.PublicId}); err != nil { + return errors.Wrap(ctx, err, op) + } + returnedSession.StaticCredentials = c + } + // TODO: after upgrading to gorm v2 this batch insert can be replaced, since gorm v2 supports batch inserts q, batchInsertArgs, err := batchInsertsessionCredentialDynamic(newSession.DynamicCredentials) if err == nil { @@ -165,12 +185,20 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O } session.States = states - var creds []*DynamicCredential - if err := read.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil { + var dynamicCreds []*DynamicCredential + if err := read.SearchWhere(ctx, &dynamicCreds, "session_id = ?", []interface{}{sessionId}); err != nil { + return errors.Wrap(ctx, err, op) + } + if len(dynamicCreds) > 0 { + session.DynamicCredentials = dynamicCreds + } + + var staticCreds []*StaticCredential + if err := read.SearchWhere(ctx, &staticCreds, "session_id = ?", []interface{}{sessionId}); err != nil { return errors.Wrap(ctx, err, op) } - if len(creds) > 0 { - session.DynamicCredentials = creds + if len(staticCreds) > 0 { + session.StaticCredentials = staticCreds } connections, err := fetchConnections(ctx, read, sessionId, db.WithOrder("create_time desc")) diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 9ecf7a6ca9..3a259348c9 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/boundary/internal/authtoken" authtokenStore "github.com/hashicorp/boundary/internal/authtoken/store" cred "github.com/hashicorp/boundary/internal/credential" + credstatic "github.com/hashicorp/boundary/internal/credential/static" "github.com/hashicorp/boundary/internal/credential/vault" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" @@ -406,6 +407,7 @@ func TestRepository_CreateSession(t *testing.T) { ExpirationTime: tt.args.composedOf.ExpirationTime, ConnectionLimit: tt.args.composedOf.ConnectionLimit, DynamicCredentials: tt.args.composedOf.DynamicCredentials, + StaticCredentials: tt.args.composedOf.StaticCredentials, } ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s, tt.args.workerAddresses) if tt.wantErr { @@ -425,6 +427,7 @@ func TestRepository_CreateSession(t *testing.T) { require.NoError(err) assert.Equal(keyId, ses.KeyId) assert.Len(ses.DynamicCredentials, len(s.DynamicCredentials)) + assert.Len(ses.StaticCredentials, len(s.StaticCredentials)) foundSession, _, err := repo.LookupSession(context.Background(), ses.PublicId) assert.NoError(err) assert.Equal(keyId, foundSession.KeyId) @@ -447,6 +450,13 @@ func TestRepository_CreateSession(t *testing.T) { assert.NotEmpty(cred.LibraryId) assert.NotEmpty(cred.CredentialPurpose) } + + assert.Equal(s.StaticCredentials, foundSession.StaticCredentials) + for _, cred := range foundSession.StaticCredentials { + assert.NotEmpty(cred.CredentialStaticId) + assert.NotEmpty(cred.SessionId) + assert.NotEmpty(cred.CredentialPurpose) + } }) } } @@ -1440,19 +1450,28 @@ func testSessionCredentialParams(t *testing.T, conn *db.DB, wrapper wrapping.Wra require.NoError(err) require.NotNil(tar) - stores := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1) - libIds := vault.TestCredentialLibraries(t, conn, wrapper, stores[0].GetPublicId(), 2) + vaultStore := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1)[0] + libIds := vault.TestCredentialLibraries(t, conn, wrapper, vaultStore.GetPublicId(), 2) + + staticStore := credstatic.TestCredentialStore(t, conn, wrapper, params.ScopeId) + upCreds := credstatic.TestUsernamePasswordCredentials(t, conn, wrapper, "u", "p", staticStore.GetPublicId(), params.ScopeId, 2) ids := target.CredentialSources{ - ApplicationCredentialIds: []string{libIds[0].GetPublicId(), libIds[1].GetPublicId()}, + ApplicationCredentialIds: []string{libIds[0].GetPublicId(), libIds[1].GetPublicId(), upCreds[0].GetPublicId(), upCreds[1].GetPublicId()}, } _, _, _, err = targetRepo.AddTargetCredentialSources(ctx, tar.GetPublicId(), tar.GetVersion(), ids) require.NoError(err) - creds := []*DynamicCredential{ + dynamicCreds := []*DynamicCredential{ NewDynamicCredential(libIds[0].GetPublicId(), cred.ApplicationPurpose), NewDynamicCredential(libIds[1].GetPublicId(), cred.ApplicationPurpose), } - params.DynamicCredentials = creds + params.DynamicCredentials = dynamicCreds + + staticCreds := []*StaticCredential{ + NewStaticCredential(upCreds[0].GetPublicId(), cred.ApplicationPurpose), + NewStaticCredential(upCreds[1].GetPublicId(), cred.ApplicationPurpose), + } + params.StaticCredentials = staticCreds return params } diff --git a/internal/session/session.go b/internal/session/session.go index 5d6faf74ac..f497b3728e 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -54,6 +54,9 @@ type ComposedOf struct { // DynamicCredentials are dynamic credentials that will be retrieved // for the session. DynamicCredentials optional. DynamicCredentials []*DynamicCredential + // StaticCredentials are static credentials that will be retrieved + // for the session. StaticCredentials optional. + StaticCredentials []*StaticCredential } // Session contains information about a user's session with a target @@ -114,6 +117,9 @@ type Session struct { // DynamicCredentials for the session. DynamicCredentials []*DynamicCredential `gorm:"-"` + // StaticCredentials for the session. + StaticCredentials []*StaticCredential `gorm:"-"` + // Connections for the session are for read only and are ignored during write operations Connections []*Connection `gorm:"-"` @@ -153,6 +159,7 @@ func New(c ComposedOf, _ ...Option) (*Session, error) { ConnectionLimit: c.ConnectionLimit, WorkerFilter: c.WorkerFilter, DynamicCredentials: c.DynamicCredentials, + StaticCredentials: c.StaticCredentials, } if err := s.validateNewSession(); err != nil { return nil, errors.WrapDeprecated(err, op) @@ -198,6 +205,13 @@ func (s *Session) Clone() interface{} { clone.DynamicCredentials = append(clone.DynamicCredentials, cp) } } + if len(s.StaticCredentials) > 0 { + clone.StaticCredentials = make([]*StaticCredential, 0, len(s.StaticCredentials)) + for _, sc := range s.StaticCredentials { + cp := sc.clone() + clone.StaticCredentials = append(clone.StaticCredentials, cp) + } + } if s.TofuToken != nil { clone.TofuToken = make([]byte, len(s.TofuToken)) copy(clone.TofuToken, s.TofuToken) @@ -283,6 +297,8 @@ func (s *Session) VetForWrite(ctx context.Context, _ db.Reader, opType db.OpType return errors.New(ctx, errors.InvalidParameter, op, "worker filter is immutable") case contains(opts.WithFieldMaskPaths, "DynamicCredentials"): return errors.New(ctx, errors.InvalidParameter, op, "dynamic credentials are immutable") + case contains(opts.WithFieldMaskPaths, "StaticCredentials"): + return errors.New(ctx, errors.InvalidParameter, op, "static credentials are immutable") case contains(opts.WithFieldMaskPaths, "TerminationReason"): if _, err := convertToReason(s.TerminationReason); err != nil { return errors.Wrap(ctx, err, op) diff --git a/internal/session/session_test.go b/internal/session/session_test.go index a5699f14df..1d33d9a43d 100644 --- a/internal/session/session_test.go +++ b/internal/session/session_test.go @@ -58,6 +58,7 @@ func TestSession_Create(t *testing.T) { ExpirationTime: composedOf.ExpirationTime, ConnectionLimit: composedOf.ConnectionLimit, DynamicCredentials: composedOf.DynamicCredentials, + StaticCredentials: composedOf.StaticCredentials, }, create: true, }, @@ -158,6 +159,7 @@ func TestSession_Create(t *testing.T) { ExpirationTime: composedOf.ExpirationTime, ConnectionLimit: composedOf.ConnectionLimit, DynamicCredentials: composedOf.DynamicCredentials, + StaticCredentials: composedOf.StaticCredentials, }, wantAddrErr: true, wantIsErr: errors.InvalidParameter, diff --git a/internal/session/static_credential.go b/internal/session/static_credential.go new file mode 100644 index 0000000000..4cae6a1f70 --- /dev/null +++ b/internal/session/static_credential.go @@ -0,0 +1,45 @@ +package session + +import ( + cred "github.com/hashicorp/boundary/internal/credential" +) + +// A StaticCredential represents the relationship between a session, a +// credential and the purpose of the credential. +type StaticCredential struct { + SessionId string `json:"session_id,omitempty" gorm:"primary_key"` + CredentialPurpose string `json:"credential_purpose,omitempty" gorm:"primary_key"` + CredentialStaticId string `json:"credential_id,omitempty" gorm:"default:null"` + + tableName string `gorm:"-"` +} + +// NewStaticCredential creates a new in memory Credential representing the +// relationship between session a credential and the purpose of the credential. +func NewStaticCredential(id string, purpose cred.Purpose) *StaticCredential { + return &StaticCredential{ + CredentialStaticId: id, + CredentialPurpose: string(purpose), + } +} + +func (c *StaticCredential) clone() *StaticCredential { + return &StaticCredential{ + SessionId: c.SessionId, + CredentialPurpose: c.CredentialPurpose, + CredentialStaticId: c.CredentialStaticId, + } +} + +// TableName returns the table name. +func (c *StaticCredential) TableName() string { + if c.tableName != "" { + return c.tableName + } + return "session_credential_static" +} + +// SetTableName sets the table name. +func (c *StaticCredential) SetTableName(n string) { + c.tableName = n +}