feat(session): Add support for session static credentials

pull/2174/head
Louis Ruch 4 years ago
parent 00dfea1244
commit 58d9d42a88

@ -0,0 +1,32 @@
begin;
create table session_credential_static (
session_id wt_public_id not null
constraint session_fkey
references session (public_id)
on delete cascade
on update cascade,
credential_static_id wt_public_id
constraint credential_static_fkey
references credential_static (public_id)
on delete cascade
on update cascade,
credential_purpose text not null
constraint credential_purpose_fkey
references credential_purpose_enm (name)
on delete restrict
on update cascade,
primary key(session_id, credential_static_id, credential_purpose),
create_time wt_timestamp
);
comment on table session_credential_dynamic is
'session_credential_static is a join table between the session and static credential tables. '
'It also contains the credential purpose the relationship represents.';
create trigger default_create_time_column before insert on session_credential_static
for each row execute procedure default_create_time();
create trigger immutable_columns before update on session_credential_static
for each row execute procedure immutable_columns('session_id', 'credential_static_id', 'credential_purpose', 'create_time');
commit;

@ -93,6 +93,7 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
func(read db.Reader, w db.Writer) error {
returnedSession = newSession.Clone().(*Session)
returnedSession.DynamicCredentials = nil
returnedSession.StaticCredentials = nil
if err = w.Create(ctx, returnedSession); err != nil {
return errors.Wrap(ctx, err, op)
}
@ -101,6 +102,25 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
cred.SessionId = newSession.PublicId
}
var staticCreds []interface{}
for _, cred := range newSession.StaticCredentials {
cred.SessionId = newSession.PublicId
staticCreds = append(staticCreds, cred)
}
if len(staticCreds) > 0 {
if err = w.CreateItems(ctx, staticCreds); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to create static credentials"))
}
// Get static creds back from the db for return
var c []*StaticCredential
if err := read.SearchWhere(ctx, &c, "session_id = ?", []interface{}{newSession.PublicId}); err != nil {
return errors.Wrap(ctx, err, op)
}
returnedSession.StaticCredentials = c
}
// TODO: after upgrading to gorm v2 this batch insert can be replaced, since gorm v2 supports batch inserts
q, batchInsertArgs, err := batchInsertsessionCredentialDynamic(newSession.DynamicCredentials)
if err == nil {
@ -165,12 +185,20 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, _ ...O
}
session.States = states
var creds []*DynamicCredential
if err := read.SearchWhere(ctx, &creds, "session_id = ?", []interface{}{sessionId}); err != nil {
var dynamicCreds []*DynamicCredential
if err := read.SearchWhere(ctx, &dynamicCreds, "session_id = ?", []interface{}{sessionId}); err != nil {
return errors.Wrap(ctx, err, op)
}
if len(dynamicCreds) > 0 {
session.DynamicCredentials = dynamicCreds
}
var staticCreds []*StaticCredential
if err := read.SearchWhere(ctx, &staticCreds, "session_id = ?", []interface{}{sessionId}); err != nil {
return errors.Wrap(ctx, err, op)
}
if len(creds) > 0 {
session.DynamicCredentials = creds
if len(staticCreds) > 0 {
session.StaticCredentials = staticCreds
}
connections, err := fetchConnections(ctx, read, sessionId, db.WithOrder("create_time desc"))

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/internal/authtoken"
authtokenStore "github.com/hashicorp/boundary/internal/authtoken/store"
cred "github.com/hashicorp/boundary/internal/credential"
credstatic "github.com/hashicorp/boundary/internal/credential/static"
"github.com/hashicorp/boundary/internal/credential/vault"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
@ -406,6 +407,7 @@ func TestRepository_CreateSession(t *testing.T) {
ExpirationTime: tt.args.composedOf.ExpirationTime,
ConnectionLimit: tt.args.composedOf.ConnectionLimit,
DynamicCredentials: tt.args.composedOf.DynamicCredentials,
StaticCredentials: tt.args.composedOf.StaticCredentials,
}
ses, privKey, err := repo.CreateSession(context.Background(), wrapper, s, tt.args.workerAddresses)
if tt.wantErr {
@ -425,6 +427,7 @@ func TestRepository_CreateSession(t *testing.T) {
require.NoError(err)
assert.Equal(keyId, ses.KeyId)
assert.Len(ses.DynamicCredentials, len(s.DynamicCredentials))
assert.Len(ses.StaticCredentials, len(s.StaticCredentials))
foundSession, _, err := repo.LookupSession(context.Background(), ses.PublicId)
assert.NoError(err)
assert.Equal(keyId, foundSession.KeyId)
@ -447,6 +450,13 @@ func TestRepository_CreateSession(t *testing.T) {
assert.NotEmpty(cred.LibraryId)
assert.NotEmpty(cred.CredentialPurpose)
}
assert.Equal(s.StaticCredentials, foundSession.StaticCredentials)
for _, cred := range foundSession.StaticCredentials {
assert.NotEmpty(cred.CredentialStaticId)
assert.NotEmpty(cred.SessionId)
assert.NotEmpty(cred.CredentialPurpose)
}
})
}
}
@ -1440,19 +1450,28 @@ func testSessionCredentialParams(t *testing.T, conn *db.DB, wrapper wrapping.Wra
require.NoError(err)
require.NotNil(tar)
stores := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1)
libIds := vault.TestCredentialLibraries(t, conn, wrapper, stores[0].GetPublicId(), 2)
vaultStore := vault.TestCredentialStores(t, conn, wrapper, params.ScopeId, 1)[0]
libIds := vault.TestCredentialLibraries(t, conn, wrapper, vaultStore.GetPublicId(), 2)
staticStore := credstatic.TestCredentialStore(t, conn, wrapper, params.ScopeId)
upCreds := credstatic.TestUsernamePasswordCredentials(t, conn, wrapper, "u", "p", staticStore.GetPublicId(), params.ScopeId, 2)
ids := target.CredentialSources{
ApplicationCredentialIds: []string{libIds[0].GetPublicId(), libIds[1].GetPublicId()},
ApplicationCredentialIds: []string{libIds[0].GetPublicId(), libIds[1].GetPublicId(), upCreds[0].GetPublicId(), upCreds[1].GetPublicId()},
}
_, _, _, err = targetRepo.AddTargetCredentialSources(ctx, tar.GetPublicId(), tar.GetVersion(), ids)
require.NoError(err)
creds := []*DynamicCredential{
dynamicCreds := []*DynamicCredential{
NewDynamicCredential(libIds[0].GetPublicId(), cred.ApplicationPurpose),
NewDynamicCredential(libIds[1].GetPublicId(), cred.ApplicationPurpose),
}
params.DynamicCredentials = creds
params.DynamicCredentials = dynamicCreds
staticCreds := []*StaticCredential{
NewStaticCredential(upCreds[0].GetPublicId(), cred.ApplicationPurpose),
NewStaticCredential(upCreds[1].GetPublicId(), cred.ApplicationPurpose),
}
params.StaticCredentials = staticCreds
return params
}

