You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/session/session.go

718 lines
26 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package session
import (
"context"
"crypto/ed25519"
"crypto/x509"
"io"
"math/big"
mathrand "math/rand"
"net"
"strings"
"time"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/util"
wrapping "github.com/hashicorp/go-kms-wrapping/v2"
"github.com/hashicorp/go-kms-wrapping/v2/extras/structwrapping"
"google.golang.org/protobuf/types/known/timestamppb"
)
const (
defaultSessionTableName = "session"
)
// ComposedOf defines the boundary data that is referenced to compose a session.
type ComposedOf struct {
// UserId of the session
UserId string
// HostId of the session
HostId string
// TargetId of the session
TargetId string
// HostSetId of the session
HostSetId string
// AuthTokenId of the session
AuthTokenId string
// ProjectId of the session
ProjectId string
// Endpoint. This is generated by the target, but is not stored in the
// warehouse as the worker may need to e.g. resolve DNS. This is to round
// trip the information to the worker when it validates a session.
Endpoint string
// Expiration time for the session
ExpirationTime *timestamp.Timestamp
// Max connections for the session
ConnectionLimit int32
// Ingress and egress worker filters. Active filters when the session was created, used to
// validate the session via the same set of rules at consumption time as
// existed at creation time. Round tripping it through here saves a lookup
// in the DB. It is not stored in the warehouse.
WorkerFilter string
EgressWorkerFilter string
IngressWorkerFilter string
// DynamicCredentials are dynamic credentials that will be retrieved
// for the session. DynamicCredentials optional.
DynamicCredentials []*DynamicCredential
// StaticCredentials are static credentials that will be retrieved
// for the session. StaticCredentials optional.
StaticCredentials []*StaticCredential
// Which worker is performing protocol-related tasks
ProtocolWorkerId string
// CorrelationId is the CorrelationId of the authorize-session request
CorrelationId string `json:"-" gorm:"default:null"`
}
// Session contains information about a user's session with a target
type Session struct {
// PublicId is used to access the session via an API
PublicId string `json:"public_id,omitempty" gorm:"primary_key"`
// UserId for the session
UserId string `json:"user_id,omitempty" gorm:"default:null"`
// TargetId for the session
TargetId string `json:"target_id,omitempty" gorm:"default:null"`
// AuthTokenId for the session
AuthTokenId string `json:"auth_token_id,omitempty" gorm:"default:null"`
// ProjectId for the session
ProjectId string `json:"project_id,omitempty" gorm:"default:null"`
// Certificate to use when connecting (or if using custom certs, to
// serve as the "login"). Raw DER bytes.
Certificate []byte `json:"certificate,omitempty" gorm:"default:null"`
// CtCertificatePrivateKey is the ciphertext certificate private key which is stored in the database
CtCertificatePrivateKey []byte `json:"ct_certificate_private_key,omitempty" gorm:"column:certificate_private_key;default:null" wrapping:"ct,certificate_private_key"`
// CertificatePrivateKey is the certificate private key in plaintext.
// This may not be set for some sessions, in which case the private
// key should be derived from the encryption key referenced in key_id.
CertificatePrivateKey []byte `json:"certificate_private_key,omitempty" gorm:"-" wrapping:"pt,certificate_private_key"`
// ExpirationTime - after this time the connection will be expired, e.g. forcefully terminated
ExpirationTime *timestamp.Timestamp `json:"expiration_time,omitempty" gorm:"default:null"`
// CtTofuToken is the ciphertext Tofutoken value stored in the database
CtTofuToken []byte `json:"ct_tofu_token,omitempty" gorm:"column:tofu_token;default:null" wrapping:"ct,tofu_token"`
// TofuToken - plain text of the "trust on first use" token for session
TofuToken []byte `json:"tofu_token,omitempty" gorm:"-" wrapping:"pt,tofu_token"`
// termination_reason for the session
TerminationReason string `json:"termination_reason,omitempty" gorm:"default:null"`
// CreateTime from the RDBMS
CreateTime *timestamp.Timestamp `json:"create_time,omitempty" gorm:"default:current_timestamp"`
// UpdateTime from the RDBMS
UpdateTime *timestamp.Timestamp `json:"update_time,omitempty" gorm:"default:current_timestamp"`
// Version for the session
Version uint32 `json:"version,omitempty" gorm:"default:null"`
// Endpoint
Endpoint string `json:"-" gorm:"default:null"`
// Maximum number of connections in a session
ConnectionLimit int32 `json:"connection_limit,omitempty" gorm:"default:null"`
// Worker filters
WorkerFilter string `json:"-" gorm:"default:null"`
EgressWorkerFilter string `json:"-" gorm:"default:null"`
IngressWorkerFilter string `json:"-" gorm:"default:null"`
// key_id is the ID of the key version used to encrypt any fields in this struct
KeyId string `json:"key_id,omitempty" gorm:"default:null"`
// States for the session which are for read only and are ignored during
// write operations
States []*State `gorm:"-"`
// DynamicCredentials for the session.
DynamicCredentials []*DynamicCredential `gorm:"-"`
// StaticCredentials for the session.
StaticCredentials []*StaticCredential `gorm:"-"`
// HostSetId for the session
HostSetId string `gorm:"-"`
// HostId of the session
HostId string `gorm:"-"`
// ProtocolWorkerId of the session
ProtocolWorkerId string `gorm:"-"`
// Connections for the session are for read only and are ignored during write operations
Connections []*Connection `gorm:"-"`
// CorrelationId is the CorrelationId of the authorize-session request
CorrelationId string `json:"-" gorm:"default:null"`
tableName string `gorm:"-"`
}
func (s Session) GetPublicId() string {
return s.PublicId
}
func (s Session) GetProjectId() string {
return s.ProjectId
}
func (s Session) GetUserId() string {
return s.UserId
}
// GetResourceType returns the resource type of the Session
func (s Session) GetResourceType() resource.Type {
return resource.Session
}
func (s Session) GetUpdateTime() *timestamp.Timestamp {
return s.UpdateTime
}
func (s Session) GetCreateTime() *timestamp.Timestamp {
return s.CreateTime
}
// GetDescription returns an empty string so that
// Session will satisfy resource requirements
func (s Session) GetDescription() string {
return ""
}
func (s Session) GetName() string {
return ""
}
func (s Session) GetVersion() uint32 {
return 0
}
var (
_ Cloneable = (*Session)(nil)
_ db.VetForWriter = (*Session)(nil)
_ boundary.AuthzProtectedEntity = (*Session)(nil)
)
// New creates a new in memory session.
func New(ctx context.Context, c ComposedOf, _ ...Option) (*Session, error) {
const op = "session.New"
s := Session{
UserId: c.UserId,
HostId: c.HostId,
TargetId: c.TargetId,
HostSetId: c.HostSetId,
AuthTokenId: c.AuthTokenId,
ProjectId: c.ProjectId,
Endpoint: c.Endpoint,
ExpirationTime: c.ExpirationTime,
ConnectionLimit: c.ConnectionLimit,
WorkerFilter: c.WorkerFilter,
EgressWorkerFilter: c.EgressWorkerFilter,
IngressWorkerFilter: c.IngressWorkerFilter,
DynamicCredentials: c.DynamicCredentials,
StaticCredentials: c.StaticCredentials,
ProtocolWorkerId: c.ProtocolWorkerId,
CorrelationId: c.CorrelationId,
}
if err := s.validateNewSession(ctx); err != nil {
return nil, errors.Wrap(ctx, err, op)
}
return &s, nil
}
// AllocSession will allocate a Session
func AllocSession() Session {
return Session{}
}
// Clone creates a clone of the Session
func (s *Session) Clone() any {
clone := &Session{
PublicId: s.PublicId,
UserId: s.UserId,
HostId: s.HostId,
TargetId: s.TargetId,
HostSetId: s.HostSetId,
AuthTokenId: s.AuthTokenId,
ProjectId: s.ProjectId,
TerminationReason: s.TerminationReason,
Version: s.Version,
Endpoint: s.Endpoint,
ConnectionLimit: s.ConnectionLimit,
WorkerFilter: s.WorkerFilter,
EgressWorkerFilter: s.EgressWorkerFilter,
IngressWorkerFilter: s.IngressWorkerFilter,
KeyId: s.KeyId,
ProtocolWorkerId: s.ProtocolWorkerId,
CorrelationId: s.CorrelationId,
}
if len(s.States) > 0 {
clone.States = make([]*State, 0, len(s.States))
for _, ss := range s.States {
cp := ss.Clone().(*State)
clone.States = append(clone.States, cp)
}
}
if len(s.DynamicCredentials) > 0 {
clone.DynamicCredentials = make([]*DynamicCredential, 0, len(s.DynamicCredentials))
for _, sc := range s.DynamicCredentials {
cp := sc.clone()
clone.DynamicCredentials = append(clone.DynamicCredentials, cp)
}
}
if len(s.StaticCredentials) > 0 {
clone.StaticCredentials = make([]*StaticCredential, 0, len(s.StaticCredentials))
for _, sc := range s.StaticCredentials {
cp := sc.clone()
clone.StaticCredentials = append(clone.StaticCredentials, cp)
}
}
if s.TofuToken != nil {
clone.TofuToken = make([]byte, len(s.TofuToken))
copy(clone.TofuToken, s.TofuToken)
}
if s.CtTofuToken != nil {
clone.CtTofuToken = make([]byte, len(s.CtTofuToken))
copy(clone.CtTofuToken, s.CtTofuToken)
}
if s.Certificate != nil {
clone.Certificate = make([]byte, len(s.Certificate))
copy(clone.Certificate, s.Certificate)
}
if s.CertificatePrivateKey != nil {
clone.CertificatePrivateKey = make([]byte, len(s.CertificatePrivateKey))
copy(clone.CertificatePrivateKey, s.CertificatePrivateKey)
}
if s.CtCertificatePrivateKey != nil {
clone.CtCertificatePrivateKey = make([]byte, len(s.CtCertificatePrivateKey))
copy(clone.CtCertificatePrivateKey, s.CtCertificatePrivateKey)
}
if s.ExpirationTime != nil {
clone.ExpirationTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.ExpirationTime.Timestamp.Seconds,
Nanos: s.ExpirationTime.Timestamp.Nanos,
},
}
}
if s.CreateTime != nil {
clone.CreateTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.CreateTime.Timestamp.Seconds,
Nanos: s.CreateTime.Timestamp.Nanos,
},
}
}
if s.UpdateTime != nil {
clone.UpdateTime = &timestamp.Timestamp{
Timestamp: &timestamppb.Timestamp{
Seconds: s.UpdateTime.Timestamp.Seconds,
Nanos: s.UpdateTime.Timestamp.Nanos,
},
}
}
return clone
}
// VetForWrite implements db.VetForWrite() interface and validates the session
// before it's written.
func (s *Session) VetForWrite(ctx context.Context, _ db.Reader, opType db.OpType, opt ...db.Option) error {
const op = "session.(Session).VetForWrite"
opts := db.GetOpts(opt...)
if s.PublicId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing public id")
}
switch opType {
case db.CreateOp:
if err := s.validateNewSession(ctx); err != nil {
return errors.Wrap(ctx, err, op)
}
if len(s.Certificate) == 0 {
return errors.New(ctx, errors.InvalidParameter, op, "missing certificate")
}
case db.UpdateOp:
switch {
case contains(opts.WithFieldMaskPaths, "PublicId"):
return errors.New(ctx, errors.InvalidParameter, op, "public id is immutable")
case contains(opts.WithFieldMaskPaths, "UserId"):
return errors.New(ctx, errors.InvalidParameter, op, "user id is immutable")
case contains(opts.WithFieldMaskPaths, "HostId"):
return errors.New(ctx, errors.InvalidParameter, op, "host id is immutable")
case contains(opts.WithFieldMaskPaths, "TargetId"):
return errors.New(ctx, errors.InvalidParameter, op, "target id is immutable")
case contains(opts.WithFieldMaskPaths, "HostSetId"):
return errors.New(ctx, errors.InvalidParameter, op, "host set id is immutable")
case contains(opts.WithFieldMaskPaths, "AuthTokenId"):
return errors.New(ctx, errors.InvalidParameter, op, "auth token id is immutable")
case contains(opts.WithFieldMaskPaths, "Certificate"):
return errors.New(ctx, errors.InvalidParameter, op, "certificate is immutable")
case contains(opts.WithFieldMaskPaths, "CreateTime"):
return errors.New(ctx, errors.InvalidParameter, op, "create time is immutable")
case contains(opts.WithFieldMaskPaths, "UpdateTime"):
return errors.New(ctx, errors.InvalidParameter, op, "update time is immutable")
case contains(opts.WithFieldMaskPaths, "Endpoint"):
return errors.New(ctx, errors.InvalidParameter, op, "endpoint is immutable")
case contains(opts.WithFieldMaskPaths, "ExpirationTime"):
return errors.New(ctx, errors.InvalidParameter, op, "expiration time is immutable")
case contains(opts.WithFieldMaskPaths, "ConnectionLimit"):
return errors.New(ctx, errors.InvalidParameter, op, "connection limit is immutable")
case contains(opts.WithFieldMaskPaths, "WorkerFilter"):
return errors.New(ctx, errors.InvalidParameter, op, "worker filter is immutable")
case contains(opts.WithFieldMaskPaths, "EgressWorkerFilter"):
return errors.New(ctx, errors.InvalidParameter, op, "egress worker filter is immutable")
case contains(opts.WithFieldMaskPaths, "IngressWorkerFilter"):
return errors.New(ctx, errors.InvalidParameter, op, "ingress worker filter is immutable")
case contains(opts.WithFieldMaskPaths, "DynamicCredentials"):
return errors.New(ctx, errors.InvalidParameter, op, "dynamic credentials are immutable")
case contains(opts.WithFieldMaskPaths, "StaticCredentials"):
return errors.New(ctx, errors.InvalidParameter, op, "static credentials are immutable")
case contains(opts.WithFieldMaskPaths, "ProtocolWorkerId"):
return errors.New(ctx, errors.InvalidParameter, op, "protocol worker id is immutable")
case contains(opts.WithFieldMaskPaths, "TerminationReason"):
if _, err := convertToReason(ctx, s.TerminationReason); err != nil {
return errors.Wrap(ctx, err, op)
}
}
}
return nil
}
// TableName returns the tablename to override the default gorm table name
func (s *Session) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultSessionTableName
}
// SetTableName sets the tablename and satisfies the ReplayableMessage
// interface. If the caller attempts to set the name to "" the name will be
// reset to the default name.
func (s *Session) SetTableName(n string) {
s.tableName = n
}
// validateNewSession checks everything but the session's PublicId
func (s *Session) validateNewSession(ctx context.Context) error {
const op = "session.(Session).validateNewSession"
if s.UserId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing user id")
}
if s.TargetId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing target id")
}
if s.AuthTokenId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing auth token id")
}
if s.ProjectId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing project id")
}
if s.Endpoint == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing endpoint")
}
if s.ExpirationTime.GetTimestamp().AsTime().IsZero() {
return errors.New(ctx, errors.InvalidParameter, op, "missing expiration time")
}
if s.TerminationReason != "" {
return errors.New(ctx, errors.InvalidParameter, op, "termination reason must be empty")
}
if s.TofuToken != nil {
return errors.New(ctx, errors.InvalidParameter, op, "tofu token must be empty")
}
if s.CtTofuToken != nil {
return errors.New(ctx, errors.InvalidParameter, op, "ct must be empty")
}
if s.CorrelationId == "" {
return errors.New(ctx, errors.InvalidParameter, op, "missing correlation id")
}
// It is okay for the worker filter and protocol worker ID to be empty, so
// they are not checked here.
return nil
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}
// newCert creates a new session certificate. If addresses are supplied, they will be parsed and added to IP or DNS SANs as appropriate.
func newCert(ctx context.Context, jobId string, addresses []string, exp time.Time, rand io.Reader) (ed25519.PrivateKey, []byte, error) {
const op = "session.newCert"
if jobId == "" {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing job id")
}
if len(addresses) == 0 {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing addresses")
}
if exp.IsZero() {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing expiry")
}
if util.IsNil(rand) {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "missing random data source")
}
pubKey, privKey, err := ed25519.GenerateKey(rand)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op)
}
template := &x509.Certificate{
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
DNSNames: []string{jobId},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement | x509.KeyUsageCertSign,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-1 * time.Minute),
NotAfter: exp,
BasicConstraintsValid: true,
IsCA: true,
}
for _, addr := range addresses {
// First ensure we aren't looking at ports, regardless of IP or not
host, _, err := util.SplitHostPort(addr)
if err != nil && !errors.Is(err, util.ErrMissingPort) {
return nil, nil, errors.Wrap(ctx, err, op)
}
// Now figure out if it's an IP address or not. If ParseIP likes it, add
// to IP SANs. Otherwise DNS SANs.
ip := net.ParseIP(host)
switch ip {
case nil:
template.DNSNames = append(template.DNSNames, host)
default:
template.IPAddresses = append(template.IPAddresses, ip)
}
}
certBytes, err := x509.CreateCertificate(rand, template, template, pubKey, privKey)
if err != nil {
return nil, nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert))
}
return privKey, certBytes, nil
}
func (s *Session) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(Session).encrypt"
if util.IsNil(cipher) {
return errors.New(ctx, errors.InvalidParameter, op, "missing cipher")
}
// Split encryption per field since they can be independently nil.
if len(s.TofuToken) > 0 {
type tofuToken struct {
TofuToken []byte `wrapping:"pt,tofu_token"`
CtTofuToken []byte `wrapping:"ct,tofu_token"`
}
v := &tofuToken{
TofuToken: s.TofuToken,
}
if err := structwrapping.WrapStruct(ctx, cipher, v); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt), errors.WithMsg("failed to encrypt tofu token"))
}
s.CtTofuToken = v.CtTofuToken
}
if len(s.CertificatePrivateKey) > 0 {
type certPk struct {
CertificatePrivateKey []byte `wrapping:"pt,certificate_private_key"`
CtCertificatePrivateKey []byte `wrapping:"ct,certificate_private_key"`
}
v := &certPk{
CertificatePrivateKey: s.CertificatePrivateKey,
}
if err := structwrapping.WrapStruct(ctx, cipher, v); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt), errors.WithMsg("failed to encrypt certificate private key"))
}
s.CtCertificatePrivateKey = v.CtCertificatePrivateKey
}
keyId, err := cipher.KeyId(ctx)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt), errors.WithMsg("error getting cipher key id"))
}
s.KeyId = keyId
return nil
}
func (s *Session) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(Session).decrypt"
if util.IsNil(cipher) {
return errors.New(ctx, errors.InvalidParameter, op, "missing cipher")
}
// Split decryption per field since they can be independently nil.
if len(s.CtTofuToken) > 0 {
type tofuToken struct {
TofuToken []byte `wrapping:"pt,tofu_token"`
CtTofuToken []byte `wrapping:"ct,tofu_token"`
}
v := &tofuToken{
CtTofuToken: s.CtTofuToken,
}
if err := structwrapping.UnwrapStruct(ctx, cipher, v); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt), errors.WithMsg("failed to decrypt tofu token"))
}
s.TofuToken = v.TofuToken
}
if len(s.CtCertificatePrivateKey) > 0 {
type certPk struct {
CertificatePrivateKey []byte `wrapping:"pt,certificate_private_key"`
CtCertificatePrivateKey []byte `wrapping:"ct,certificate_private_key"`
}
v := &certPk{
CtCertificatePrivateKey: s.CtCertificatePrivateKey,
}
if err := structwrapping.UnwrapStruct(ctx, cipher, v); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt), errors.WithMsg("failed to decrypt certificate private key"))
}
s.CertificatePrivateKey = v.CertificatePrivateKey
}
return nil
}
type sessionListView struct {
// Session fields, we omit some fields that are not included when listing sessions.
PublicId string `gorm:"primary_key"`
UserId string `gorm:"default:null"`
HostId string `gorm:"default:null"`
HostSetId string `gorm:"default:null"`
TargetId string `gorm:"default:null"`
AuthTokenId string `gorm:"default:null"`
ProjectId string `gorm:"default:null"`
Certificate []byte `gorm:"default:null"`
ExpirationTime *timestamp.Timestamp `gorm:"default:null"`
TerminationReason string `gorm:"default:null"`
CreateTime *timestamp.Timestamp `gorm:"default:current_timestamp"`
UpdateTime *timestamp.Timestamp `gorm:"default:current_timestamp"`
Version uint32 `gorm:"default:null"`
Endpoint string `gorm:"default:null"`
ConnectionLimit int32 `gorm:"default:null"`
// State fields
Status string `gorm:"column:state"`
StartTime *timestamp.Timestamp `gorm:"column:start_time"`
EndTime *timestamp.Timestamp `gorm:"column:end_time"`
}
// TableName returns the tablename to override the default gorm table name
func (s *sessionListView) TableName() string {
return "session_list"
}
type deletedSession struct {
PublicId string `gorm:"primary_key"`
DeleteTime *timestamp.Timestamp
}
// TableName returns the tablename to override the default gorm table name
func (s *deletedSession) TableName() string {
return "session_deleted"
}
// ProxyCertificate represents a session id to proxy certificate mapping
// It's stored during session authorization and used at connection authorization
// time to lookup the proxy certificate, if applicable
type ProxyCertificate struct {
SessionId string `gorm:"primary_key"`
Certificate []byte `gorm:"not_null"`
PrivateKey []byte `gorm:"-" wrapping:"pt,private_key"`
PrivateKeyEncrypted []byte `gorm:"not_null" wrapping:"ct,private_key"`
KeyId string `gorm:"not_null"`
}
// NewProxyCertificate creates a new in memory ProxyCertificate
func NewProxyCertificate(ctx context.Context, sessionId string, privateKey, certificate []byte) (*ProxyCertificate, error) {
const op = "session.NewProxyCertificate"
switch {
case sessionId == "":
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing session id")
case len(certificate) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing certificate")
case len(privateKey) == 0:
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing private key")
}
return &ProxyCertificate{
SessionId: sessionId,
Certificate: certificate,
PrivateKey: privateKey,
}, nil
}
func allocProxyCertificate() *ProxyCertificate {
return &ProxyCertificate{}
}
// Clone creates a clone of the ProxyCertificate
func (t *ProxyCertificate) Clone() *ProxyCertificate {
spc := &ProxyCertificate{
SessionId: t.SessionId,
KeyId: t.KeyId,
}
if t.Certificate != nil {
spc.Certificate = make([]byte, len(t.Certificate))
copy(spc.Certificate, t.Certificate)
}
if t.PrivateKey != nil {
spc.PrivateKey = make([]byte, len(t.PrivateKey))
copy(spc.PrivateKey, t.PrivateKey)
}
if t.PrivateKeyEncrypted != nil {
spc.PrivateKeyEncrypted = make([]byte, len(t.PrivateKeyEncrypted))
copy(spc.PrivateKeyEncrypted, t.PrivateKeyEncrypted)
}
return spc
}
// VetForWrite implements db.VetForWrite() interface and validates the session proxy certificate
func (t *ProxyCertificate) VetForWrite(ctx context.Context, _ db.Reader, opType db.OpType, _ ...db.Option) error {
const op = "session.(ProxyCertificate).VetForWrite"
switch {
case t.SessionId == "":
return errors.New(ctx, errors.InvalidParameter, op, "missing session id")
case len(t.Certificate) == 0:
return errors.New(ctx, errors.InvalidParameter, op, "missing certificate")
case len(t.PrivateKeyEncrypted) == 0:
return errors.New(ctx, errors.InvalidParameter, op, "missing encrypted private key")
case t.KeyId == "":
return errors.New(ctx, errors.InvalidParameter, op, "missing key id")
}
return nil
}
// Encrypt the proxy cert before writing it to the db
func (t *ProxyCertificate) Encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(ProxyCertificate).Encrypt"
if cipher == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing cipher")
}
if err := structwrapping.WrapStruct(ctx, cipher, t, nil); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt))
}
keyId, err := cipher.KeyId(ctx)
if err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Encrypt), errors.WithMsg("failed to read cipher key id"))
}
t.KeyId = keyId
return nil
}
// Decrypt the proxy cert after reading it from the db
func (t *ProxyCertificate) Decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "session.(ProxyCertificate).Decrypt"
if cipher == nil {
return errors.New(ctx, errors.InvalidParameter, op, "missing cipher")
}
if err := structwrapping.UnwrapStruct(ctx, cipher, t, nil); err != nil {
return errors.Wrap(ctx, err, op, errors.WithCode(errors.Decrypt))
}
return nil
}
// TableName returns the table name.
func (t *ProxyCertificate) TableName() string {
return "session_proxy_certificate"
}