a small refactor, so we can easily support options when creating new auth tokens. (#920)

pull/900/head^2
Jim 5 years ago committed by GitHub
parent b61dd0ef36
commit ffcda56f40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -19,22 +19,29 @@ import (
"google.golang.org/protobuf/proto"
)
// writableAuthToken is used for auth token writes. Since gorm relies on the TableName interface this allows
// us to use a base table for writes and a view for reads.
type writableAuthToken struct {
// authTokenView is used for reading auth token's via the auth_token_account
// view which includes some columns from the auth_account table required by the
// API. Defining a type allows us to easily override the tableName to use the
// view name. authTokenViews share the same store struct/proto, which makes
// them easily convertable to vanilla AuthTokens when required.
type authTokenView struct {
*store.AuthToken
tableName string `gorm:"-"`
}
func (s *writableAuthToken) clone() *writableAuthToken {
cp := proto.Clone(s.AuthToken)
return &writableAuthToken{
AuthToken: cp.(*store.AuthToken),
// allocAuthTokenView is just easier/better than leaking the underlying type
// bits to the repo, since the repo needs to alloc this type quite often.
func allocAuthTokenView() *authTokenView {
fresh := &authTokenView{
AuthToken: &store.AuthToken{},
}
return fresh
}
func (s *writableAuthToken) toAuthToken() *AuthToken {
cp := proto.Clone(s.AuthToken)
// toAuthToken converts the view type to the type returned to repo callers and
// the API.
func (atv *authTokenView) toAuthToken() *AuthToken {
cp := proto.Clone(atv.AuthToken)
return &AuthToken{
AuthToken: cp.(*store.AuthToken),
}
@ -53,29 +60,31 @@ func (s *AuthToken) clone() *AuthToken {
}
}
func (s *AuthToken) toWritableAuthToken() *writableAuthToken {
cp := proto.Clone(s.AuthToken)
return &writableAuthToken{
AuthToken: cp.(*store.AuthToken),
// allocAuthToken is just easier/better than leaking the underlying type
// bits to the repo, since the repo needs to alloc this type quite often.
func allocAuthToken() *AuthToken {
fresh := &AuthToken{
AuthToken: &store.AuthToken{},
}
return fresh
}
// encrypt the entry's data using the provided cipher (wrapping.Wrapper)
func (s *writableAuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
func (at *AuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "authtoken.(writableAuthToken).encrypt"
// structwrapping doesn't support embedding, so we'll pass in the store.Entry directly
if err := structwrapping.WrapStruct(ctx, cipher, s.AuthToken, nil); err != nil {
if err := structwrapping.WrapStruct(ctx, cipher, at.AuthToken, nil); err != nil {
return errors.Wrap(err, op, errors.WithCode(errors.Encrypt))
}
s.KeyId = cipher.KeyID()
at.KeyId = cipher.KeyID()
return nil
}
// decrypt will decrypt the auth token's value using the provided cipher (wrapping.Wrapper)
func (s *AuthToken) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
func (at *AuthToken) decrypt(ctx context.Context, cipher wrapping.Wrapper) error {
const op = "authtoken.(AuthToken).decrypt"
// structwrapping doesn't support embedding, so we'll pass in the store.Entry directly
if err := structwrapping.UnwrapStruct(ctx, cipher, s.AuthToken, nil); err != nil {
if err := structwrapping.UnwrapStruct(ctx, cipher, at.AuthToken, nil); err != nil {
return errors.Wrap(err, op, errors.WithCode(errors.Decrypt))
}
return nil
@ -97,14 +106,20 @@ func newAuthTokenId() (string, error) {
return id, nil
}
// newAuthToken generates a token with a version prefix.
func newAuthToken() (string, error) {
// newAuthToken generates a new in-memory token. No options are currently
// supported.
func newAuthToken(_ ...Option) (*AuthToken, error) {
const op = "authtoken.newAuthToken"
token, err := base62.Random(tokenLength)
if err != nil {
return "", errors.Wrap(err, op, errors.WithCode(errors.Io))
return nil, errors.Wrap(err, op, errors.WithCode(errors.Io))
}
return fmt.Sprintf("%s%s", TokenValueVersionPrefix, token), nil
return &AuthToken{
AuthToken: &store.AuthToken{
Token: fmt.Sprintf("%s%s", TokenValueVersionPrefix, token),
},
}, nil
}
// EncryptToken is a shared function for encrypting a token value for return to

@ -1,9 +1,5 @@
package authtoken
// This file contains tests for methods defined in authtoken.go as well as tests which exercise the db
// functionality directly without going through the respository. Repository centric tests should be
// placed in repository_test.go
import (
"context"
"testing"
@ -21,6 +17,10 @@ import (
"google.golang.org/protobuf/proto"
)
// This file contains tests for methods defined in authtoken.go as well as tests which exercise the db
// functionality directly without going through the respository. Repository centric tests should be
// placed in repository_test.go
func TestAuthToken_DbUpdate(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
@ -101,10 +101,9 @@ func TestAuthToken_DbUpdate(t *testing.T) {
authTok := TestAuthToken(t, conn, kms, org.GetPublicId())
proto.Merge(authTok.AuthToken, tt.args.authTok)
wAuthToken := authTok.toWritableAuthToken()
err := wAuthToken.encrypt(context.Background(), wrapper)
err := authTok.encrypt(context.Background(), wrapper)
require.NoError(t, err)
cnt, err := w.Update(context.Background(), wAuthToken, tt.args.fieldMask, tt.args.nullMask)
cnt, err := w.Update(context.Background(), authTok, tt.args.fieldMask, tt.args.nullMask)
if tt.wantErr {
t.Logf("Got error :%v", err)
assert.Error(err)
@ -159,7 +158,7 @@ func TestAuthToken_DbCreate(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
at := &writableAuthToken{AuthToken: tt.in}
at := &AuthToken{AuthToken: tt.in}
err := at.encrypt(context.Background(), wrapper)
require.NoError(t, err)
err = db.New(conn).Create(context.Background(), at)
@ -187,18 +186,18 @@ func TestAuthToken_DbDelete(t *testing.T) {
tests := []struct {
name string
at *writableAuthToken
at *AuthToken
wantError bool
wantCnt int
}{
{
name: "basic",
at: &writableAuthToken{AuthToken: &store.AuthToken{PublicId: existingAuthTok.GetPublicId()}},
at: &AuthToken{AuthToken: &store.AuthToken{PublicId: existingAuthTok.GetPublicId()}},
wantCnt: 1,
},
{
name: "delete-nothing",
at: &writableAuthToken{AuthToken: &store.AuthToken{PublicId: testAuthTokenId()}},
at: &AuthToken{AuthToken: &store.AuthToken{PublicId: testAuthTokenId()}},
wantCnt: 0,
},
{
@ -209,7 +208,7 @@ func TestAuthToken_DbDelete(t *testing.T) {
},
{
name: "delete-no-public-id",
at: &writableAuthToken{AuthToken: &store.AuthToken{}},
at: &AuthToken{AuthToken: &store.AuthToken{}},
wantCnt: 0,
wantError: true,
},

@ -1,8 +1,14 @@
package authtoken
const (
defaultAuthTokenTableName = "auth_token_account"
defaultWritableAuthTokenTableName = "auth_token"
// defaultAuthTokenTableName is the table where auth tokens are stored.
defaultAuthTokenTableName = "auth_token"
// defaultAuthTokenViewName is a view that includes all the auth_token
// columns plus the auth_account columns of: scope_id, iam_user_id and
// auth_method_id. These additional columns are returned via the API for
// auth tokens, so the view's handy
defaultAuthTokenViewName = "auth_token_account"
)
// TableName returns the table name for the auth token.
@ -19,16 +25,16 @@ func (s *AuthToken) SetTableName(n string) {
s.tableName = n
}
// TableName returns the table name for the auth token.
func (s *writableAuthToken) TableName() string {
// TableName returns the table name for the authTokenView.
func (s *authTokenView) TableName() string {
if s.tableName != "" {
return s.tableName
}
return defaultWritableAuthTokenTableName
return defaultAuthTokenViewName
}
// SetTableName sets the table name. If the caller attempts to
// set the name to "" the name will be reset to the default name.
func (s *writableAuthToken) SetTableName(n string) {
func (s *authTokenView) SetTableName(n string) {
s.tableName = n
}

@ -6,7 +6,6 @@ import (
"time"
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/boundary/internal/authtoken/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
@ -70,19 +69,16 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
return nil, errors.New(errors.InvalidParameter, op, "missing auth account id")
}
at := allocAuthToken()
at.AuthAccountId = withAuthAccountId
id, err := newAuthTokenId()
at, err := newAuthToken()
if err != nil {
return nil, errors.Wrap(err, op)
}
at.PublicId = id
token, err := newAuthToken()
at.AuthAccountId = withAuthAccountId
id, err := newAuthTokenId()
if err != nil {
return nil, errors.Wrap(err, op)
}
at.Token = token
at.PublicId = id
databaseWrapper, err := r.kms.GetWrapper(ctx, withIamUser.GetScopeId(), kms.KeyPurposeDatabase)
if err != nil {
@ -97,7 +93,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
}
at.ExpirationTime = &timestamp.Timestamp{Timestamp: expiration}
var newAuthToken *writableAuthToken
var newAuthToken *AuthToken
_, err = r.writer.DoTx(
ctx,
db.StdRetryCnt,
@ -116,7 +112,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
at.AuthMethodId = acct.GetAuthMethodId()
at.IamUserId = acct.GetIamUserId()
newAuthToken = at.toWritableAuthToken()
newAuthToken = at.clone()
if err := newAuthToken.encrypt(ctx, databaseWrapper); err != nil {
return err
}
@ -133,7 +129,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User,
if err != nil {
return nil, errors.Wrap(err, op)
}
return newAuthToken.toAuthToken(), nil
return newAuthToken, nil
}
// LookupAuthToken returns the AuthToken for the provided id. Returns nil, nil if no AuthToken is found for id.
@ -146,14 +142,18 @@ func (r *Repository) LookupAuthToken(ctx context.Context, id string, opt ...Opti
}
opts := getOpts(opt...)
at := allocAuthToken()
at.PublicId = id
if err := r.reader.LookupByPublicId(ctx, at); err != nil {
// use the view, to bring in the required account columns. Just don't forget
// to convert it before returning it.
atv := allocAuthTokenView()
atv.PublicId = id
if err := r.reader.LookupByPublicId(ctx, atv); err != nil {
if errors.IsNotFoundError(err) {
return nil, nil
}
return nil, errors.Wrap(err, op)
}
at := atv.toAuthToken()
if opts.withTokenValue {
databaseWrapper, err := r.kms.GetWrapper(ctx, at.GetScopeId(), kms.KeyPurposeDatabase, kms.WithKeyId(at.GetKeyId()))
if err != nil {
@ -218,7 +218,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
delAt := retAT.toWritableAuthToken()
delAt := retAT.clone()
// tokens are not replicated, so they don't need oplog entries.
if _, err := w.Delete(ctx, delAt); err != nil {
return errors.Wrap(err, op, errors.WithMsg("delete auth token"))
@ -247,7 +247,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt ..
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
at := retAT.toWritableAuthToken()
at := retAT.clone()
// Setting the ApproximateLastAccessTime to null through using the null mask allows a defined db's
// trigger to set ApproximateLastAccessTime to the commit
// timestamp. Tokens are not replicated, so they don't need oplog entries.
@ -280,14 +280,18 @@ func (r *Repository) ListAuthTokens(ctx context.Context, withScopeIds []string,
}
opts := getOpts(opt...)
var authTokens []*AuthToken
if err := r.reader.SearchWhere(ctx, &authTokens, "auth_account_id in (select public_id from auth_account where scope_id in (?))", []interface{}{withScopeIds}, db.WithLimit(opts.withLimit)); err != nil {
// use the view, to bring in the required account columns. Just don't forget
// to convert them before returning them
var atvs []*authTokenView
if err := r.reader.SearchWhere(ctx, &atvs, "auth_account_id in (select public_id from auth_account where scope_id in (?))", []interface{}{withScopeIds}, db.WithLimit(opts.withLimit)); err != nil {
return nil, errors.Wrap(err, op)
}
for _, at := range authTokens {
at.Token = ""
at.CtToken = nil
at.KeyId = ""
authTokens := make([]*AuthToken, 0, len(atvs))
for _, atv := range atvs {
atv.Token = ""
atv.CtToken = nil
atv.KeyId = ""
authTokens = append(authTokens, atv.toAuthToken())
}
return authTokens, nil
}
@ -317,7 +321,7 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti
db.StdRetryCnt,
db.ExpBackoff{},
func(_ db.Reader, w db.Writer) error {
deleteAT := at.toWritableAuthToken()
deleteAT := at.clone()
// tokens are not replicated, so they don't need oplog entries.
rowsDeleted, err = w.Delete(ctx, deleteAT)
if err == nil && rowsDeleted > 1 {
@ -333,10 +337,3 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti
return rowsDeleted, nil
}
func allocAuthToken() *AuthToken {
fresh := &AuthToken{
AuthToken: &store.AuthToken{},
}
return fresh
}

@ -395,7 +395,7 @@ func TestRepository_ValidateToken(t *testing.T) {
{
name: "doesnt-exist",
id: badId,
token: badToken,
token: badToken.Token,
want: nil,
},
{
@ -409,7 +409,7 @@ func TestRepository_ValidateToken(t *testing.T) {
{
name: "mismatched-token",
id: at.GetPublicId(),
token: badToken,
token: badToken.Token,
want: nil,
wantIsErr: errors.Unknown,
},

Loading…
Cancel
Save