more validation

pull/347/head
Jim Lambert 6 years ago
parent edabf2dcaa
commit 321be613e8

@ -3,6 +3,7 @@ package session
import (
"context"
"fmt"
"strings"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/session/store"
@ -50,7 +51,7 @@ func New(
},
}
if err := s.validate("new session:"); err != nil {
if err := s.validateNewSession("new session:"); err != nil {
return nil, err
}
return &s, nil
@ -74,16 +75,47 @@ func (s *Session) Clone() interface{} {
// VetForWrite implements db.VetForWrite() interface and validates the session
// before it's written.
func (s *Session) VetForWrite(ctx context.Context, r db.Reader, opType db.OpType, opt ...db.Option) error {
opts := db.GetOpts(opt...)
if s.PublicId == "" {
return fmt.Errorf("session vet for write: missing public id: %w", db.ErrInvalidParameter)
}
switch opType {
case db.CreateOp:
if err := s.validate("session vet for write:"); err != nil {
if err := s.validateNewSession("session vet for write:"); err != nil {
return err
}
case db.UpdateOp:
panic("not implemented")
switch {
case contains(opts.WithFieldMaskPaths, "PublicId"):
return fmt.Errorf("session vet for write: public id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "UserId"):
return fmt.Errorf("session vet for write: user id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "HostId"):
return fmt.Errorf("session vet for write: host id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ServerId"):
return fmt.Errorf("session vet for write: server id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "TargetId"):
return fmt.Errorf("session vet for write: target id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "SetId"):
return fmt.Errorf("session vet for write: set id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "AuthTokenId"):
return fmt.Errorf("session vet for write: auth token id is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "ServerType"):
return fmt.Errorf("session vet for write: server type is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "Address"):
return fmt.Errorf("session vet for write: address is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "Port"):
return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "CreateTime"):
case contains(opts.WithFieldMaskPaths, "Port"):
return fmt.Errorf("session vet for write: port is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "UpdateTime"):
return fmt.Errorf("session vet for write: update time is immutable: %w", db.ErrInvalidParameter)
case contains(opts.WithFieldMaskPaths, "TerminationReason"):
if _, err := convertToReason(s.TerminationReason); err != nil {
return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter)
}
}
}
return nil
}
@ -103,8 +135,8 @@ func (s *Session) SetTableName(n string) {
s.tableName = n
}
// validate checks everything but the session's PublicId
func (s *Session) validate(errorPrefix string) error {
// validateNewSession checks everything but the session's PublicId
func (s *Session) validateNewSession(errorPrefix string) error {
if s.UserId == "" {
return fmt.Errorf("%s missing user id: %w", errorPrefix, db.ErrInvalidParameter)
}
@ -135,5 +167,19 @@ func (s *Session) validate(errorPrefix string) error {
if s.Port == "" {
return fmt.Errorf("%s missing port: %w", errorPrefix, db.ErrInvalidParameter)
}
if s.TerminationReason != "" {
if _, err := convertToReason(s.TerminationReason); err != nil {
return fmt.Errorf("session vet for write: %w", db.ErrInvalidParameter)
}
}
return nil
}
func contains(ss []string, t string) bool {
for _, s := range ss {
if strings.EqualFold(s, t) {
return true
}
}
return false
}

Loading…
Cancel
Save