added Tofu token with unit tests

jimlambrt-session-basics
Jim Lambert 6 years ago
parent 7bb52e60be
commit 07ac120ade

@ -8,7 +8,9 @@ import (
"github.com/hashicorp/boundary/internal/db"
dbcommon "github.com/hashicorp/boundary/internal/db/common"
"github.com/hashicorp/boundary/internal/kms"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/vault/sdk/helper/strutil"
)
// CreateSession inserts into the repository and returns the new Session with
@ -49,6 +51,12 @@ func (r *Repository) CreateSession(ctx context.Context, sessionWrapper wrapping.
if newSession.ServerType != "" {
return nil, nil, fmt.Errorf("create session: server type must empty: %w", db.ErrInvalidParameter)
}
if newSession.CtTofuToken != nil {
return nil, nil, fmt.Errorf("create session: ct must be empty: %w", db.ErrInvalidParameter)
}
if newSession.TofuToken != nil {
return nil, nil, fmt.Errorf("create session: tofu token must be empty: %w", db.ErrInvalidParameter)
}
id, err := newId()
if err != nil {
@ -125,6 +133,17 @@ func (r *Repository) LookupSession(ctx context.Context, sessionId string, opt ..
}
return nil, nil, fmt.Errorf("lookup session: %w", err)
}
if len(session.CtTofuToken) > 0 {
databaseWrapper, err := r.kms.GetWrapper(ctx, session.ScopeId, kms.KeyPurposeDatabase, kms.WithKeyId(session.KeyId))
if err != nil {
return nil, nil, fmt.Errorf("lookup session: unable to get database wrapper: %w", err)
}
if err := session.decrypt(ctx, databaseWrapper); err != nil {
return nil, nil, fmt.Errorf("lookup session: cannot decrypt session value: %w", err)
}
} else {
session.CtTofuToken = nil
}
return &session, states, nil
}
@ -145,6 +164,11 @@ func (r *Repository) ListSessions(ctx context.Context, opt ...Option) ([]*Sessio
if err != nil {
return nil, fmt.Errorf("list sessions: %w", err)
}
for _, s := range sessions {
s.CtTofuToken = nil
s.TofuToken = nil
s.KeyId = ""
}
return sessions, nil
}
@ -195,23 +219,49 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio
if session.PublicId == "" {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: missing session public id %w", db.ErrInvalidParameter)
}
if session.CtTofuToken != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: ct must be empty: %w", db.ErrInvalidParameter)
}
translatedFieldMasks := make([]string, 0, len(fieldMaskPaths))
for _, f := range fieldMaskPaths {
switch {
case strings.EqualFold("TerminationReason", f):
translatedFieldMasks = append(translatedFieldMasks, f)
case strings.EqualFold("ServerId", f):
translatedFieldMasks = append(translatedFieldMasks, f)
case strings.EqualFold("ServerType", f):
translatedFieldMasks = append(translatedFieldMasks, f)
case strings.EqualFold("TofuToken", f):
translatedFieldMasks = append(translatedFieldMasks, "CtTofuToken")
default:
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: field: %s: %w", f, db.ErrInvalidFieldMask)
}
}
updateSession := session.Clone().(*Session)
if strutil.StrListContains(translatedFieldMasks, "CtTofuToken") && len(updateSession.TofuToken) != 0 {
if err := r.reader.LookupById(ctx, updateSession); err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, updateSession.PublicId)
}
databaseWrapper, err := r.kms.GetWrapper(ctx, updateSession.ScopeId, kms.KeyPurposeDatabase)
if err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: unable to get database wrapper: %w", err)
}
if err := updateSession.encrypt(ctx, databaseWrapper); err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("create session: %w", err)
}
}
var dbMask, nullFields []string
dbMask, nullFields = dbcommon.BuildUpdatePaths(
map[string]interface{}{
"TerminationReason": session.TerminationReason,
"ServerId": session.ServerId,
"ServerType": session.ServerType,
"TerminationReason": updateSession.TerminationReason,
"ServerId": updateSession.ServerId,
"ServerType": updateSession.ServerType,
"CtTofuToken": updateSession.CtTofuToken,
},
fieldMaskPaths,
translatedFieldMasks,
)
if len(dbMask) == 0 && len(nullFields) == 0 {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w", db.ErrEmptyFieldMask)
@ -226,7 +276,8 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio
db.ExpBackoff{},
func(reader db.Reader, w db.Writer) error {
var err error
s = session.Clone().(*Session)
s = updateSession.Clone().(*Session)
rowsUpdated, err = w.Update(
ctx,
s,
@ -250,6 +301,9 @@ func (r *Repository) UpdateSession(ctx context.Context, session *Session, versio
if err != nil {
return nil, nil, db.NoRowsAffected, fmt.Errorf("update session: %w for %s", err, session.PublicId)
}
if len(s.CtTofuToken) == 0 {
s.CtTofuToken = nil
}
return s, states, rowsUpdated, err
}
@ -309,6 +363,9 @@ func (r *Repository) UpdateState(ctx context.Context, sessionId string, sessionV
if err != nil {
return nil, nil, fmt.Errorf("update session state: error creating new state: %w", err)
}
if len(updatedSession.CtTofuToken) == 0 {
updatedSession.CtTofuToken = nil
}
return &updatedSession, returnedStates, nil
}

@ -424,6 +424,7 @@ func TestRepository_UpdateSession(t *testing.T) {
terminationReason TerminationReason
serverId string
serverType string
tofu []byte
fieldMaskPaths []string
opt []Option
publicId *string // not updateable - db.ErrInvalidFieldMask
@ -447,7 +448,8 @@ func TestRepository_UpdateSession(t *testing.T) {
terminationReason: Terminated,
serverId: newServerFunc(),
serverType: servers.ServerTypeWorker.String(),
fieldMaskPaths: []string{"TerminationReason", "ServerId", "ServerType"},
tofu: TestTofu(t),
fieldMaskPaths: []string{"TerminationReason", "ServerId", "ServerType", "TofuToken"},
},
wantErr: false,
wantRowsUpdate: 1,
@ -576,6 +578,10 @@ func TestRepository_UpdateSession(t *testing.T) {
updateSession.ServerId = tt.args.serverId
updateSession.ServerType = tt.args.serverType
updateSession.TerminationReason = tt.args.terminationReason.String()
if tt.args.tofu != nil {
updateSession.TofuToken = make([]byte, len(tt.args.tofu))
copy(updateSession.TofuToken, tt.args.tofu)
}
updateSession.Version = s.Version
afterUpdateSession, afterUpdateState, updatedRows, err := repo.UpdateSession(context.Background(), &updateSession, updateSession.Version, tt.args.fieldMaskPaths, tt.args.opt...)
@ -612,6 +618,13 @@ func TestRepository_UpdateSession(t *testing.T) {
if tt.args.serverType == "" {
dbassrt.IsNull(foundSession, "ServerType")
}
if tt.args.tofu == nil {
dbassrt.IsNull(foundSession, "CtTofuToken")
dbassrt.IsNull(foundSession, "KeyId")
} else {
dbassrt.NotNull(foundSession, "CtTofuToken")
dbassrt.NotNull(foundSession, "KeyId")
}
assert.Equal(tt.args.terminationReason.String(), foundSession.TerminationReason)
assert.Equal(tt.args.serverId, foundSession.ServerId)
assert.Equal(tt.args.serverType, foundSession.ServerType)

@ -14,6 +14,7 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/go-kms-wrapping/structwrapping"
"google.golang.org/protobuf/types/known/timestamppb"
)
@ -63,6 +64,10 @@ type Session struct {
Certificate []byte `json:"certificate,omitempty" gorm:"default:null"`
// ExpirationTime - after this time the connection will be expired, e.g. forcefully terminated
ExpirationTime *timestamp.Timestamp `json:"expiration_time,omitempty" gorm:"default:null"`
// CtTofuToken is the ciphertext Tofutoken value stored in the database
CtTofuToken []byte `json:"ct_tofu_token,omitempty" gorm:"column:tofu_token;default:null" wrapping:"ct,tofu_token"`
// TofuToken - plain text of the "trust on first use" token for session
TofuToken []byte `json:"tofu_token,omitempty" gorm:"-" wrapping:"pt,tofu_token"`
// termination_reason for the session
TerminationReason string `json:"termination_reason,omitempty" gorm:"default:null"`
// CreateTime from the RDBMS
@ -72,6 +77,12 @@ type Session struct {
// Version for the session
Version uint32 `json:"version,omitempty" gorm:"default:null"`
// key_id is the key ID that was used for the encryption operation. It can be
// used to identify a specific version of the key needed to decrypt the value,
// which is useful for caching purposes.
// @inject_tag: `gorm:"not_null"`
KeyId string `json:"key_id,omitempty" gorm:"not_null"`
tableName string `gorm:"-"`
}
@ -122,6 +133,14 @@ func (s *Session) Clone() interface{} {
TerminationReason: s.TerminationReason,
Version: s.Version,
}
if s.TofuToken != nil {
clone.TofuToken = make([]byte, len(s.TofuToken))
copy(clone.TofuToken, s.TofuToken)
}
if s.CtTofuToken != nil {
clone.CtTofuToken = make([]byte, len(s.CtTofuToken))
copy(clone.CtTofuToken, s.CtTofuToken)
}
if s.Certificate != nil {
clone.Certificate = make([]byte, len(s.Certificate))
copy(clone.Certificate, s.Certificate)
@ -241,6 +260,12 @@ func (s *Session) validateNewSession(errorPrefix string) error {
if s.ServerType != "" {
return fmt.Errorf("%s server type must be empty: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.TofuToken != nil {
return fmt.Errorf("%s tofu token must be empty: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.CtTofuToken != nil {
return fmt.Errorf("%s ct must be empty: %w", errorPrefix, db.ErrInvalidParameter)
}
return nil
}
@ -287,3 +312,18 @@ func newCert(wrapper wrapping.Wrapper, userId, jobId string) (ed25519.PrivateKey
}
return privKey, certBytes, nil
}
func (s *Session) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
if err := structwrapping.WrapStruct(ctx, cipher, s, nil); err != nil {
return fmt.Errorf("error encrypting session: %w", err)
}
s.KeyId = cipher.KeyID()
return nil
}
func (s *Session) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
if err := structwrapping.UnwrapStruct(ctx, cipher, s, nil); err != nil {
return fmt.Errorf("error decrypting session: %w", err)
}
return nil
}

@ -16,6 +16,7 @@ import (
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/target"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/require"
)
@ -132,6 +133,14 @@ func TestSessionParams(t *testing.T, conn *gorm.DB, wrapper wrapping.Wrapper, ia
}
}
func TestTofu(t *testing.T) []byte {
t.Helper()
require := require.New(t)
tofu, err := base62.Random(20)
require.NoError(err)
return []byte(tofu)
}
// TestCert is a temporary test func that intentionally doesn't take testing.T
// as a parameter. It's currently used in controller.jobTestingHandler() and
// should be deprecated once that function is refactored to use sessions properly.

Loading…
Cancel
Save