From 07ac120ade59f47c1564e12bbb73d177f9eeb7b8 Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Tue, 15 Sep 2020 10:08:48 -0400 Subject: [PATCH] added Tofu token with unit tests --- internal/session/repository_session.go | 67 +++++++++++++++++++-- internal/session/repository_session_test.go | 15 ++++- internal/session/session.go | 40 ++++++++++++ internal/session/testing.go | 9 +++ 4 files changed, 125 insertions(+), 6 deletions(-) diff --git a/internal/session/repository_session.go b/internal/session/repository_session.go index d80449702c..ff9ff36736 100644 --- a/internal/session/repository_session.go +++ b/internal/session/repository_session.go @@ -8,7 +8,9 @@ import ( "github.com/hashicorp/boundary/internal/db" dbcommon "github.com/hashicorp/boundary/internal/db/common" + "github.com/hashicorp/boundary/internal/kms" wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/vault/sdk/helper/strutil" ) // CreateSession inserts into the repository and returns the new Session with @@ -49,6 +51,12 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping. if newSession.ServerType != "" { return nil, nil, fmt.Errorf("create session: server type must empty: %w", db.ErrInvalidParameter) } + if newSession.CtTofuToken != nil { + return nil, nil, fmt.Errorf("create session: ct must be empty: %w", db.ErrInvalidParameter) + } + if newSession.TofuToken != nil { + return nil, nil, fmt.Errorf("create session: tofu token must be empty: %w", db.ErrInvalidParameter) + } id, err := newId() if err != nil { @@ -125,6 +133,17 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt .. } return nil, nil, fmt.Errorf("lookup session: %w", err) } + if len(session.CtTofuToken) > 0 { + databaseWrapper, err := r.kms.GetWrapper(ctx, session.ScopeId, kms.KeyPurposeDatabase, kms.WithKeyId(session.KeyId)) + if err != nil { + return nil, nil, fmt.Errorf("lookup session: unable to get database wrapper: %w", err) + } + if err := session.decrypt(ctx, databaseWrapper); err != nil { + return nil, nil, fmt.Errorf("lookup session: cannot decrypt session value: %w", err) + } + } else { + session.CtTofuToken = nil + } return &session, states, nil } @@ -145,6 +164,11 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio if err != nil { return nil, fmt.Errorf("list sessions: %w", err) } + for _, s := range sessions { + s.CtTofuToken = nil + s.TofuToken = nil + s.KeyId = "" + } return sessions, nil } @@ -195,23 +219,49 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio if session.PublicId == "" { return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session public id %w", db.ErrInvalidParameter) } + if session.CtTofuToken != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: ct must be empty: %w", db.ErrInvalidParameter) + } + + translatedFieldMasks := make([]string, 0, len(fieldMaskPaths)) for _, f := range fieldMaskPaths { switch { case strings.EqualFold("TerminationReason", f): + translatedFieldMasks = append(translatedFieldMasks, f) case strings.EqualFold("ServerId", f): + translatedFieldMasks = append(translatedFieldMasks, f) case strings.EqualFold("ServerType", f): + translatedFieldMasks = append(translatedFieldMasks, f) + case strings.EqualFold("TofuToken", f): + translatedFieldMasks = append(translatedFieldMasks, "CtTofuToken") default: return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: field: %s: %w", f, db.ErrInvalidFieldMask) } } + + updateSession := session.Clone().(*Session) + if strutil.StrListContains(translatedFieldMasks, "CtTofuToken") && len(updateSession.TofuToken) != 0 { + if err := r.reader.LookupById(ctx, updateSession); err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, updateSession.PublicId) + } + databaseWrapper, err := r.kms.GetWrapper(ctx, updateSession.ScopeId, kms.KeyPurposeDatabase) + if err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: unable to get database wrapper: %w", err) + } + if err := updateSession.encrypt(ctx, databaseWrapper); err != nil { + return nil, nil, db.NoRowsAffected, fmt.Errorf("create session: %w", err) + } + } + var dbMask, nullFields []string dbMask, nullFields = dbcommon.BuildUpdatePaths( map[string]interface{}{ - "TerminationReason": session.TerminationReason, - "ServerId": session.ServerId, - "ServerType": session.ServerType, + "TerminationReason": updateSession.TerminationReason, + "ServerId": updateSession.ServerId, + "ServerType": updateSession.ServerType, + "CtTofuToken": updateSession.CtTofuToken, }, - fieldMaskPaths, + translatedFieldMasks, ) if len(dbMask) == 0 && len(nullFields) == 0 { return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", db.ErrEmptyFieldMask) @@ -226,7 +276,8 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio db.ExpBackoff{}, func(reader db.Reader, w db.Writer) error { var err error - s = session.Clone().(*Session) + s = updateSession.Clone().(*Session) + rowsUpdated, err = w.Update( ctx, s, @@ -250,6 +301,9 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio if err != nil { return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, session.PublicId) } + if len(s.CtTofuToken) == 0 { + s.CtTofuToken = nil + } return s, states, rowsUpdated, err } @@ -309,6 +363,9 @@ func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionV if err != nil { return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err) } + if len(updatedSession.CtTofuToken) == 0 { + updatedSession.CtTofuToken = nil + } return &updatedSession, returnedStates, nil } diff --git a/internal/session/repository_session_test.go b/internal/session/repository_session_test.go index 5e626ae0f7..98b797518b 100644 --- a/internal/session/repository_session_test.go +++ b/internal/session/repository_session_test.go @@ -424,6 +424,7 @@ func TestRepository_UpdateSession(t *testing.T) { terminationReason TerminationReason serverId string serverType string + tofu []byte fieldMaskPaths []string opt []Option publicId *string // not updateable - db.ErrInvalidFieldMask @@ -447,7 +448,8 @@ func TestRepository_UpdateSession(t *testing.T) { terminationReason: Terminated, serverId: newServerFunc(), serverType: servers.ServerTypeWorker.String(), - fieldMaskPaths: []string{"TerminationReason", "ServerId", "ServerType"}, + tofu: TestTofu(t), + fieldMaskPaths: []string{"TerminationReason", "ServerId", "ServerType", "TofuToken"}, }, wantErr: false, wantRowsUpdate: 1, @@ -576,6 +578,10 @@ func TestRepository_UpdateSession(t *testing.T) { updateSession.ServerId = tt.args.serverId updateSession.ServerType = tt.args.serverType updateSession.TerminationReason = tt.args.terminationReason.String() + if tt.args.tofu != nil { + updateSession.TofuToken = make([]byte, len(tt.args.tofu)) + copy(updateSession.TofuToken, tt.args.tofu) + } updateSession.Version = s.Version afterUpdateSession, afterUpdateState, updatedRows, err := repo.UpdateSession(context.Background(), &updateSession, updateSession.Version, tt.args.fieldMaskPaths, tt.args.opt...) @@ -612,6 +618,13 @@ func TestRepository_UpdateSession(t *testing.T) { if tt.args.serverType == "" { dbassrt.IsNull(foundSession, "ServerType") } + if tt.args.tofu == nil { + dbassrt.IsNull(foundSession, "CtTofuToken") + dbassrt.IsNull(foundSession, "KeyId") + } else { + dbassrt.NotNull(foundSession, "CtTofuToken") + dbassrt.NotNull(foundSession, "KeyId") + } assert.Equal(tt.args.terminationReason.String(), foundSession.TerminationReason) assert.Equal(tt.args.serverId, foundSession.ServerId) assert.Equal(tt.args.serverType, foundSession.ServerType) diff --git a/internal/session/session.go b/internal/session/session.go index 8c23db93c2..669434a1a1 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/go-kms-wrapping/structwrapping" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -63,6 +64,10 @@ type Session struct { Certificate []byte `json:"certificate,omitempty" gorm:"default:null"` // ExpirationTime - after this time the connection will be expired, e.g. forcefully terminated ExpirationTime *timestamp.Timestamp `json:"expiration_time,omitempty" gorm:"default:null"` + // CtTofuToken is the ciphertext Tofutoken value stored in the database + CtTofuToken []byte `json:"ct_tofu_token,omitempty" gorm:"column:tofu_token;default:null" wrapping:"ct,tofu_token"` + // TofuToken - plain text of the "trust on first use" token for session + TofuToken []byte `json:"tofu_token,omitempty" gorm:"-" wrapping:"pt,tofu_token"` // termination_reason for the session TerminationReason string `json:"termination_reason,omitempty" gorm:"default:null"` // CreateTime from the RDBMS @@ -72,6 +77,12 @@ type Session struct { // Version for the session Version uint32 `json:"version,omitempty" gorm:"default:null"` + // key_id is the key ID that was used for the encryption operation. It can be + // used to identify a specific version of the key needed to decrypt the value, + // which is useful for caching purposes. + // @inject_tag: `gorm:"not_null"` + KeyId string `json:"key_id,omitempty" gorm:"not_null"` + tableName string `gorm:"-"` } @@ -122,6 +133,14 @@ func (s *Session) Clone() interface{} { TerminationReason: s.TerminationReason, Version: s.Version, } + if s.TofuToken != nil { + clone.TofuToken = make([]byte, len(s.TofuToken)) + copy(clone.TofuToken, s.TofuToken) + } + if s.CtTofuToken != nil { + clone.CtTofuToken = make([]byte, len(s.CtTofuToken)) + copy(clone.CtTofuToken, s.CtTofuToken) + } if s.Certificate != nil { clone.Certificate = make([]byte, len(s.Certificate)) copy(clone.Certificate, s.Certificate) @@ -241,6 +260,12 @@ func (s *Session) validateNewSession(errorPrefix string) error { if s.ServerType != "" { return fmt.Errorf("%s server type must be empty: %w", errorPrefix, db.ErrInvalidParameter) } + if s.TofuToken != nil { + return fmt.Errorf("%s tofu token must be empty: %w", errorPrefix, db.ErrInvalidParameter) + } + if s.CtTofuToken != nil { + return fmt.Errorf("%s ct must be empty: %w", errorPrefix, db.ErrInvalidParameter) + } return nil } @@ -287,3 +312,18 @@ func newCert(wrapper wrapping.Wrapper, userId, jobId string) (ed25519.PrivateKey } return privKey, certBytes, nil } + +func (s *Session) encrypt(ctx context.Context, cipher wrapping.Wrapper) error { + if err := structwrapping.WrapStruct(ctx, cipher, s, nil); err != nil { + return fmt.Errorf("error encrypting session: %w", err) + } + s.KeyId = cipher.KeyID() + return nil +} + +func (s *Session) decrypt(ctx context.Context, cipher wrapping.Wrapper) error { + if err := structwrapping.UnwrapStruct(ctx, cipher, s, nil); err != nil { + return fmt.Errorf("error decrypting session: %w", err) + } + return nil +} diff --git a/internal/session/testing.go b/internal/session/testing.go index 826a4ccbee..365b4e4bf7 100644 --- a/internal/session/testing.go +++ b/internal/session/testing.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/boundary/internal/kms" "github.com/hashicorp/boundary/internal/target" wrapping "github.com/hashicorp/go-kms-wrapping" + "github.com/hashicorp/vault/sdk/helper/base62" "github.com/jinzhu/gorm" "github.com/stretchr/testify/require" ) @@ -132,6 +133,14 @@ func TestSessionParams(t *testing.T, conn *gorm.DB, wrapper wrapping.Wrapper, ia } } +func TestTofu(t *testing.T) []byte { + t.Helper() + require := require.New(t) + tofu, err := base62.Random(20) + require.NoError(err) + return []byte(tofu) +} + // TestCert is a temporary test func that intentionally doesn't take testing.T // as a parameter. It's currently used in controller.jobTestingHandler() and // should be deprecated once that function is refactored to use sessions properly.