Merge pull request #4804 from hashicorp/backport/tmessi-fix-session-idle-in-transaction/widely-engaged-haddock

This pull request was automerged via backport-assistant
pull/4807/head
hc-github-team-secure-boundary 2 years ago committed by GitHub
commit 09e32f7b76
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -252,15 +252,6 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
return errors.Wrap(ctx, err, op)
}
session.Connections = connections
if session.ProjectId == "" || session.UserId == "" {
// Skip decryption if Project ID or UserId is missing,
// since it will just lead to errors, and the session
// is already canceled if either of those are empty.
return nil
}
if err := decryptAndMaybeUpdateSession(ctx, r.kms, &session, w); err != nil && !opts.withIgnoreDecryptionFailures {
return errors.Wrap(ctx, err, op)
}
return nil
},
)
@ -271,6 +262,15 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
return nil, nil, errors.Wrap(ctx, err, op)
}
// Skip decryption if Project ID or UserId is missing,
// since it will just lead to errors, and the session
// is already canceled if either of those are empty.
if session.ProjectId != "" && session.UserId != "" {
if err := decrypt(ctx, r.kms, &session); err != nil && !opts.withIgnoreDecryptionFailures {
return nil, nil, errors.Wrap(ctx, err, op)
}
}
authzSummary, err := r.sessionAuthzSummary(ctx, sessionId)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("failed to get authz summary"))
@ -602,7 +602,7 @@ func (r *Repository) lookupActivatedSessionTx(ctx context.Context, reader db.Rea
if txErr = reader.LookupById(ctx, activatedSession); txErr != nil {
return errors.Wrap(ctx, txErr, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId)))
}
if txErr = decryptAndMaybeUpdateSession(ctx, r.kms, activatedSession, writer); txErr != nil {
if txErr = decrypt(ctx, r.kms, activatedSession); txErr != nil {
return errors.Wrap(ctx, txErr, op)
}
if len(activatedSession.TofuToken) > 0 && subtle.ConstantTimeCompare(activatedSession.TofuToken, tofuToken) != 1 {
@ -637,10 +637,10 @@ func (r *Repository) getActivatedSession(ctx context.Context, sessionId string,
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
err := r.lookupActivatedSessionTx(ctx, reader, w, sessionId, tofuToken, &activatedSession)
func(reader db.Reader, _ db.Writer) error {
err := reader.LookupById(ctx, &activatedSession)
if err != nil {
return errors.Wrap(ctx, err, op)
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId)))
}
returnedStates, err = r.fetchActivatedSessionStatesTx(ctx, reader, sessionId)
if err != nil {
@ -652,6 +652,12 @@ func (r *Repository) getActivatedSession(ctx context.Context, sessionId string,
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
if err := decrypt(ctx, r.kms, &activatedSession); err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
if len(activatedSession.TofuToken) > 0 && subtle.ConstantTimeCompare(activatedSession.TofuToken, tofuToken) != 1 {
return nil, nil, errors.New(ctx, errors.TokenMismatch, op, "tofu token mismatch")
}
return &activatedSession, returnedStates, nil
}
@ -674,10 +680,28 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing tofu token")
}
// Lookup session first to get the project id so the correct kms wrapper can be used for encrypting the tofu.
foundSession := AllocSession()
foundSession.PublicId = sessionId
if err := r.reader.LookupById(ctx, &foundSession); err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId)))
}
// Encrypt the tofu before we start a database transaction to avoid holding the transaction while encrypting.
updatedSession := AllocSession()
updatedSession.PublicId = sessionId
updatedSession.TofuToken = tofuToken
sessionWrapper, err := r.kms.GetWrapper(ctx, foundSession.ProjectId, kms.KeyPurposeSessions)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper"))
}
if err := updatedSession.encrypt(ctx, sessionWrapper); err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
var tofuSeen bool
var returnedStates []*State
_, err := r.writer.DoTx(
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
db.ExpBackoff{},
@ -692,21 +716,26 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess
if rowsAffected == 0 {
return errors.New(ctx, errors.InvalidSessionState, op, "session is not in a pending state")
}
foundSession := AllocSession()
foundSession = AllocSession()
foundSession.PublicId = sessionId
err = r.lookupActivatedSessionTx(ctx, reader, w, sessionId, tofuToken, &foundSession)
err = reader.LookupById(ctx, &foundSession)
if err != nil {
return errors.Wrap(ctx, err, op)
return errors.Wrap(ctx, err, op, errors.WithMsg(fmt.Sprintf("failed for %s", sessionId)))
}
updatedSession.TofuToken = tofuToken
sessionWrapper, err := r.kms.GetWrapper(ctx, foundSession.ProjectId, kms.KeyPurposeSessions)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper"))
}
if err := updatedSession.encrypt(ctx, sessionWrapper); err != nil {
return errors.Wrap(ctx, err, op)
// If we already have recorded a tofu, we don't need to update anything.
// Once we are out of the transaction, we can decrypt and check if the
// recorded tofu matches.
if len(foundSession.CtTofuToken) > 0 {
tofuSeen = true
returnedStates, err = r.fetchActivatedSessionStatesTx(ctx, reader, sessionId)
if err != nil {
return errors.Wrap(ctx, err, op)
}
return nil
}
rowsUpdated, err := w.Update(ctx, &updatedSession, []string{"CtTofuToken", "KeyId"}, nil)
if err != nil {
return errors.Wrap(ctx, err, op)
@ -731,6 +760,16 @@ func (r *Repository) ActivateSession(ctx context.Context, sessionId string, sess
}
return nil, nil, errors.Wrap(ctx, err, op)
}
if tofuSeen {
if err := decrypt(ctx, r.kms, &foundSession); err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
if subtle.ConstantTimeCompare(foundSession.TofuToken, tofuToken) != 1 {
return nil, nil, errors.New(ctx, errors.TokenMismatch, op, "tofu token mismatch")
}
return &foundSession, returnedStates, nil
}
return &updatedSession, returnedStates, nil
}
@ -774,10 +813,6 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
if rowsUpdated != 1 {
return errors.New(ctx, errors.MultipleRecords, op, fmt.Sprintf("updated session and %d rows updated", rowsUpdated))
}
if err := decryptAndMaybeUpdateSession(ctx, r.kms, &updatedSession, w); err != nil && !opts.withIgnoreDecryptionFailures {
return errors.Wrap(ctx, err, op)
}
rowsAffected, err = w.Exec(ctx, updateSessionState, []any{
sql.Named("session_id", sessionId),
sql.Named("status", s.String()),
@ -807,6 +842,11 @@ func (r *Repository) updateState(ctx context.Context, sessionId string, sessionV
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithMsg("error creating new state"))
}
if err := decrypt(ctx, r.kms, &updatedSession); err != nil && !opts.withIgnoreDecryptionFailures {
return nil, nil, errors.Wrap(ctx, err, op)
}
return &updatedSession, returnedStates, nil
}
@ -916,21 +956,15 @@ func fetchHostSetHost(ctx context.Context, r db.Reader, sessionId string, opt ..
return hostSetHost, nil
}
// decryptAndMaybeUpdateSession switches between the database key and session key based on whether
// the session uses a legacy private key or not. It also updates the encrypted session
// in the database (if necessary). Eventually we should be able to remove this function
// and use the session key unconditionally.
func decryptAndMaybeUpdateSession(ctx context.Context, kmsRepo kms.GetWrapperer, session *Session, writer db.Writer) error {
const op = "session.decryptAndMaybeUpdateSession"
// decrypt decrypts encrypted fields of the Session.
func decrypt(ctx context.Context, kmsRepo kms.GetWrapperer, session *Session) error {
const op = "session.decrypt"
if util.IsNil(kmsRepo) {
return errors.New(ctx, errors.InvalidParameter, op, "missing kms repo")
}
if session == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing session")
}
if util.IsNil(writer) {
return errors.New(ctx, errors.InvalidParameter, op, "missing writer")
}
if session.ProjectId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing session project ID")
}
@ -947,43 +981,8 @@ func decryptAndMaybeUpdateSession(ctx context.Context, kmsRepo kms.GetWrapperer,
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get session wrapper"))
}
if len(session.CtCertificatePrivateKey) > 0 {
// New-style session with private key stored in DB, just decrypt.
if err := session.decrypt(ctx, sessionWrapper); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session value"))
}
return nil
}
// No certificate private key present, this is a legacy session with a
// private key derived from the key. Derive it again and store it back
// in the DB.
_, session.CertificatePrivateKey, err = DeriveED25519Key(ctx, sessionWrapper, session.UserId, session.PublicId)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("Error deriving session key"))
}
updatedFields := []string{"CtCertificatePrivateKey", "KeyId"}
if len(session.CtTofuToken) > 0 {
// If the TOFU token was set on this session, we can decrypt it using
// the database key.
databaseWrapper, err := kmsRepo.GetWrapper(ctx, session.ProjectId, kms.KeyPurposeDatabase)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to get database wrapper"))
}
if err := session.decrypt(ctx, databaseWrapper); err != nil {
// Note; we can hit this error if the database key that
// was used to encrypt the TOFU token is no longer available
// in the wrapper. Try to return a useful error to the user.
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session TOFU token. You may need to recreate your session."))
}
updatedFields = append(updatedFields, "CtTofuToken")
}
// Rewrap with the session wrapper. Next time we look up this session
// all values will be encrypted using the session key.
if err := session.encrypt(ctx, sessionWrapper); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to encrypt session value"))
}
if _, err := writer.Update(ctx, session, updatedFields, nil); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to update session value"))
if err := session.decrypt(ctx, sessionWrapper); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("unable to decrypt session value"))
}
return nil
}

