From 321be613e8768f72a9a5c2ed8f6289039f765a9a Mon Sep 17 00:00:00 2001 From: Jim Lambert Date: Sat, 5 Sep 2020 09:04:04 -0400 Subject: [PATCH] more validation --- internal/session/session.go | 56 +++++++++++++++++++++++++++++++++---- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/internal/session/session.go b/internal/session/session.go index eaba3f936f..1924b14d88 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -3,6 +3,7 @@ package session import ( "context" "fmt" + "strings" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/session/store" @@ -50,7 +51,7 @@ func New( }, } - if err := s.validate("new session:"); err != nil { + if err := s.validateNewSession("new session:"); err != nil { return nil, err } return &s, nil @@ -74,16 +75,47 @@ func (s *Session) Clone() interface{} { // VetForWrite implements db.VetForWrite() interface and validates the session // before it's written. func (s *Session) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error { + opts := db.GetOpts(opt...) if s.PublicId == "" { return fmt.Errorf("session vet for write: missing public id: %w", db.ErrInvalidParameter) } switch opType { case db.CreateOp: - if err := s.validate("session vet for write:"); err != nil { + if err := s.validateNewSession("session vet for write:"); err != nil { return err } case db.UpdateOp: - panic("not implemented") + switch { + case contains(opts.WithFieldMaskPaths, "PublicId"): + return fmt.Errorf("session vet for write: public id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "UserId"): + return fmt.Errorf("session vet for write: user id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "HostId"): + return fmt.Errorf("session vet for write: host id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "ServerId"): + return fmt.Errorf("session vet for write: server id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "TargetId"): + return fmt.Errorf("session vet for write: target id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "SetId"): + return fmt.Errorf("session vet for write: set id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "AuthTokenId"): + return fmt.Errorf("session vet for write: auth token id is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "ServerType"): + return fmt.Errorf("session vet for write: server type is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "Address"): + return fmt.Errorf("session vet for write: address is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "Port"): + return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "CreateTime"): + case contains(opts.WithFieldMaskPaths, "Port"): + return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "UpdateTime"): + return fmt.Errorf("session vet for write: update time is immutable: %w", db.ErrInvalidParameter) + case contains(opts.WithFieldMaskPaths, "TerminationReason"): + if _, err := convertToReason(s.TerminationReason); err != nil { + return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter) + } + } } return nil } @@ -103,8 +135,8 @@ func (s *Session) SetTableName(n string) { s.tableName = n } -// validate checks everything but the session's PublicId -func (s *Session) validate(errorPrefix string) error { +// validateNewSession checks everything but the session's PublicId +func (s *Session) validateNewSession(errorPrefix string) error { if s.UserId == "" { return fmt.Errorf("%s missing user id: %w", errorPrefix, db.ErrInvalidParameter) } @@ -135,5 +167,19 @@ func (s *Session) validate(errorPrefix string) error { if s.Port == "" { return fmt.Errorf("%s missing port: %w", errorPrefix, db.ErrInvalidParameter) } + if s.TerminationReason != "" { + if _, err := convertToReason(s.TerminationReason); err != nil { + return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter) + } + } return nil } + +func contains(ss []string, t string) bool { + for _, s := range ss { + if strings.EqualFold(s, t) { + return true + } + } + return false +}