|
|
|
|
@ -3,6 +3,7 @@ package session
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"fmt"
|
|
|
|
|
"strings"
|
|
|
|
|
|
|
|
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
|
|
|
"github.com/hashicorp/boundary/internal/session/store"
|
|
|
|
|
@ -50,7 +51,7 @@ func New(
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := s.validate("new session:"); err != nil {
|
|
|
|
|
if err := s.validateNewSession("new session:"); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return &s, nil
|
|
|
|
|
@ -74,16 +75,47 @@ func (s *Session) Clone() interface{} {
|
|
|
|
|
// VetForWrite implements db.VetForWrite() interface and validates the session
|
|
|
|
|
// before it's written.
|
|
|
|
|
func (s *Session) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
|
|
|
|
|
opts := db.GetOpts(opt...)
|
|
|
|
|
if s.PublicId == "" {
|
|
|
|
|
return fmt.Errorf("session vet for write: missing public id: %w", db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
switch opType {
|
|
|
|
|
case db.CreateOp:
|
|
|
|
|
if err := s.validate("session vet for write:"); err != nil {
|
|
|
|
|
if err := s.validateNewSession("session vet for write:"); err != nil {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
case db.UpdateOp:
|
|
|
|
|
panic("not implemented")
|
|
|
|
|
switch {
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "PublicId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: public id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "UserId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: user id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "HostId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: host id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "ServerId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: server id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "TargetId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: target id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "SetId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: set id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "AuthTokenId"):
|
|
|
|
|
return fmt.Errorf("session vet for write: auth token id is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "ServerType"):
|
|
|
|
|
return fmt.Errorf("session vet for write: server type is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "Address"):
|
|
|
|
|
return fmt.Errorf("session vet for write: address is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "Port"):
|
|
|
|
|
return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "CreateTime"):
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "Port"):
|
|
|
|
|
return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "UpdateTime"):
|
|
|
|
|
return fmt.Errorf("session vet for write: update time is immutable: %w", db.ErrInvalidParameter)
|
|
|
|
|
case contains(opts.WithFieldMaskPaths, "TerminationReason"):
|
|
|
|
|
if _, err := convertToReason(s.TerminationReason); err != nil {
|
|
|
|
|
return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
@ -103,8 +135,8 @@ func (s *Session) SetTableName(n string) {
|
|
|
|
|
s.tableName = n
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// validate checks everything but the session's PublicId
|
|
|
|
|
func (s *Session) validate(errorPrefix string) error {
|
|
|
|
|
// validateNewSession checks everything but the session's PublicId
|
|
|
|
|
func (s *Session) validateNewSession(errorPrefix string) error {
|
|
|
|
|
if s.UserId == "" {
|
|
|
|
|
return fmt.Errorf("%s missing user id: %w", errorPrefix, db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
@ -135,5 +167,19 @@ func (s *Session) validate(errorPrefix string) error {
|
|
|
|
|
if s.Port == "" {
|
|
|
|
|
return fmt.Errorf("%s missing port: %w", errorPrefix, db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
if s.TerminationReason != "" {
|
|
|
|
|
if _, err := convertToReason(s.TerminationReason); err != nil {
|
|
|
|
|
return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func contains(ss []string, t string) bool {
|
|
|
|
|
for _, s := range ss {
|
|
|
|
|
if strings.EqualFold(s, t) {
|
|
|
|
|
return true
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
|