@ -1919,9 +1919,8 @@ func TestRepository_deleteTerminated(t *testing.T) {
}
}
func Test_decryptAndMaybeUpdateSession(t *testing.T) {
func Test_decrypt(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsRepo := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
@ -1929,89 +1928,41 @@ func Test_decryptAndMaybeUpdateSession(t *testing.T) {
t.Run("errors-with-invalid-kms", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
err := decryptAndMaybeUpdateSession(ctx, nil, s, rw)
err := decrypt(ctx, nil, s)
require.Error(t, err)
})
t.Run("errors-with-invalid-session", func(t *testing.T) {
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, nil, rw)
require.Error(t, err)
})
t.Run("errors-with-invalid-writer", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, nil)
err := decrypt(ctx, kmsRepo, nil)
require.Error(t, err)
})
t.Run("errors-with-invalid-session-project-id", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.ProjectId = ""
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
err := decrypt(ctx, kmsRepo, s)
require.Error(t, err)
})
t.Run("errors-with-invalid-session-key-id", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.KeyId = ""
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
err := decrypt(ctx, kmsRepo, s)
require.Error(t, err)
})
t.Run("errors-with-invalid-session-user-id", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.UserId = ""
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
err := decrypt(ctx, kmsRepo, s)
require.Error(t, err)
})
t.Run("errors-with-invalid-session-public-id", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.PublicId = ""
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
err := decrypt(ctx, kmsRepo, s)
require.Error(t, err)
})
t.Run("session-with-local-session-key-succeeds", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
require.NoError(t, err)
})
t.Run("session-with-derived-session-key-succeeds", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.CtCertificatePrivateKey = nil
s.CertificatePrivateKey = nil
s.TofuToken = nil
s.CtTofuToken = nil
err := decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
require.NoError(t, err)
})
t.Run("session-with-derived-session-key-and-tofu-token-succeeds", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.CtCertificatePrivateKey = nil
s.CertificatePrivateKey = nil
s.TofuToken = []byte("A token")
actualKeyId := s.KeyId
databaseWrapper, err := kmsRepo.GetWrapper(ctx, s.ProjectId, kms.KeyPurposeDatabase)
require.NoError(t, err)
err = s.encrypt(ctx, databaseWrapper)
require.NoError(t, err)
s.KeyId = actualKeyId // Restore this as the encrypt call above will overwrite it.
err = decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
require.NoError(t, err)
})
t.Run("session-with-derived-session-key-and-tofu-token-cannot-be-decrypted", func(t *testing.T) {
s := TestDefaultSession(t, conn, wrapper, iamRepo)
s.CtCertificatePrivateKey = nil
s.CertificatePrivateKey = nil
s.TofuToken = []byte("A token")
actualKeyId := s.KeyId
databaseWrapper, err := kmsRepo.GetWrapper(ctx, s.ProjectId, kms.KeyPurposeDatabase)
require.NoError(t, err)
err = s.encrypt(ctx, databaseWrapper)
require.NoError(t, err)
databaseKeyId := s.KeyId
err = kmsRepo.RotateKeys(ctx, s.ProjectId)
require.NoError(t, err)
ok, err := kmsRepo.DestroyKeyVersion(ctx, s.ProjectId, databaseKeyId)
err := decrypt(ctx, kmsRepo, s)
require.NoError(t, err)
assert.True(t, ok)
s.KeyId = actualKeyId // Restore this as the encrypt call above will overwrite it.
err = decryptAndMaybeUpdateSession(ctx, kmsRepo, s, rw)
require.ErrorContains(t, err, "You may need to recreate your session")
})
}

@ -109,7 +109,7 @@ func sessionRewrapFn(ctx context.Context, dataKeyVersionId string, scopeId strin
}
continue
}
if err := decryptAndMaybeUpdateSession(ctx, kmsRepo, session, writer); err != nil {
if err := decrypt(ctx, kmsRepo, session); err != nil {
return errors.Wrap(ctx, err, op, errors.WithMsg("failed to decrypt session"))
}
wrapper, err := kmsRepo.GetWrapper(ctx, session.GetProjectId(), kms.KeyPurposeSessions)

Loading…
Cancel
Save