[enhancement] Refactor authtoken repo to return domain specific errors (#876)

pull/877/head^2
s-christoff 5 years ago committed by GitHub
parent bd7322dfb8
commit 81785e139f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ import (
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/authtoken/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/gen/controller/tokens"
"github.com/hashicorp/boundary/internal/kms"
wrapping "github.com/hashicorp/go-kms-wrapping"
@ -61,9 +62,10 @@ func (s *AuthToken) toWritableAuthToken() *writableAuthToken {
// encrypt the entry's data using the provided cipher (wrapping.Wrapper)
func (s *writableAuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "authtoken.(writableAuthToken).encrypt"
// structwrapping doesn't support embedding, so we'll pass in the store.Entry directly
if err := structwrapping.WrapStruct(ctx, cipher, s.AuthToken, nil); err != nil {
return fmt.Errorf("error encrypting auth token: %w", err)
return errors.Wrap(err, op, errors.WithCode(errors.Encrypt))
}
s.KeyId = cipher.KeyID()
return nil
@ -71,9 +73,10 @@ func (s *writableAuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper
// decrypt will decrypt the auth token's value using the provided cipher (wrapping.Wrapper)
func (s *AuthToken) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "authtoken.(AuthToken).decrypt"
// structwrapping doesn't support embedding, so we'll pass in the store.Entry directly
if err := structwrapping.UnwrapStruct(ctx, cipher, s.AuthToken, nil); err != nil {
return fmt.Errorf("error decrypting auth token: %w", err)
return errors.Wrap(err, op, errors.WithCode(errors.Decrypt))
}
return nil
}
@ -86,18 +89,20 @@ const (
)
func newAuthTokenId() (string, error) {
const op = "authtoken.newAuthTokenId"
id, err := db.NewPublicId(AuthTokenPrefix)
if err != nil {
return "", fmt.Errorf("new auth token id: %w", err)
return "", errors.Wrap(err, op)
}
return id, err
return id, nil
}
// newAuthToken generates a token with a version prefix.
func newAuthToken() (string, error) {
const op = "authtoken.newAuthToken"
token, err := base62.Random(tokenLength)
if err != nil {
return "", fmt.Errorf("unable to generate auth token: %w", err)
return "", errors.Wrap(err, op, errors.WithCode(errors.Io))
}
return fmt.Sprintf("%s%s", TokenValueVersionPrefix, token), nil
}
@ -105,6 +110,7 @@ func newAuthToken() (string, error) {
// EncryptToken is a shared function for encrypting a token value for return to
// the user.
func EncryptToken(ctx context.Context, kmsCache *kms.Kms, scopeId, publicId, token string) (string, error) {
const op = "authtoken.EncryptToken"
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
s1Info := &tokens.S1TokenInfo{
@ -115,22 +121,22 @@ func EncryptToken(ctx context.Context, kmsCache *kms.Kms, scopeId, publicId, tok
marshaledS1Info, err := proto.Marshal(s1Info)
if err != nil {
return "", fmt.Errorf("error marshaling token info: %w", err)
return "", errors.Wrap(err, op, errors.WithMsg("marshaling encrypted token"), errors.WithCode(errors.Encode))
}
tokenWrapper, err := kmsCache.GetWrapper(ctx, scopeId, kms.KeyPurposeTokens)
if err != nil {
return "", fmt.Errorf("unable to get wrapper: %w", err)
return "", errors.Wrap(err, op, errors.WithMsg("unable to get wrapper"))
}
blobInfo, err := tokenWrapper.Encrypt(ctx, []byte(marshaledS1Info), []byte(publicId))
if err != nil {
return "", fmt.Errorf("error encrypting token: %w", err)
return "", errors.Wrap(err, op, errors.WithMsg("marshaling token info"), errors.WithCode(errors.Encrypt))
}
marshaledBlob, err := proto.Marshal(blobInfo)
if err != nil {
return "", fmt.Errorf("error marshaling encrypted token: %w", err)
return "", errors.Wrap(err, op, errors.WithMsg("marshaling encrypted token"), errors.WithCode(errors.Encode))
}
encoded := base58.FastBase58Encoding(marshaledBlob)

@ -33,13 +33,14 @@ type Repository struct {
// NewRepository creates a new Repository. The returned repository is not safe for concurrent go
// routines to access it.
func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repository, error) {
const op = "authtoken.NewRepository"
switch {
case r == nil:
return nil, fmt.Errorf("db.Reader: auth token: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "nil db reader")
case w == nil:
return nil, fmt.Errorf("db.Writer: auth token: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "nil db writer")
case kms == nil:
return nil, fmt.Errorf("kms: auth token: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "nil kms")
}
opts := getOpts(opt...)
@ -58,34 +59,34 @@ func NewRepository(r db.Reader, w db.Writer, kms *kms.Kms, opt ...Option) (*Repo
// contains the auth token value. The provided IAM User ID must be associated to the provided auth account id
// or an error will be returned. All options are ignored.
func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User, withAuthAccountId string, opt ...Option) (*AuthToken, error) {
const op = "authtoken.(Repository).CreateAuthToken"
if withIamUser == nil {
return nil, fmt.Errorf("create: auth token: no user: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing user")
}
if withIamUser.GetPublicId() == "" {
return nil, fmt.Errorf("create: auth token: no user id: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing user id")
}
if withAuthAccountId == "" {
return nil, fmt.Errorf("create: auth token: no auth account id: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing auth account id")
}
at := allocAuthToken()
at.AuthAccountId = withAuthAccountId
id, err := newAuthTokenId()
if err != nil {
return nil, fmt.Errorf("create: auth token id: %w", err)
return nil, errors.Wrap(err, op)
}
at.PublicId = id
token, err := newAuthToken()
if err != nil {
return nil, fmt.Errorf("create: auth token value: %w", err)
return nil, errors.Wrap(err, op)
}
at.Token = token
databaseWrapper, err := r.kms.GetWrapper(ctx, withIamUser.GetScopeId(), kms.KeyPurposeDatabase)
if err != nil {
return nil, fmt.Errorf("create: unable to get database wrapper: %w", err)
return nil, errors.Wrap(err, op, errors.WithMsg("unable to get database wrapper"))
}
// We truncate the expiration time to the nearest second to make testing in different platforms with
@ -105,10 +106,11 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
acct := allocAuthAccount()
acct.PublicId = withAuthAccountId
if err := read.LookupByPublicId(ctx, acct); err != nil {
return fmt.Errorf("create: auth token: auth account lookup: %w", err)
return errors.Wrap(err, op, errors.WithMsg("auth account lookup"))
}
if acct.GetIamUserId() != withIamUser.GetPublicId() {
return fmt.Errorf("create: auth token: auth account %q mismatch with iam user %q", withAuthAccountId, withIamUser.GetPublicId())
return errors.New(errors.InvalidParameter, op,
fmt.Sprintf("auth account %q mismatch with iam user %q", withAuthAccountId, withIamUser.GetPublicId()))
}
at.ScopeId = acct.GetScopeId()
at.AuthMethodId = acct.GetAuthMethodId()
@ -129,7 +131,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
)
if err != nil {
return nil, fmt.Errorf("create: auth token: %v: %w", at, err)
return nil, errors.Wrap(err, op)
}
return newAuthToken.toAuthToken(), nil
}
@ -138,8 +140,9 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
// For security reasons, the actual token is not included in the returned AuthToken.
// All exported options are ignored.
func (r *Repository) LookupAuthToken(ctx context.Context, id string, opt ...Option) (*AuthToken, error) {
const op = "authtoken.(Repository).LookupAuthToken"
if id == "" {
return nil, fmt.Errorf("lookup: auth token: missing public id: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing public id")
}
opts := getOpts(opt...)
@ -149,15 +152,15 @@ func (r *Repository) LookupAuthToken(ctx context.Context, id string, opt ...Opti
if errors.IsNotFoundError(err) {
return nil, nil
}
return nil, fmt.Errorf("auth token: lookup: %w", err)
return nil, errors.Wrap(err, op)
}
if opts.withTokenValue {
databaseWrapper, err := r.kms.GetWrapper(ctx, at.GetScopeId(), kms.KeyPurposeDatabase, kms.WithKeyId(at.GetKeyId()))
if err != nil {
return nil, fmt.Errorf("lookup: unable to get database wrapper: %w", err)
return nil, errors.Wrap(err, op, errors.WithMsg("unable to get database wrapper"))
}
if err := at.decrypt(ctx, databaseWrapper); err != nil {
return nil, fmt.Errorf("lookup: auth token: cannot decrypt auth token value: %w", err)
return nil, errors.Wrap(err, op)
}
}
@ -174,11 +177,12 @@ func (r *Repository) LookupAuthToken(ctx context.Context, id string, opt ...Opti
//
// NOTE: Do not log or add the token string to any errors to avoid leaking it as it is a secret.
func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ...Option) (*AuthToken, error) {
const op = "authtoken.(Repository).ValidateToken"
if token == "" {
return nil, fmt.Errorf("validate token: auth token: missing token: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing token")
}
if id == "" {
return nil, fmt.Errorf("validate token: auth token: missing public id: %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing public id")
}
retAT, err := r.LookupAuthToken(ctx, id, withTokenValue())
@ -187,7 +191,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
if errors.IsNotFoundError(err) {
return nil, nil
}
return nil, fmt.Errorf("validate token: %w", err)
return nil, errors.Wrap(err, op)
}
if retAT == nil {
return nil, nil
@ -196,11 +200,11 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
// If the token is too old or stale invalidate it and return nothing.
exp, err := ptypes.Timestamp(retAT.GetExpirationTime().GetTimestamp())
if err != nil {
return nil, fmt.Errorf("validate token: expiration time : %w", err)
return nil, errors.Wrap(err, op, errors.WithMsg("expiration time"), errors.WithCode(errors.InvalidTimeStamp))
}
lastAccessed, err := ptypes.Timestamp(retAT.GetApproximateLastAccessTime().GetTimestamp())
if err != nil {
return nil, fmt.Errorf("validate token: last accessed time : %w", err)
return nil, errors.Wrap(err, op, errors.WithMsg("last accessed time"), errors.WithCode(errors.InvalidTimeStamp))
}
now := time.Now()
@ -217,7 +221,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
delAt := retAT.toWritableAuthToken()
// tokens are not replicated, so they don't need oplog entries.
if _, err := w.Delete(ctx, delAt); err != nil {
return fmt.Errorf("validate token: delete auth token: %w", err)
return errors.Wrap(err, op, errors.WithMsg("delete auth token"))
}
retAT = nil
return nil
@ -262,7 +266,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
}
if err != nil {
return nil, fmt.Errorf("validate token: auth token: %s: %w", id, err)
return nil, errors.Wrap(err, op, errors.WithMsg(id))
}
return retAT, nil
}
@ -270,14 +274,15 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
// ListAuthTokens lists auth tokens in the given scopes and supports the
// WithLimit option.
func (r *Repository) ListAuthTokens(ctx context.Context, withScopeIds []string, opt ...Option) ([]*AuthToken, error) {
const op = "authtoken.(Repository).ListAuthTokens"
if len(withScopeIds) == 0 {
return nil, fmt.Errorf("list auth tokens: missing scope id %w", errors.ErrInvalidParameter)
return nil, errors.New(errors.InvalidParameter, op, "missing scope id")
}
opts := getOpts(opt...)
var authTokens []*AuthToken
if err := r.reader.SearchWhere(ctx, &authTokens, "auth_account_id in (select public_id from auth_account where scope_id in (?))", []interface{}{withScopeIds}, db.WithLimit(opts.withLimit)); err != nil {
return nil, fmt.Errorf("list auth tokens: %w", err)
return nil, errors.Wrap(err, op)
}
for _, at := range authTokens {
at.Token = ""
@ -290,8 +295,9 @@ func (r *Repository) ListAuthTokens(ctx context.Context, withScopeIds []string,
// DeleteAuthToken deletes the token with the provided id from the repository returning a count of the
// number of records deleted. All options are ignored.
func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Option) (int, error) {
const op = "authtoken.(Repository).DeleteAuthToken"
if id == "" {
return db.NoRowsAffected, fmt.Errorf("delete: auth token: missing public id: %w", errors.ErrInvalidParameter)
return db.NoRowsAffected, errors.New(errors.InvalidParameter, op, "missing public id")
}
at, err := r.LookupAuthToken(ctx, id)
@ -299,7 +305,7 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti
if errors.IsNotFoundError(err) {
return db.NoRowsAffected, nil
}
return db.NoRowsAffected, fmt.Errorf("delete: auth token: lookup %w", err)
return db.NoRowsAffected, errors.Wrap(err, op)
}
if at == nil {
return db.NoRowsAffected, nil
@ -322,7 +328,7 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti
)
if err != nil {
return db.NoRowsAffected, fmt.Errorf("delete: auth token: %s: %w", id, err)
return db.NoRowsAffected, errors.Wrap(err, op, errors.WithMsg(id))
}
return rowsDeleted, nil

@ -35,10 +35,11 @@ func TestRepository_New(t *testing.T) {
}
tests := []struct {
name string
args args
want *Repository
wantIsErr error
name string
args args
want *Repository
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "valid default limit",
@ -122,8 +123,9 @@ func TestRepository_New(t *testing.T) {
w: rw,
kms: kmsCache,
},
want: nil,
wantIsErr: errors.ErrInvalidParameter,
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
},
{
name: "nil-writer",
@ -132,8 +134,9 @@ func TestRepository_New(t *testing.T) {
w: nil,
kms: kmsCache,
},
want: nil,
wantIsErr: errors.ErrInvalidParameter,
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db writer: parameter violation: error #100",
},
{
name: "nil-kms",
@ -142,8 +145,9 @@ func TestRepository_New(t *testing.T) {
w: rw,
kms: nil,
},
want: nil,
wantIsErr: errors.ErrInvalidParameter,
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil kms: parameter violation: error #100",
},
{
name: "all-nils",
@ -152,8 +156,9 @@ func TestRepository_New(t *testing.T) {
w: nil,
kms: nil,
},
want: nil,
wantIsErr: errors.ErrInvalidParameter,
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
},
}
for _, tt := range tests {
@ -161,9 +166,9 @@ func TestRepository_New(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, err := NewRepository(tt.args.r, tt.args.w, tt.args.kms, tt.args.opts...)
if tt.wantIsErr != nil {
assert.Truef(errors.Is(err, tt.wantIsErr), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
@ -280,10 +285,11 @@ func TestRepository_LookupAuthToken(t *testing.T) {
require.NotNil(t, badId)
tests := []struct {
name string
id string
want *AuthToken
wantErr error
name string
id string
want *AuthToken
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "found",
@ -296,10 +302,11 @@ func TestRepository_LookupAuthToken(t *testing.T) {
want: nil,
},
{
name: "bad-public-id",
id: "",
want: nil,
wantErr: errors.ErrInvalidParameter,
name: "bad-public-id",
id: "",
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.(Repository).LookupAuthToken: missing public id: parameter violation: error #100",
},
}
@ -312,8 +319,9 @@ func TestRepository_LookupAuthToken(t *testing.T) {
require.NotNil(repo)
got, err := repo.LookupAuthToken(context.Background(), tt.id)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
@ -371,11 +379,12 @@ func TestRepository_ValidateToken(t *testing.T) {
require.NotNil(t, badToken)
tests := []struct {
name string
id string
token string
want *AuthToken
wantErr error
name string
id string
token string
want *AuthToken
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "exists",
@ -390,18 +399,19 @@ func TestRepository_ValidateToken(t *testing.T) {
want: nil,
},
{
name: "empty-token",
id: at.GetPublicId(),
token: "",
want: nil,
wantErr: errors.ErrInvalidParameter,
name: "empty-token",
id: at.GetPublicId(),
token: "",
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.(Repository).ValidateToken: missing token: parameter violation: error #100",
},
{
name: "mismatched-token",
id: at.GetPublicId(),
token: badToken,
want: nil,
wantErr: nil,
name: "mismatched-token",
id: at.GetPublicId(),
token: badToken,
want: nil,
wantIsErr: errors.Unknown,
},
}
for _, tt := range tests {
@ -410,8 +420,9 @@ func TestRepository_ValidateToken(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := repo.ValidateToken(context.Background(), tt.id, tt.token)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
@ -544,10 +555,11 @@ func TestRepository_DeleteAuthToken(t *testing.T) {
require.NotNil(t, badId)
tests := []struct {
name string
id string
want int
wantErr error
name string
id string
want int
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "found",
@ -560,10 +572,11 @@ func TestRepository_DeleteAuthToken(t *testing.T) {
want: 0,
},
{
name: "empty-public-id",
id: "",
want: 0,
wantErr: errors.ErrInvalidParameter,
name: "empty-public-id",
id: "",
want: 0,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.(Repository).DeleteAuthToken: missing public id: parameter violation: error #100",
},
}
@ -576,8 +589,9 @@ func TestRepository_DeleteAuthToken(t *testing.T) {
require.NotNil(repo)
got, err := repo.DeleteAuthToken(context.Background(), tt.id)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
@ -610,10 +624,9 @@ func TestRepository_ListAuthTokens(t *testing.T) {
emptyOrg, _ := iam.TestScopes(t, repo)
tests := []struct {
name string
orgId string
want []*AuthToken
wantErr error
name string
orgId string
want []*AuthToken
}{
{
name: "populated",
@ -639,12 +652,7 @@ func TestRepository_ListAuthTokens(t *testing.T) {
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.ListAuthTokens(context.Background(), []string{tt.orgId})
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
sort.Slice(tt.want, func(i, j int) bool { return tt.want[i].PublicId < tt.want[j].PublicId })
sort.Slice(got, func(i, j int) bool { return got[i].PublicId < got[j].PublicId })

@ -23,7 +23,7 @@ func newId(prefix string) (string, error) {
}
publicId, err := base62.Random(10)
if err != nil {
return "", errors.Wrap(err, op, errors.WithMsg("unable to generate id"))
return "", errors.Wrap(err, op, errors.WithCode(errors.Io))
}
return fmt.Sprintf("%s_%s", prefix, publicId), nil
}

@ -30,6 +30,7 @@ const (
TicketAlreadyRedeemed Code = 106 // TicketAlreadyRedeemed represents that the ticket version has already been redeemed
TicketNotFound Code = 107 // TicketNotFound represents that the ticket was not found
Io Code = 108 // Io represents that an io error occurred in an underlying call (i.e binary.Write)
InvalidTimeStamp Code = 109 // InvalidTimeStamp represents an invalid time stamp for an operation
// PasswordTooShort results from attempting to set a password which is to short.
PasswordTooShort Code = 200

@ -72,6 +72,11 @@ func TestCode_Both_String_Info(t *testing.T) {
c: Io,
want: Io,
},
{
name: "InvalidTimeStamp",
c: InvalidTimeStamp,
want: InvalidTimeStamp,
},
{
name: "PasswordTooShort",
c: PasswordTooShort,

@ -52,6 +52,10 @@ var errorCodeInfo = map[Code]Info{
Message: "error during io operation",
Kind: Integrity,
},
InvalidTimeStamp: {
Message: "invalid time stamp",
Kind: Integrity,
},
PasswordTooShort: {
Message: "too short",
Kind: Password,

Loading…
Cancel
Save