From ffcda56f4069e9248a0f318d94a5526637eb9919 Mon Sep 17 00:00:00 2001 From: Jim Date: Wed, 10 Feb 2021 13:52:42 -0500 Subject: [PATCH] a small refactor, so we can easily support options when creating new auth tokens. (#920) --- internal/authtoken/authtoken.go | 59 +++++++++++++++++---------- internal/authtoken/authtoken_test.go | 23 +++++------ internal/authtoken/gorm.go | 18 +++++--- internal/authtoken/repository.go | 57 ++++++++++++-------------- internal/authtoken/repository_test.go | 4 +- 5 files changed, 89 insertions(+), 72 deletions(-) diff --git a/internal/authtoken/authtoken.go b/internal/authtoken/authtoken.go index 5968ec6483..975eed54e6 100644 --- a/internal/authtoken/authtoken.go +++ b/internal/authtoken/authtoken.go @@ -19,22 +19,29 @@ import ( "google.golang.org/protobuf/proto" ) -// writableAuthToken is used for auth token writes. Since gorm relies on the TableName interface this allows -// us to use a base table for writes and a view for reads. -type writableAuthToken struct { +// authTokenView is used for reading auth token's via the auth_token_account +// view which includes some columns from the auth_account table required by the +// API. Defining a type allows us to easily override the tableName to use the +// view name. authTokenViews share the same store struct/proto, which makes +// them easily convertable to vanilla AuthTokens when required. +type authTokenView struct { *store.AuthToken tableName string `gorm:"-"` } -func (s *writableAuthToken) clone() *writableAuthToken { - cp := proto.Clone(s.AuthToken) - return &writableAuthToken{ - AuthToken: cp.(*store.AuthToken), +// allocAuthTokenView is just easier/better than leaking the underlying type +// bits to the repo, since the repo needs to alloc this type quite often. +func allocAuthTokenView() *authTokenView { + fresh := &authTokenView{ + AuthToken: &store.AuthToken{}, } + return fresh } -func (s *writableAuthToken) toAuthToken() *AuthToken { - cp := proto.Clone(s.AuthToken) +// toAuthToken converts the view type to the type returned to repo callers and +// the API. +func (atv *authTokenView) toAuthToken() *AuthToken { + cp := proto.Clone(atv.AuthToken) return &AuthToken{ AuthToken: cp.(*store.AuthToken), } @@ -53,29 +60,31 @@ func (s *AuthToken) clone() *AuthToken { } } -func (s *AuthToken) toWritableAuthToken() *writableAuthToken { - cp := proto.Clone(s.AuthToken) - return &writableAuthToken{ - AuthToken: cp.(*store.AuthToken), +// allocAuthToken is just easier/better than leaking the underlying type +// bits to the repo, since the repo needs to alloc this type quite often. +func allocAuthToken() *AuthToken { + fresh := &AuthToken{ + AuthToken: &store.AuthToken{}, } + return fresh } // encrypt the entry's data using the provided cipher (wrapping.Wrapper) -func (s *writableAuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper) error { +func (at *AuthToken) encrypt(ctx context.Context, cipher wrapping.Wrapper) error { const op = "authtoken.(writableAuthToken).encrypt" // structwrapping doesn't support embedding, so we'll pass in the store.Entry directly - if err := structwrapping.WrapStruct(ctx, cipher, s.AuthToken, nil); err != nil { + if err := structwrapping.WrapStruct(ctx, cipher, at.AuthToken, nil); err != nil { return errors.Wrap(err, op, errors.WithCode(errors.Encrypt)) } - s.KeyId = cipher.KeyID() + at.KeyId = cipher.KeyID() return nil } // decrypt will decrypt the auth token's value using the provided cipher (wrapping.Wrapper) -func (s *AuthToken) decrypt(ctx context.Context, cipher wrapping.Wrapper) error { +func (at *AuthToken) decrypt(ctx context.Context, cipher wrapping.Wrapper) error { const op = "authtoken.(AuthToken).decrypt" // structwrapping doesn't support embedding, so we'll pass in the store.Entry directly - if err := structwrapping.UnwrapStruct(ctx, cipher, s.AuthToken, nil); err != nil { + if err := structwrapping.UnwrapStruct(ctx, cipher, at.AuthToken, nil); err != nil { return errors.Wrap(err, op, errors.WithCode(errors.Decrypt)) } return nil @@ -97,14 +106,20 @@ func newAuthTokenId() (string, error) { return id, nil } -// newAuthToken generates a token with a version prefix. -func newAuthToken() (string, error) { +// newAuthToken generates a new in-memory token. No options are currently +// supported. +func newAuthToken(_ ...Option) (*AuthToken, error) { const op = "authtoken.newAuthToken" token, err := base62.Random(tokenLength) if err != nil { - return "", errors.Wrap(err, op, errors.WithCode(errors.Io)) + return nil, errors.Wrap(err, op, errors.WithCode(errors.Io)) } - return fmt.Sprintf("%s%s", TokenValueVersionPrefix, token), nil + + return &AuthToken{ + AuthToken: &store.AuthToken{ + Token: fmt.Sprintf("%s%s", TokenValueVersionPrefix, token), + }, + }, nil } // EncryptToken is a shared function for encrypting a token value for return to diff --git a/internal/authtoken/authtoken_test.go b/internal/authtoken/authtoken_test.go index 4f03d34138..c4977b119e 100644 --- a/internal/authtoken/authtoken_test.go +++ b/internal/authtoken/authtoken_test.go @@ -1,9 +1,5 @@ package authtoken -// This file contains tests for methods defined in authtoken.go as well as tests which exercise the db -// functionality directly without going through the respository. Repository centric tests should be -// placed in repository_test.go - import ( "context" "testing" @@ -21,6 +17,10 @@ import ( "google.golang.org/protobuf/proto" ) +// This file contains tests for methods defined in authtoken.go as well as tests which exercise the db +// functionality directly without going through the respository. Repository centric tests should be +// placed in repository_test.go + func TestAuthToken_DbUpdate(t *testing.T) { conn, _ := db.TestSetup(t, "postgres") wrapper := db.TestWrapper(t) @@ -101,10 +101,9 @@ func TestAuthToken_DbUpdate(t *testing.T) { authTok := TestAuthToken(t, conn, kms, org.GetPublicId()) proto.Merge(authTok.AuthToken, tt.args.authTok) - wAuthToken := authTok.toWritableAuthToken() - err := wAuthToken.encrypt(context.Background(), wrapper) + err := authTok.encrypt(context.Background(), wrapper) require.NoError(t, err) - cnt, err := w.Update(context.Background(), wAuthToken, tt.args.fieldMask, tt.args.nullMask) + cnt, err := w.Update(context.Background(), authTok, tt.args.fieldMask, tt.args.nullMask) if tt.wantErr { t.Logf("Got error :%v", err) assert.Error(err) @@ -159,7 +158,7 @@ func TestAuthToken_DbCreate(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { assert := assert.New(t) - at := &writableAuthToken{AuthToken: tt.in} + at := &AuthToken{AuthToken: tt.in} err := at.encrypt(context.Background(), wrapper) require.NoError(t, err) err = db.New(conn).Create(context.Background(), at) @@ -187,18 +186,18 @@ func TestAuthToken_DbDelete(t *testing.T) { tests := []struct { name string - at *writableAuthToken + at *AuthToken wantError bool wantCnt int }{ { name: "basic", - at: &writableAuthToken{AuthToken: &store.AuthToken{PublicId: existingAuthTok.GetPublicId()}}, + at: &AuthToken{AuthToken: &store.AuthToken{PublicId: existingAuthTok.GetPublicId()}}, wantCnt: 1, }, { name: "delete-nothing", - at: &writableAuthToken{AuthToken: &store.AuthToken{PublicId: testAuthTokenId()}}, + at: &AuthToken{AuthToken: &store.AuthToken{PublicId: testAuthTokenId()}}, wantCnt: 0, }, { @@ -209,7 +208,7 @@ func TestAuthToken_DbDelete(t *testing.T) { }, { name: "delete-no-public-id", - at: &writableAuthToken{AuthToken: &store.AuthToken{}}, + at: &AuthToken{AuthToken: &store.AuthToken{}}, wantCnt: 0, wantError: true, }, diff --git a/internal/authtoken/gorm.go b/internal/authtoken/gorm.go index f8d9ecbeaa..abcc6c8ec9 100644 --- a/internal/authtoken/gorm.go +++ b/internal/authtoken/gorm.go @@ -1,8 +1,14 @@ package authtoken const ( - defaultAuthTokenTableName = "auth_token_account" - defaultWritableAuthTokenTableName = "auth_token" + // defaultAuthTokenTableName is the table where auth tokens are stored. + defaultAuthTokenTableName = "auth_token" + + // defaultAuthTokenViewName is a view that includes all the auth_token + // columns plus the auth_account columns of: scope_id, iam_user_id and + // auth_method_id. These additional columns are returned via the API for + // auth tokens, so the view's handy + defaultAuthTokenViewName = "auth_token_account" ) // TableName returns the table name for the auth token. @@ -19,16 +25,16 @@ func (s *AuthToken) SetTableName(n string) { s.tableName = n } -// TableName returns the table name for the auth token. -func (s *writableAuthToken) TableName() string { +// TableName returns the table name for the authTokenView. +func (s *authTokenView) TableName() string { if s.tableName != "" { return s.tableName } - return defaultWritableAuthTokenTableName + return defaultAuthTokenViewName } // SetTableName sets the table name. If the caller attempts to // set the name to "" the name will be reset to the default name. -func (s *writableAuthToken) SetTableName(n string) { +func (s *authTokenView) SetTableName(n string) { s.tableName = n } diff --git a/internal/authtoken/repository.go b/internal/authtoken/repository.go index 501d685ec3..414de59b5b 100644 --- a/internal/authtoken/repository.go +++ b/internal/authtoken/repository.go @@ -6,7 +6,6 @@ import ( "time" "github.com/golang/protobuf/ptypes" - "github.com/hashicorp/boundary/internal/authtoken/store" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/db/timestamp" "github.com/hashicorp/boundary/internal/errors" @@ -70,19 +69,16 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User, return nil, errors.New(errors.InvalidParameter, op, "missing auth account id") } - at := allocAuthToken() - at.AuthAccountId = withAuthAccountId - id, err := newAuthTokenId() + at, err := newAuthToken() if err != nil { return nil, errors.Wrap(err, op) } - at.PublicId = id - - token, err := newAuthToken() + at.AuthAccountId = withAuthAccountId + id, err := newAuthTokenId() if err != nil { return nil, errors.Wrap(err, op) } - at.Token = token + at.PublicId = id databaseWrapper, err := r.kms.GetWrapper(ctx, withIamUser.GetScopeId(), kms.KeyPurposeDatabase) if err != nil { @@ -97,7 +93,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User, } at.ExpirationTime = ×tamp.Timestamp{Timestamp: expiration} - var newAuthToken *writableAuthToken + var newAuthToken *AuthToken _, err = r.writer.DoTx( ctx, db.StdRetryCnt, @@ -116,7 +112,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User, at.AuthMethodId = acct.GetAuthMethodId() at.IamUserId = acct.GetIamUserId() - newAuthToken = at.toWritableAuthToken() + newAuthToken = at.clone() if err := newAuthToken.encrypt(ctx, databaseWrapper); err != nil { return err } @@ -133,7 +129,7 @@ func (r *Repository) CreateAuthToken(ctx context.Context, withIamUser *iam.User, if err != nil { return nil, errors.Wrap(err, op) } - return newAuthToken.toAuthToken(), nil + return newAuthToken, nil } // LookupAuthToken returns the AuthToken for the provided id. Returns nil, nil if no AuthToken is found for id. @@ -146,14 +142,18 @@ func (r *Repository) LookupAuthToken(ctx context.Context, id string, opt ...Opti } opts := getOpts(opt...) - at := allocAuthToken() - at.PublicId = id - if err := r.reader.LookupByPublicId(ctx, at); err != nil { + // use the view, to bring in the required account columns. Just don't forget + // to convert it before returning it. + atv := allocAuthTokenView() + atv.PublicId = id + if err := r.reader.LookupByPublicId(ctx, atv); err != nil { if errors.IsNotFoundError(err) { return nil, nil } return nil, errors.Wrap(err, op) } + + at := atv.toAuthToken() if opts.withTokenValue { databaseWrapper, err := r.kms.GetWrapper(ctx, at.GetScopeId(), kms.KeyPurposeDatabase, kms.WithKeyId(at.GetKeyId())) if err != nil { @@ -218,7 +218,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt .. db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - delAt := retAT.toWritableAuthToken() + delAt := retAT.clone() // tokens are not replicated, so they don't need oplog entries. if _, err := w.Delete(ctx, delAt); err != nil { return errors.Wrap(err, op, errors.WithMsg("delete auth token")) @@ -247,7 +247,7 @@ func (r *Repository) ValidateToken(ctx context.Context, id, token string, opt .. db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - at := retAT.toWritableAuthToken() + at := retAT.clone() // Setting the ApproximateLastAccessTime to null through using the null mask allows a defined db's // trigger to set ApproximateLastAccessTime to the commit // timestamp. Tokens are not replicated, so they don't need oplog entries. @@ -280,14 +280,18 @@ func (r *Repository) ListAuthTokens(ctx context.Context, withScopeIds []string, } opts := getOpts(opt...) - var authTokens []*AuthToken - if err := r.reader.SearchWhere(ctx, &authTokens, "auth_account_id in (select public_id from auth_account where scope_id in (?))", []interface{}{withScopeIds}, db.WithLimit(opts.withLimit)); err != nil { + // use the view, to bring in the required account columns. Just don't forget + // to convert them before returning them + var atvs []*authTokenView + if err := r.reader.SearchWhere(ctx, &atvs, "auth_account_id in (select public_id from auth_account where scope_id in (?))", []interface{}{withScopeIds}, db.WithLimit(opts.withLimit)); err != nil { return nil, errors.Wrap(err, op) } - for _, at := range authTokens { - at.Token = "" - at.CtToken = nil - at.KeyId = "" + authTokens := make([]*AuthToken, 0, len(atvs)) + for _, atv := range atvs { + atv.Token = "" + atv.CtToken = nil + atv.KeyId = "" + authTokens = append(authTokens, atv.toAuthToken()) } return authTokens, nil } @@ -317,7 +321,7 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti db.StdRetryCnt, db.ExpBackoff{}, func(_ db.Reader, w db.Writer) error { - deleteAT := at.toWritableAuthToken() + deleteAT := at.clone() // tokens are not replicated, so they don't need oplog entries. rowsDeleted, err = w.Delete(ctx, deleteAT) if err == nil && rowsDeleted > 1 { @@ -333,10 +337,3 @@ func (r *Repository) DeleteAuthToken(ctx context.Context, id string, opt ...Opti return rowsDeleted, nil } - -func allocAuthToken() *AuthToken { - fresh := &AuthToken{ - AuthToken: &store.AuthToken{}, - } - return fresh -} diff --git a/internal/authtoken/repository_test.go b/internal/authtoken/repository_test.go index e7db5729ac..289f8e2416 100644 --- a/internal/authtoken/repository_test.go +++ b/internal/authtoken/repository_test.go @@ -395,7 +395,7 @@ func TestRepository_ValidateToken(t *testing.T) { { name: "doesnt-exist", id: badId, - token: badToken, + token: badToken.Token, want: nil, }, { @@ -409,7 +409,7 @@ func TestRepository_ValidateToken(t *testing.T) { { name: "mismatched-token", id: at.GetPublicId(), - token: badToken, + token: badToken.Token, want: nil, wantIsErr: errors.Unknown, },