@ -54,6 +54,9 @@ type ComposedOf struct {
// DynamicCredentials are dynamic credentials that will be retrieved
// for the session. DynamicCredentials optional.
DynamicCredentials []*DynamicCredential
// StaticCredentials are static credentials that will be retrieved
// for the session. StaticCredentials optional.
StaticCredentials []*StaticCredential
}
// Session contains information about a user's session with a target
@ -114,6 +117,9 @@ type Session struct {
// DynamicCredentials for the session.
DynamicCredentials []*DynamicCredential `gorm:"-"`
// StaticCredentials for the session.
StaticCredentials []*StaticCredential `gorm:"-"`
// Connections for the session are for read only and are ignored during write operations
Connections []*Connection `gorm:"-"`
@ -153,6 +159,7 @@ func New(c ComposedOf, _ ...Option) (*Session, error) {
ConnectionLimit: c.ConnectionLimit,
WorkerFilter: c.WorkerFilter,
DynamicCredentials: c.DynamicCredentials,
StaticCredentials: c.StaticCredentials,
}
if err := s.validateNewSession(); err != nil {
return nil, errors.WrapDeprecated(err, op)
@ -198,6 +205,13 @@ func (s *Session) Clone() interface{} {
clone.DynamicCredentials = append(clone.DynamicCredentials, cp)
}
}
if len(s.StaticCredentials) > 0 {
clone.StaticCredentials = make([]*StaticCredential, 0, len(s.StaticCredentials))
for _, sc := range s.StaticCredentials {
cp := sc.clone()
clone.StaticCredentials = append(clone.StaticCredentials, cp)
}
}
if s.TofuToken != nil {
clone.TofuToken = make([]byte, len(s.TofuToken))
copy(clone.TofuToken, s.TofuToken)
@ -283,6 +297,8 @@ func (s *Session) VetForWrite(ctx context.Context, _ db.Reader, opType db.OpType
return errors.New(ctx, errors.InvalidParameter, op, "worker filter is immutable")
case contains(opts.WithFieldMaskPaths, "DynamicCredentials"):
return errors.New(ctx, errors.InvalidParameter, op, "dynamic credentials are immutable")
case contains(opts.WithFieldMaskPaths, "StaticCredentials"):
return errors.New(ctx, errors.InvalidParameter, op, "static credentials are immutable")
case contains(opts.WithFieldMaskPaths, "TerminationReason"):
if _, err := convertToReason(s.TerminationReason); err != nil {
return errors.Wrap(ctx, err, op)

@ -58,6 +58,7 @@ func TestSession_Create(t *testing.T) {
ExpirationTime: composedOf.ExpirationTime,
ConnectionLimit: composedOf.ConnectionLimit,
DynamicCredentials: composedOf.DynamicCredentials,
StaticCredentials: composedOf.StaticCredentials,
},
create: true,
},
@ -158,6 +159,7 @@ func TestSession_Create(t *testing.T) {
ExpirationTime: composedOf.ExpirationTime,
ConnectionLimit: composedOf.ConnectionLimit,
DynamicCredentials: composedOf.DynamicCredentials,
StaticCredentials: composedOf.StaticCredentials,
},
wantAddrErr: true,
wantIsErr: errors.InvalidParameter,

@ -0,0 +1,45 @@
package session
import (
cred "github.com/hashicorp/boundary/internal/credential"
)
// A StaticCredential represents the relationship between a session, a
// credential and the purpose of the credential.
type StaticCredential struct {
SessionId string `json:"session_id,omitempty" gorm:"primary_key"`
CredentialPurpose string `json:"credential_purpose,omitempty" gorm:"primary_key"`
CredentialStaticId string `json:"credential_id,omitempty" gorm:"default:null"`
tableName string `gorm:"-"`
}
// NewStaticCredential creates a new in memory Credential representing the
// relationship between session a credential and the purpose of the credential.
func NewStaticCredential(id string, purpose cred.Purpose) *StaticCredential {
return &StaticCredential{
CredentialStaticId: id,
CredentialPurpose: string(purpose),
}
}
func (c *StaticCredential) clone() *StaticCredential {
return &StaticCredential{
SessionId: c.SessionId,
CredentialPurpose: c.CredentialPurpose,
CredentialStaticId: c.CredentialStaticId,
}
}
// TableName returns the table name.
func (c *StaticCredential) TableName() string {
if c.tableName != "" {
return c.tableName
}
return "session_credential_static"
}
// SetTableName sets the table name.
func (c *StaticCredential) SetTableName(n string) {
c.tableName = n
}
Loading…
Cancel
Save