From b2fec2e4f0a328e390db9467826d83b4009c1e6e Mon Sep 17 00:00:00 2001 From: Todd Date: Fri, 27 Oct 2023 14:39:53 -0700 Subject: [PATCH] Store and use refresh tokens in the client cache (#3857) * Refresh resources using refresh tokens --- api/apierror.go | 9 +- .../internal/cache/options_test.go | 8 +- .../internal/cache/refresh_test.go | 136 +++--- .../clientcache/internal/cache/repository.go | 4 +- .../cache/repository_refresh_token.go | 140 ++++++ .../cache/repository_refresh_token_test.go | 123 ++++++ .../internal/cache/repository_sessions.go | 58 ++- .../cache/repository_sessions_test.go | 115 ++++- .../internal/cache/repository_targets.go | 60 ++- .../internal/cache/repository_targets_test.go | 116 ++++- .../internal/cache/repository_test.go | 2 +- .../clientcache/internal/cache/store_test.go | 104 +++++ .../clientcache/internal/daemon/testing.go | 12 +- .../clientcache/internal/daemon/ticker.go | 2 +- internal/clientcache/internal/db/schema.sql | 31 +- internal/daemon/controller/handlers/errors.go | 21 + .../daemon/controller/handlers/errors_test.go | 12 + internal/errors/code.go | 2 + internal/errors/code_test.go | 5 + internal/errors/info.go | 4 + internal/refreshtoken/refresh_token.go | 133 ++++++ internal/refreshtoken/refresh_token_test.go | 398 ++++++++++++++++++ 22 files changed, 1389 insertions(+), 106 deletions(-) create mode 100644 internal/clientcache/internal/cache/repository_refresh_token.go create mode 100644 internal/clientcache/internal/cache/repository_refresh_token_test.go create mode 100644 internal/refreshtoken/refresh_token.go create mode 100644 internal/refreshtoken/refresh_token_test.go diff --git a/api/apierror.go b/api/apierror.go index 90b7eb617b..073c163663 100644 --- a/api/apierror.go +++ b/api/apierror.go @@ -11,10 +11,11 @@ import ( ) var ( - ErrNotFound = &Error{Kind: codes.NotFound.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusNotFound}}} - ErrInvalidArgument = &Error{Kind: codes.InvalidArgument.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}} - ErrPermissionDenied = &Error{Kind: codes.PermissionDenied.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusForbidden}}} - ErrUnauthorized = &Error{Kind: codes.Unauthenticated.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusUnauthorized}}} + ErrNotFound = &Error{Kind: codes.NotFound.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusNotFound}}} + ErrInvalidArgument = &Error{Kind: codes.InvalidArgument.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}} + ErrPermissionDenied = &Error{Kind: codes.PermissionDenied.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusForbidden}}} + ErrUnauthorized = &Error{Kind: codes.Unauthenticated.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusUnauthorized}}} + ErrInvalidRefreshToken = &Error{Kind: "invalid refresh token", response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}} ) // AsServerError returns an api *Error from the provided error. If the provided error diff --git a/internal/clientcache/internal/cache/options_test.go b/internal/clientcache/internal/cache/options_test.go index 9e8e491bae..7724918f2e 100644 --- a/internal/clientcache/internal/cache/options_test.go +++ b/internal/clientcache/internal/cache/options_test.go @@ -33,7 +33,9 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithTargetRetrievalFunc", func(t *testing.T) { - var f TargetRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*targets.Target, error) { return nil, nil } + var f TargetRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { + return nil, nil, "", nil + } opts, err := getOpts(WithTargetRetrievalFunc(f)) require.NoError(t, err) @@ -44,8 +46,8 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, opts, testOpts) }) t.Run("WithSessionRetrievalFunc", func(t *testing.T) { - var f SessionRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*sessions.Session, error) { - return nil, nil + var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { + return nil, nil, "", nil } opts, err := getOpts(WithSessionRetrievalFunc(f)) require.NoError(t, err) diff --git a/internal/clientcache/internal/cache/refresh_test.go b/internal/clientcache/internal/cache/refresh_test.go index 98c01f20e1..ba3cc86f81 100644 --- a/internal/clientcache/internal/cache/refresh_test.go +++ b/internal/clientcache/internal/cache/refresh_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "strconv" "sync" "testing" "time" @@ -24,10 +25,42 @@ import ( // testStaticResourceRetrievalFunc returns a function that always returns the // provided slice and a nil error. The returned function can be passed into the // options that provide a resource retrieval func such as -// WithTargetRetrievalFunc and WithSessionRetrievalFunc. -func testStaticResourceRetrievalFunc[T any](ret []T) func(context.Context, string, string) ([]T, error) { - return func(ctx context.Context, s1, s2 string) ([]T, error) { - return ret, nil +// WithTargetRetrievalFunc and WithSessionRetrievalFunc. The provided refresh +// token determines the returned value and is a string representation of an +// incrementing integer. This integer is the index into the provided return +// values and once it reaches the length of the provided slice it returns an +// empty slice and the same refresh token repeatedly. +func testStaticResourceRetrievalFunc[T any](t *testing.T, ret [][]T, removed [][]string) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { + t.Helper() + require.Equal(t, len(ret), len(removed), "returned slice and removed slice must be the same length") + return func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { + index := 0 + if refToken != "" { + var err error + index, err = strconv.Atoi(string(refToken)) + require.NoError(t, err) + } + + switch { + case len(ret) == 0: + return nil, nil, "", nil + case index > 0 && index >= len(ret): + return []T{}, []string{}, RefreshTokenValue(fmt.Sprintf("%d", index)), nil + default: + return ret[index], removed[index], RefreshTokenValue(fmt.Sprintf("%d", index+1)), nil + } + } +} + +// testErroringForRefreshTokenRetrievalFunc returns a refresh token error when +// the refresh token is not empty. This is useful for testing behavior when +// the refresh token has expired or is otherwise invalid. +func testErroringForRefreshTokenRetrievalFunc[T any](t *testing.T, ret []T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { + return func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) { + if refToken != "" { + return nil, nil, "", api.ErrInvalidRefreshToken + } + return ret, nil, "1", nil } } @@ -297,32 +330,32 @@ func TestRefresh(t *testing.T) { target("1"), target("2"), target("3"), + target("4"), } - assert.NoError(t, rs.Refresh(ctx, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { - require.Equal(t, boundaryAddr, addr) - require.Equal(t, at.Token, token) - return retTargets, nil - }))) + opts := []Option{ + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, + [][]*targets.Target{ + retTargets[:3], + retTargets[3:], + }, + [][]string{ + nil, + {retTargets[0].Id, retTargets[1].Id}, + }, + )), + } + assert.NoError(t, rs.Refresh(ctx, opts...)) cachedTargets, err := r.ListTargets(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retTargets, cachedTargets) - - t.Run("empty response clears it out", func(t *testing.T) { - assert.NoError(t, rs.Refresh(ctx, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { - require.Equal(t, boundaryAddr, addr) - require.Equal(t, at.Token, token) - return nil, nil - }))) - - cachedTargets, err := r.ListTargets(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) - }) + assert.ElementsMatch(t, retTargets[:3], cachedTargets) + + // Second call removes the first 2 resources from the cache and adds the last + assert.NoError(t, rs.Refresh(ctx, opts...)) + cachedTargets, err = r.ListTargets(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retTargets[2:], cachedTargets) }) t.Run("set sessions", func(t *testing.T) { @@ -330,42 +363,49 @@ func TestRefresh(t *testing.T) { session("1"), session("2"), session("3"), + session("4"), } - assert.NoError(t, rs.Refresh(ctx, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(retSess)))) - + opts := []Option{ + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, + [][]*sessions.Session{ + retSess[:3], + retSess[3:], + }, + [][]string{ + nil, + {retSess[0].Id, retSess[1].Id}, + }, + )), + } + assert.NoError(t, rs.Refresh(ctx, opts...)) cachedSessions, err := r.ListSessions(ctx, at.Id) assert.NoError(t, err) - assert.ElementsMatch(t, retSess, cachedSessions) - - t.Run("empty response clears it out", func(t *testing.T) { - assert.NoError(t, rs.Refresh(ctx, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)))) + assert.ElementsMatch(t, retSess[:3], cachedSessions) - cachedTargets, err := r.ListSessions(ctx, at.Id) - assert.NoError(t, err) - assert.Empty(t, cachedTargets) - }) + // Second call removes the first 2 resources from the cache and adds the last + assert.NoError(t, rs.Refresh(ctx, opts...)) + cachedSessions, err = r.ListSessions(ctx, at.Id) + assert.NoError(t, err) + assert.ElementsMatch(t, retSess[2:], cachedSessions) }) t.Run("error propogates up", func(t *testing.T) { innerErr := errors.New("test error") err := rs.Refresh(ctx, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), - WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) { + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) - return nil, innerErr + return nil, nil, "", innerErr })) assert.ErrorContains(t, err, innerErr.Error()) err = rs.Refresh(ctx, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)), - WithSessionRetrievalFunc(func(ctx context.Context, addr, token string) ([]*sessions.Session, error) { + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)), + WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { require.Equal(t, boundaryAddr, addr) require.Equal(t, at.Token, token) - return nil, innerErr + return nil, nil, "", innerErr })) assert.ErrorContains(t, err, innerErr.Error()) }) @@ -385,8 +425,8 @@ func TestRefresh(t *testing.T) { assert.Len(t, us, 1) rs.Refresh(ctx, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)), - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil))) + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)), + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil))) ps, err = r.listTokens(ctx, u) require.NoError(t, err) diff --git a/internal/clientcache/internal/cache/repository.go b/internal/clientcache/internal/cache/repository.go index 7fd528b7f0..c44c5aa894 100644 --- a/internal/clientcache/internal/cache/repository.go +++ b/internal/clientcache/internal/cache/repository.go @@ -56,7 +56,7 @@ func NewRepository(ctx context.Context, conn *db.DB, idToAuthToken *sync.Map, ke }, nil } -func (r *Repository) SaveError(ctx context.Context, u *user, resourceType string, err error) error { +func (r *Repository) SaveError(ctx context.Context, u *user, resourceType resourceType, err error) error { const op = "cache.(Repository).StoreError" switch { case resourceType == "": @@ -70,7 +70,7 @@ func (r *Repository) SaveError(ctx context.Context, u *user, resourceType string } apiErr := &ApiError{ UserId: u.Id, - ResourceType: resourceType, + ResourceType: string(resourceType), Error: err.Error(), } onConflict := db.OnConflict{ diff --git a/internal/clientcache/internal/cache/repository_refresh_token.go b/internal/clientcache/internal/cache/repository_refresh_token.go new file mode 100644 index 0000000000..0f53ca5ec6 --- /dev/null +++ b/internal/clientcache/internal/cache/repository_refresh_token.go @@ -0,0 +1,140 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + + "github.com/hashicorp/boundary/internal/db" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/util" + "github.com/hashicorp/go-dbw" +) + +// RefreshTokenValue is the the type for the actual refresh token value handled +// by the client cache. +type RefreshTokenValue string + +// lookupRefreshToken returns the last known valid refresh token or an empty +// string if one is unkonwn. No error is returned if no valid refresh token is +// found. +func (r *Repository) lookupRefreshToken(ctx context.Context, u *user, resourceType resourceType) (RefreshTokenValue, error) { + const op = "cache.(Repsoitory).lookupRefreshToken" + switch { + case util.IsNil(u): + return "", errors.New(ctx, errors.InvalidParameter, op, "user is nil") + case u.Id == "": + return "", errors.New(ctx, errors.InvalidParameter, op, "user id is empty") + case !resourceType.valid(): + return "", errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid") + } + + rt := &refreshToken{ + UserId: u.Id, + ResourceType: resourceType, + } + if err := r.rw.LookupById(ctx, rt); err != nil { + if errors.Is(err, dbw.ErrRecordNotFound) { + return "", nil + } + return "", errors.Wrap(ctx, err, op) + } + return rt.RefreshToken, nil +} + +// deleteRefreshToken deletes the refresh token for the provided user and resource type +func (r *Repository) deleteRefreshToken(ctx context.Context, u *user, rType resourceType) error { + const op = "cache.(Repository).deleteRefreshToken" + switch { + case util.IsNil(u): + return errors.New(ctx, errors.InvalidParameter, op, "user is nil") + case u.Id == "": + return errors.New(ctx, errors.InvalidParameter, op, "user id is empty") + case !rType.valid(): + return errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid") + } + + _, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { + rt := &refreshToken{ + UserId: u.Id, + ResourceType: rType, + } + n, err := w.Delete(ctx, rt) + if err != nil { + return errors.Wrap(ctx, err, op) + } + if n > 1 { + return errors.New(ctx, errors.MultipleRecords, op, "attempted to delete a single resource but multiple were resource deletions were attempted") + } + return nil + }) + if err != nil { + return err + } + + return nil +} + +func upsertRefreshToken(ctx context.Context, writer db.Writer, u *user, rt resourceType, tok RefreshTokenValue) error { + const op = "cache.upsertRefreshToken" + switch { + case util.IsNil(writer): + return errors.New(ctx, errors.InvalidParameter, op, "writer is nil") + case !writer.IsTx(ctx): + return errors.New(ctx, errors.InvalidParameter, op, "writer isn't in a transaction") + case util.IsNil(u): + return errors.New(ctx, errors.InvalidParameter, op, "user is nil") + case u.Id == "": + return errors.New(ctx, errors.InvalidParameter, op, "user id is empty") + case !rt.valid(): + return errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid") + } + + refTok := &refreshToken{ + UserId: u.Id, + ResourceType: rt, + RefreshToken: tok, + } + + switch tok { + case "": + writer.Delete(ctx, refTok) + default: + onConflict := &db.OnConflict{ + Target: db.Columns{"user_id", "resource_type"}, + Action: db.SetColumns([]string{"refresh_token"}), + } + if err := writer.Create(ctx, refTok, db.WithOnConflict(onConflict)); err != nil { + return errors.Wrap(ctx, err, op) + } + } + + return nil +} + +type resourceType string + +const ( + unknownResourceType resourceType = "unknown" + targetResourceType resourceType = "target" + sessionResourceType resourceType = "session" +) + +func (r resourceType) valid() bool { + switch r { + case targetResourceType, sessionResourceType: + return true + } + return false +} + +type refreshToken struct { + UserId string `gorm:"primaryKey"` + ResourceType resourceType `gorm:"primaryKey"` + RefreshToken RefreshTokenValue +} + +func (*refreshToken) TableName() string { + return "refresh_token" +} diff --git a/internal/clientcache/internal/cache/repository_refresh_token_test.go b/internal/clientcache/internal/cache/repository_refresh_token_test.go new file mode 100644 index 0000000000..bbc729f932 --- /dev/null +++ b/internal/clientcache/internal/cache/repository_refresh_token_test.go @@ -0,0 +1,123 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package cache + +import ( + "context" + "sync" + "testing" + + "github.com/hashicorp/boundary/api/authtokens" + cachedb "github.com/hashicorp/boundary/internal/clientcache/internal/db" + "github.com/hashicorp/boundary/internal/db" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLookupRefreshToken(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + r, err := NewRepository(ctx, s, &sync.Map{}, + mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}), + sliceBasedAuthTokenBoundaryReader(nil)) + require.NoError(t, err) + + t.Run("nil user", func(t *testing.T) { + _, err := r.lookupRefreshToken(ctx, nil, targetResourceType) + assert.Error(t, err) + assert.ErrorContains(t, err, "user is nil") + }) + + t.Run("user id is empty", func(t *testing.T) { + _, err := r.lookupRefreshToken(ctx, &user{Address: "addr"}, targetResourceType) + assert.Error(t, err) + assert.ErrorContains(t, err, "user id is empty") + }) + + t.Run("resource type is invalid", func(t *testing.T) { + _, err := r.lookupRefreshToken(ctx, &user{Id: "something", Address: "addr"}, resourceType("invalid")) + assert.Error(t, err) + assert.ErrorContains(t, err, "resource type is invalid") + }) + + t.Run("unknown user", func(t *testing.T) { + got, err := r.lookupRefreshToken(ctx, &user{Id: "unkonwnUser", Address: "addr"}, targetResourceType) + assert.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("no refresh token", func(t *testing.T) { + known := &user{Id: "known", Address: "addr"} + require.NoError(t, r.rw.Create(ctx, known)) + + got, err := r.lookupRefreshToken(ctx, known, targetResourceType) + assert.NoError(t, err) + assert.Empty(t, got) + }) + + t.Run("got refresh token", func(t *testing.T) { + token := RefreshTokenValue("something") + known := &user{Id: "withrefreshtoken", Address: "addr"} + require.NoError(t, r.rw.Create(ctx, known)) + + r.rw.DoTx(ctx, 1, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { + require.NoError(t, upsertRefreshToken(ctx, w, known, targetResourceType, token)) + return nil + }) + + got, err := r.lookupRefreshToken(ctx, known, targetResourceType) + assert.NoError(t, err) + assert.Equal(t, token, got) + }) +} + +func TestDeleteRefreshTokens(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + r, err := NewRepository(ctx, s, &sync.Map{}, + mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}), + sliceBasedAuthTokenBoundaryReader(nil)) + require.NoError(t, err) + + t.Run("nil user", func(t *testing.T) { + err := r.deleteRefreshToken(ctx, nil, targetResourceType) + assert.Error(t, err) + assert.ErrorContains(t, err, "user is nil") + }) + + t.Run("no user id", func(t *testing.T) { + err := r.deleteRefreshToken(ctx, &user{Address: "addr"}, targetResourceType) + assert.Error(t, err) + assert.ErrorContains(t, err, "user id is empty") + }) + + t.Run("invalid resource type", func(t *testing.T) { + err := r.deleteRefreshToken(ctx, &user{Id: "id", Address: "addr"}, "this is invalid") + assert.Error(t, err) + assert.ErrorContains(t, err, "resource type is invalid") + }) + + t.Run("success", func(t *testing.T) { + u := &user{Id: "id", Address: "addr"} + require.NoError(t, r.rw.Create(ctx, u)) + + r.rw.DoTx(ctx, 1, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { + require.NoError(t, upsertRefreshToken(ctx, w, u, targetResourceType, "token")) + return nil + }) + got, err := r.lookupRefreshToken(ctx, u, targetResourceType) + require.NoError(t, err) + require.NotEmpty(t, got) + + assert.NoError(t, r.deleteRefreshToken(ctx, u, targetResourceType)) + + got, err = r.lookupRefreshToken(ctx, u, targetResourceType) + require.NoError(t, err) + require.Empty(t, got) + }) +} diff --git a/internal/clientcache/internal/cache/repository_sessions.go b/internal/clientcache/internal/cache/repository_sessions.go index 1b3b13ab06..a28acb147c 100644 --- a/internal/clientcache/internal/cache/repository_sessions.go +++ b/internal/clientcache/internal/cache/repository_sessions.go @@ -15,30 +15,32 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" - "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/mql" ) // SessionRetrievalFunc is a function that retrieves sessions // from the provided boundary addr using the provided token. -type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error) +type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*sessions.Session, removedIds []string, refreshToken RefreshTokenValue, err error) -func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) { +func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) { const op = "cache.defaultSessionFunc" client, err := api.NewClient(&api.Config{ Addr: addr, - Token: token, + Token: authTok, }) if err != nil { - return nil, errors.Wrap(ctx, err, op) + return nil, nil, "", errors.Wrap(ctx, err, op) } sClient := sessions.NewClient(client) - l, err := sClient.List(ctx, "global", sessions.WithRecursive(true)) + l, err := sClient.List(ctx, "global", sessions.WithRecursive(true), sessions.WithRefreshToken(string(refreshTok))) if err != nil { - return nil, errors.Wrap(ctx, err, op) + if api.ErrInvalidRefreshToken.Is(err) { + return nil, nil, "", err + } + return nil, nil, "", errors.Wrap(ctx, err, op) } - return l.Items, nil + return l.Items, l.RemovedIds, RefreshTokenValue(l.RefreshToken), nil } func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { @@ -51,6 +53,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au case u.Address == "": return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing") } + const resourceType = sessionResourceType opts, err := getOpts(opt...) if err != nil { @@ -59,16 +62,28 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au if opts.withSessionRetrievalFunc == nil { opts.withSessionRetrievalFunc = defaultSessionFunc } + oldRefreshToken, err := r.lookupRefreshToken(ctx, u, resourceType) + if err != nil { + return errors.Wrap(ctx, err, op) + } // Find and use a token for retrieving sessions var gotResponse bool var resp []*sessions.Session + var newRefreshToken RefreshTokenValue + var removedIds []string var retErr error for at, t := range tokens { - resp, err = opts.withSessionRetrievalFunc(ctx, u.Address, t) + resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshToken) + if api.ErrInvalidRefreshToken.Is(err) { + if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { + return errors.Wrap(ctx, err, op) + } + // try again without the refresh token + oldRefreshToken = "" + resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "") + } if err != nil { - // TODO: If we get an error about the token no longer having - // permissions, remove it. retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) continue } @@ -77,7 +92,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au } if retErr != nil { - if saveErr := r.SaveError(ctx, u, resource.Session.String(), retErr); saveErr != nil { + if saveErr := r.SaveError(ctx, u, resourceType, retErr); saveErr != nil { return stderrors.Join(err, errors.Wrap(ctx, saveErr, op)) } } @@ -87,10 +102,17 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u)) _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { - // TODO: Instead of deleting everything, use refresh tokens and apply the delta - if _, err := w.Exec(ctx, "delete from session where user_id = @user_id", - []any{sql.Named("user_id", u.Id)}); err != nil { - return err + switch { + case oldRefreshToken == "": + if _, err := w.Exec(ctx, "delete from session where user_id = @user_id", + []any{sql.Named("user_id", u.Id)}); err != nil { + return err + } + case len(removedIds) > 0: + if _, err := w.Exec(ctx, "delete from session where id in @ids", + []any{sql.Named("ids", removedIds)}); err != nil { + return err + } } for _, s := range resp { @@ -114,6 +136,10 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au return err } } + + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } return nil }) if err != nil { diff --git a/internal/clientcache/internal/cache/repository_sessions_test.go b/internal/clientcache/internal/cache/repository_sessions_test.go index cfd307de41..369d07ee5e 100644 --- a/internal/clientcache/internal/cache/repository_sessions_test.go +++ b/internal/clientcache/internal/cache/repository_sessions_test.go @@ -116,13 +116,22 @@ func TestRepository_refreshSessions(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshSessions(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(tc.sess))) + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil}))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) var got []*Session require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil)) assert.Len(t, got, tc.wantCount) + + t.Cleanup(func() { + refTok := &refreshToken{ + UserId: tc.u.Id, + ResourceType: sessionResourceType, + } + _, err := r.rw.Delete(ctx, refTok) + require.NoError(t, err) + }) } else { assert.ErrorContains(t, err, tc.errorContains) } @@ -130,6 +139,96 @@ func TestRepository_refreshSessions(t *testing.T) { } } +func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := user{ + Id: "u1", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt := KeyringToken{ + KeyringType: "keyring", + TokenName: "token", + AuthTokenId: at.Id, + } + atMap := map[ringToken]*authtokens.AuthToken{ + {kt.KeyringType, kt.TokenName}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + ss := [][]*sessions.Session{ + { + { + Id: "ttcp_1", + Status: "status1", + Endpoint: "address1", + Type: "tcp", + }, + { + Id: "ttcp_2", + Status: "status2", + Endpoint: "address2", + Type: "tcp", + }, + }, + { + { + Id: "ttcp_3", + Status: "status3", + Endpoint: "address3", + Type: "tcp", + }, + }, + } + + err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + assert.NoError(t, err) + + got, err := r.ListSessions(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 2) + + // Refreshing again uses the refresh token and get additional sessions, appending + // them to the response + err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))) + assert.NoError(t, err) + + got, err = r.ListSessions(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 3) + + // Refreshing again wont return any more resources, but also none should be + // removed + require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))) + assert.NoError(t, err) + + got, err = r.ListSessions(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 3) + + // Refresh again with the refresh token being reported as invalid. + require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithSessionRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0])))) + assert.NoError(t, err) + + got, err = r.ListSessions(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 2) +} + func TestRepository_ListSessions(t *testing.T) { ctx := context.Background() s, err := cachedb.Open(ctx) @@ -200,7 +299,7 @@ func TestRepository_ListSessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(ss)))) + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) t.Run("wrong user gets no sessions", func(t *testing.T) { l, err := r.ListSessions(ctx, kt2.AuthTokenId) @@ -307,7 +406,7 @@ func TestRepository_QuerySessions(t *testing.T) { }, } require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(ss)))) + WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil})))) t.Run("wrong token gets no sessions", func(t *testing.T) { l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query) @@ -337,7 +436,15 @@ func TestDefaultSessionRetrievalFunc(t *testing.T) { _, err = tarClient.AuthorizeSession(tc.Context(), tar1.Item.Id) assert.NoError(t, err) - got, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token) + got, removed, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "") assert.NoError(t, err) + assert.NotEmpty(t, refTok) + assert.Empty(t, removed) assert.Len(t, got, 1) + + got2, removed2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok) + assert.NoError(t, err) + assert.NotEmpty(t, refTok2) + assert.Empty(t, removed2) + assert.Empty(t, got2) } diff --git a/internal/clientcache/internal/cache/repository_targets.go b/internal/clientcache/internal/cache/repository_targets.go index df0693ee3f..26c4077d88 100644 --- a/internal/clientcache/internal/cache/repository_targets.go +++ b/internal/clientcache/internal/cache/repository_targets.go @@ -15,30 +15,32 @@ import ( "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/event" - "github.com/hashicorp/boundary/internal/types/resource" "github.com/hashicorp/boundary/internal/util" "github.com/hashicorp/mql" ) // TargetRetrievalFunc is a function that retrieves targets // from the provided boundary addr using the provided token. -type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error) +type TargetRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*targets.Target, removedIds []string, refreshToken RefreshTokenValue, err error) -func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) { +func defaultTargetFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) { const op = "cache.defaultTargetFunc" client, err := api.NewClient(&api.Config{ Addr: addr, - Token: token, + Token: authTok, }) if err != nil { - return nil, errors.Wrap(ctx, err, op) + return nil, nil, "", errors.Wrap(ctx, err, op) } tarClient := targets.NewClient(client) - l, err := tarClient.List(ctx, "global", targets.WithRecursive(true)) + l, err := tarClient.List(ctx, "global", targets.WithRecursive(true), targets.WithRefreshToken(string(refreshTok))) if err != nil { - return nil, errors.Wrap(ctx, err, op) + if api.ErrInvalidRefreshToken.Is(err) { + return nil, nil, "", err + } + return nil, nil, "", errors.Wrap(ctx, err, op) } - return l.Items, nil + return l.Items, l.RemovedIds, RefreshTokenValue(l.RefreshToken), nil } func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error { @@ -49,6 +51,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut case u.Id == "": return errors.New(ctx, errors.InvalidParameter, op, "user id is missing") } + const resourceType = targetResourceType opts, err := getOpts(opt...) if err != nil { @@ -57,16 +60,28 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut if opts.withTargetRetrievalFunc == nil { opts.withTargetRetrievalFunc = defaultTargetFunc } + oldRefreshToken, err := r.lookupRefreshToken(ctx, u, resourceType) + if err != nil { + return errors.Wrap(ctx, err, op) + } // Find and use a token for retrieving targets var gotResponse bool var resp []*targets.Target + var removedIds []string + var newRefreshToken RefreshTokenValue var retErr error for at, t := range tokens { - resp, err = opts.withTargetRetrievalFunc(ctx, u.Address, t) + resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, oldRefreshToken) + if api.ErrInvalidRefreshToken.Is(err) { + if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil { + return errors.Wrap(ctx, err, op) + } + // try again without the refresh token + oldRefreshToken = "" + resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "") + } if err != nil { - // TODO: If we get an error about the token no longer having - // permissions, remove it. retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id))) continue } @@ -74,7 +89,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut break } if retErr != nil { - if saveErr := r.SaveError(ctx, u, resource.Target.String(), retErr); saveErr != nil { + if saveErr := r.SaveError(ctx, u, resourceType, retErr); saveErr != nil { return stderrors.Join(err, errors.Wrap(ctx, saveErr, op)) } } @@ -84,10 +99,17 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u)) _, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error { - // TODO: Instead of deleting everything, use refresh tokens and apply the delta - if _, err := w.Exec(ctx, "delete from target where user_id = @user_id", - []any{sql.Named("user_id", u.Id)}); err != nil { - return err + switch { + case oldRefreshToken == "": + if _, err := w.Exec(ctx, "delete from target where user_id = @user_id", + []any{sql.Named("user_id", u.Id)}); err != nil { + return err + } + case len(removedIds) > 0: + if _, err := w.Exec(ctx, "delete from target where id in @ids", + []any{sql.Named("ids", removedIds)}); err != nil { + return err + } } for _, t := range resp { @@ -111,6 +133,12 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut return err } } + + if newRefreshToken != "" || oldRefreshToken != "" { + if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil { + return err + } + } return nil }) if err != nil { diff --git a/internal/clientcache/internal/cache/repository_targets_test.go b/internal/clientcache/internal/cache/repository_targets_test.go index ed3a1a9da8..a6a87b92bb 100644 --- a/internal/clientcache/internal/cache/repository_targets_test.go +++ b/internal/clientcache/internal/cache/repository_targets_test.go @@ -110,13 +110,22 @@ func TestRepository_refreshTargets(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { err := r.refreshTargets(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(tc.targets))) + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{tc.targets}, [][]string{nil}))) if tc.errorContains == "" { assert.NoError(t, err) rw := db.New(s) var got []*Target require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil)) assert.Len(t, got, tc.wantCount) + + t.Cleanup(func() { + refTok := &refreshToken{ + UserId: tc.u.Id, + ResourceType: targetResourceType, + } + _, err := r.rw.Delete(ctx, refTok) + require.NoError(t, err) + }) } else { assert.ErrorContains(t, err, tc.errorContains) } @@ -124,6 +133,97 @@ func TestRepository_refreshTargets(t *testing.T) { } } +func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + addr := "address" + u := user{ + Id: "u1", + Address: addr, + } + at := &authtokens.AuthToken{ + Id: "at_1", + Token: "at_1_token", + UserId: u.Id, + } + kt := KeyringToken{ + KeyringType: "keyring", + TokenName: "token", + AuthTokenId: at.Id, + } + atMap := map[ringToken]*authtokens.AuthToken{ + {kt.KeyringType, kt.TokenName}: at, + } + r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap))) + require.NoError(t, err) + require.NoError(t, r.AddKeyringToken(ctx, addr, kt)) + + ts := [][]*targets.Target{ + { + { + Id: "ttcp_1", + Name: "name1", + Address: "address1", + Type: "tcp", + SessionMaxSeconds: 111, + }, + { + Id: "ttcp_2", + Name: "name2", + Address: "address2", + Type: "tcp", + SessionMaxSeconds: 222, + }, + }, { + { + Id: "ttcp_3", + Name: "name3", + Address: "address3", + Type: "tcp", + SessionMaxSeconds: 333, + }, + }, + } + + require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + + got, err := r.ListTargets(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 2) + + // Refreshing again uses the refresh token and get additional sessions, appending + // them to the response + require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + assert.NoError(t, err) + + got, err = r.ListTargets(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 3) + + // Refreshing again wont return any more resources, but also none should be + // removed + require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil})))) + assert.NoError(t, err) + + got, err = r.ListTargets(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 3) + + // Refresh again with the refresh token being reported as invalid. + require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"}, + WithTargetRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ts[0])))) + assert.NoError(t, err) + + got, err = r.ListTargets(ctx, at.Id) + require.NoError(t, err) + assert.Len(t, got, 2) +} + func TestRepository_ListTargets(t *testing.T) { ctx := context.Background() s, err := cachedb.Open(ctx) @@ -190,7 +290,7 @@ func TestRepository_ListTargets(t *testing.T) { }, } require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(ts)))) + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) t.Run("wrong user gets no targets", func(t *testing.T) { l, err := r.ListTargets(ctx, kt2.AuthTokenId) @@ -294,7 +394,7 @@ func TestRepository_QueryTargets(t *testing.T) { }, } require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"}, - WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(ts)))) + WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil})))) t.Run("wrong token gets no targets", func(t *testing.T) { l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query) @@ -321,8 +421,16 @@ func TestDefaultTargetRetrievalFunc(t *testing.T) { require.NoError(t, err) require.NotNil(t, tar2) - got, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token) + got, removed, refTok, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "") assert.NoError(t, err) + assert.NotEmpty(t, refTok) + assert.Empty(t, removed) assert.Contains(t, got, tar1.Item) assert.Contains(t, got, tar2.Item) + + got2, removed2, refTok2, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok) + assert.NoError(t, err) + assert.NotEmpty(t, refTok2) + assert.Empty(t, removed2) + assert.Empty(t, got2) } diff --git a/internal/clientcache/internal/cache/repository_test.go b/internal/clientcache/internal/cache/repository_test.go index dfd9dc3a4f..43e995ed0e 100644 --- a/internal/clientcache/internal/cache/repository_test.go +++ b/internal/clientcache/internal/cache/repository_test.go @@ -54,7 +54,7 @@ func TestRepository_SaveError(t *testing.T) { sliceBasedAuthTokenBoundaryReader(nil)) require.NoError(t, err) - testResource := "test_resource_type" + testResource := targetResourceType testErr := fmt.Errorf("test error for %q", testResource) u := &user{ diff --git a/internal/clientcache/internal/cache/store_test.go b/internal/clientcache/internal/cache/store_test.go index a4b01b163f..21d40443f3 100644 --- a/internal/clientcache/internal/cache/store_test.go +++ b/internal/clientcache/internal/cache/store_test.go @@ -135,6 +135,110 @@ func TestAuthToken_NoMoreKeyringTokens(t *testing.T) { assert.NoError(t, rw.LookupById(ctx, u)) } +func TestRefreshToken(t *testing.T) { + ctx := context.Background() + s, err := cachedb.Open(ctx) + require.NoError(t, err) + + rw := db.New(s) + + u := &user{ + Id: "userId", + Address: "address", + } + + t.Run("no user foreign key constraint", func(t *testing.T) { + tok := &refreshToken{ + UserId: u.Id, + ResourceType: targetResourceType, + RefreshToken: "something", + } + require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed") + }) + + require.NoError(t, rw.Create(ctx, u)) + + t.Run("no user id", func(t *testing.T) { + tok := &refreshToken{ + ResourceType: targetResourceType, + RefreshToken: "something", + } + require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed") + }) + + t.Run("unknown resource type", func(t *testing.T) { + tok := &refreshToken{ + UserId: u.Id, + ResourceType: "thisisntknown", + RefreshToken: "something", + } + require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed") + }) + + t.Run("empty refresh token", func(t *testing.T) { + tok := &refreshToken{ + UserId: u.Id, + ResourceType: "thisisntknown", + } + require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed") + }) + + t.Run("create", func(t *testing.T) { + tok := &refreshToken{ + UserId: u.Id, + ResourceType: targetResourceType, + RefreshToken: "something", + } + require.NoError(t, rw.Create(ctx, tok)) + require.NoError(t, rw.LookupById(ctx, tok)) + assert.NotEmpty(t, tok.RefreshToken) + }) + + t.Run("update", func(t *testing.T) { + u := &user{ + Id: "updatethis", + Address: "updated", + } + require.NoError(t, rw.Create(ctx, u)) + + tok := &refreshToken{ + UserId: u.Id, + ResourceType: targetResourceType, + RefreshToken: "started", + } + require.NoError(t, rw.Create(ctx, tok)) + + tok.RefreshToken = "updated" + n, err := rw.Update(ctx, tok, []string{"RefreshToken"}, nil) + assert.NoError(t, err) + assert.Equal(t, 1, n) + }) + + t.Run("delete user deletes token", func(t *testing.T) { + u := &user{ + Id: "deletethis", + Address: "deleted", + } + require.NoError(t, rw.Create(ctx, u)) + + tok := &refreshToken{ + UserId: u.Id, + ResourceType: targetResourceType, + RefreshToken: "deleted_soon", + } + require.NoError(t, rw.Create(ctx, tok)) + + _, err = rw.Exec(ctx, "delete from user where id = ?", []any{u.Id}) + + require.True(t, errors.IsNotFoundError(rw.LookupById(ctx, tok))) + }) + + // TODO: When gorm sqlite driver fixes it's delete, use rw.Delete instead of the Exec. + // n, err := rw.Delete(ctx, p) + _, err = rw.Exec(ctx, "delete from refresh_token", nil) + assert.NoError(t, err) +} + func TestAuthToken(t *testing.T) { ctx := context.Background() s, err := cachedb.Open(ctx) diff --git a/internal/clientcache/internal/daemon/testing.go b/internal/clientcache/internal/daemon/testing.go index e9fb44888d..e873838f26 100644 --- a/internal/clientcache/internal/daemon/testing.go +++ b/internal/clientcache/internal/daemon/testing.go @@ -75,17 +75,17 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars [] r, err := cache.NewRepository(ctx, s.CacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn) require.NoError(t, err) - tarFn := func(ctx context.Context, _, tok string) ([]*targets.Target, error) { + tarFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*targets.Target, []string, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil + return nil, nil, "", nil } - return tars, nil + return tars, nil, "", nil } - sessFn := func(ctx context.Context, _, tok string) ([]*sessions.Session, error) { + sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) { if tok != p.Token { - return nil, nil + return nil, nil, "", nil } - return sess, nil + return sess, nil, "", nil } rs, err := cache.NewRefreshService(ctx, r) require.NoError(t, err) diff --git a/internal/clientcache/internal/daemon/ticker.go b/internal/clientcache/internal/daemon/ticker.go index d2068bbe92..caff76d80e 100644 --- a/internal/clientcache/internal/daemon/ticker.go +++ b/internal/clientcache/internal/daemon/ticker.go @@ -14,7 +14,7 @@ import ( ) const ( - DefaultRefreshIntervalSeconds = 5 * 60 + DefaultRefreshIntervalSeconds = 15 defaultRefreshInterval = DefaultRefreshIntervalSeconds * time.Second ) diff --git a/internal/clientcache/internal/db/schema.sql b/internal/clientcache/internal/db/schema.sql index d6de11fd75..aa22a891c1 100644 --- a/internal/clientcache/internal/db/schema.sql +++ b/internal/clientcache/internal/db/schema.sql @@ -13,6 +13,33 @@ create table if not exists user ( check (length(address) > 0) ); +-- Contains the known resource types contained in the boundary client cache +create table if not exists resource_type_enm( + string text not null primary key + constraint only_predefined_resource_types_allowed + check(string in ('unknown', 'target', 'session')) +); + +insert into resource_type_enm (string) +values + ('unknown'), + ('target'), + ('session'); + +-- Contains refresh tokens for list requests sent by the client daemon to the +-- boundary instance. +create table if not exists refresh_token( + user_id text not null + references user(id) + on delete cascade, + resource_type text not null + references resource_type_enm(string) + constraint only_known_resource_types_allowed, + refresh_token text not null + check (length(refresh_token) > 0), + primary key (user_id, resource_type) +); + -- Contains the boundary auth token create table if not exists auth_token ( -- id is the boundary id of the auth token @@ -107,7 +134,9 @@ create table if not exists api_error ( user_id text not null references user(id) on delete cascade, - resource_type text not null, + resource_type text not null + references resource_type_enm(string) + constraint only_known_resource_types_allowed, error text not null, create_time timestamp not null default current_timestamp, primary key (user_id, resource_type) diff --git a/internal/daemon/controller/handlers/errors.go b/internal/daemon/controller/handlers/errors.go index e381afe7b7..6316723296 100644 --- a/internal/daemon/controller/handlers/errors.go +++ b/internal/daemon/controller/handlers/errors.go @@ -153,6 +153,25 @@ func InvalidArgumentErrorf(msg string, fields map[string]string) *ApiError { return apiErr } +func invalidRefreshTokenError(err error) *ApiError { + const op = "handlers.invalidRefreshTokenError" + ctx := context.TODO() + + var domainErr *errors.Err + if !errors.As(err, &domainErr) { + event.WriteError(ctx, op, err, event.WithInfoMsg("Unable to build invalid argument api error.")) + } + + return &ApiError{ + Status: http.StatusBadRequest, + Inner: &pb.Error{ + Kind: domainErr.Info().Message, + Op: string(domainErr.Op), + Message: domainErr.Msg, + }, + } +} + // ConflictErrorf generates an ApiErr when a pre-conditional check is violated. // Note, this deliberately doesn't translate to the similarly named '412 // Precondition Failed' HTTP response status. The ApiErr returned is a 400 bad @@ -199,6 +218,8 @@ func backendErrorToApiError(inErr error) *ApiError { return NotFoundErrorf(genericNotFoundMsg) case errors.Match(errors.T(errors.AccountAlreadyAssociated), inErr): return InvalidArgumentErrorf(inErr.Error(), nil) + case errors.Match(errors.T(errors.InvalidRefreshToken), inErr): + return invalidRefreshTokenError(inErr) case errors.Match(errors.T(errors.InvalidFieldMask), inErr), errors.Match(errors.T(errors.EmptyFieldMask), inErr): return InvalidArgumentErrorf("Error in provided request", map[string]string{"update_mask": "Invalid update mask provided."}) case errors.IsUniqueError(inErr): diff --git a/internal/daemon/controller/handlers/errors_test.go b/internal/daemon/controller/handlers/errors_test.go index 67b098edc5..d332555169 100644 --- a/internal/daemon/controller/handlers/errors_test.go +++ b/internal/daemon/controller/handlers/errors_test.go @@ -219,6 +219,18 @@ func TestApiErrorHandler(t *testing.T) { }, }, }, + { + name: "Invalid refresh token error", + err: errors.New(ctx, errors.InvalidRefreshToken, errors.Op("test.op"), "this is a test invalid refresh token error"), + expected: ApiError{ + Status: http.StatusBadRequest, + Inner: &pb.Error{ + Kind: "invalid refresh token", + Op: "test.op", + Message: "this is a test invalid refresh token error", + }, + }, + }, { name: "Wrapped forbidden domain error", err: fmt.Errorf("got error: %w", errors.E(ctx, errors.WithCode(errors.Forbidden), errors.WithMsg("test msg"))), diff --git a/internal/errors/code.go b/internal/errors/code.go index 62950b4f38..088eda5e95 100644 --- a/internal/errors/code.go +++ b/internal/errors/code.go @@ -67,6 +67,8 @@ const ( Closed = 134 // Closed represents an error when an operation cannot be completed because the thing being operated on is closed ChecksumMismatch = 135 // ChecksumMismatch represents an error when a checksum is mismatched + InvalidRefreshToken Code = 136 // InvalidRefreshToken represents an error where the provided refresh token is invalid + AuthAttemptExpired Code = 198 // AuthAttemptExpired represents an expired authentication attempt AuthMethodInactive Code = 199 // AuthMethodInactive represents an error that means the auth method is not active. diff --git a/internal/errors/code_test.go b/internal/errors/code_test.go index 4ad8be5b85..0153170dbe 100644 --- a/internal/errors/code_test.go +++ b/internal/errors/code_test.go @@ -415,6 +415,11 @@ func TestCode_Both_String_Info(t *testing.T) { c: InvalidConfiguration, want: InvalidConfiguration, }, + { + name: "InvalidRefreshToken", + c: InvalidRefreshToken, + want: InvalidRefreshToken, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/internal/errors/info.go b/internal/errors/info.go index d1951be66b..e62c3bbdfc 100644 --- a/internal/errors/info.go +++ b/internal/errors/info.go @@ -347,4 +347,8 @@ var errorCodeInfo = map[Code]Info{ Message: "invalid configuration", Kind: Configuration, }, + InvalidRefreshToken: { + Message: "invalid refresh token", + Kind: Parameter, + }, } diff --git a/internal/refreshtoken/refresh_token.go b/internal/refreshtoken/refresh_token.go new file mode 100644 index 0000000000..3f744b5867 --- /dev/null +++ b/internal/refreshtoken/refresh_token.go @@ -0,0 +1,133 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +// The refreshtoken package encapsulates domain logic surrounding +// list endpoint refresh tokens. Refresh tokens are used when users +// paginate through results in our list endpoints, and also to +// allow users to request new, updated and deleted resources. +package refreshtoken + +import ( + "bytes" + "context" + "time" + + "github.com/hashicorp/boundary/internal/boundary" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/types/resource" +) + +// A Token is returned in list endpoints for the purposes of pagination +type Token struct { + CreatedTime time.Time + UpdatedTime time.Time + ResourceType resource.Type + GrantsHash []byte + LastItemId string + LastItemUpdatedTime time.Time +} + +// New creates a new refresh token from a createdTime, resource type, grants hash, and last item information +func New(ctx context.Context, createdTime time.Time, updatedTime time.Time, typ resource.Type, grantsHash []byte, lastItemId string, lastItemUpdatedTime time.Time) (*Token, error) { + const op = "refreshtoken.New" + + if len(grantsHash) == 0 { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash") + } + if createdTime.After(time.Now()) { + return nil, errors.New(ctx, errors.InvalidParameter, op, "created time is in the future") + } + if createdTime.Before(time.Now().AddDate(0, 0, -30)) { + return nil, errors.New(ctx, errors.InvalidParameter, op, "created time is too old") + } + if updatedTime.Before(createdTime) { + return nil, errors.New(ctx, errors.InvalidParameter, op, "updated time is older than created time") + } + if updatedTime.After(time.Now()) { + return nil, errors.New(ctx, errors.InvalidParameter, op, "updated time is in the future") + } + if lastItemId == "" { + return nil, errors.New(ctx, errors.InvalidParameter, op, "missing last item ID") + } + if lastItemUpdatedTime.After(time.Now()) { + return nil, errors.New(ctx, errors.InvalidParameter, op, "last item updated time is in the future") + } + + return &Token{ + CreatedTime: createdTime, + UpdatedTime: updatedTime, + ResourceType: typ, + GrantsHash: grantsHash, + LastItemId: lastItemId, + LastItemUpdatedTime: lastItemUpdatedTime, + }, nil +} + +// FromResource creates a new refresh token from a resource and grants hash +func FromResource(res boundary.Resource, grantsHash []byte) *Token { + t := time.Now() + return &Token{ + CreatedTime: t, + UpdatedTime: t, + ResourceType: res.GetResourceType(), + GrantsHash: grantsHash, + LastItemId: res.GetPublicId(), + LastItemUpdatedTime: res.GetUpdateTime().AsTime(), + } +} + +// Refresh refreshes a token's updated time +func (rt *Token) Refresh(updatedTime time.Time) *Token { + rt.UpdatedTime = updatedTime + return rt +} + +// RefreshLastItem refreshes a token's updated time and last item +func (rt *Token) RefreshLastItem(res boundary.Resource, updatedTime time.Time) *Token { + rt.UpdatedTime = updatedTime + rt.LastItemId = res.GetPublicId() + rt.LastItemUpdatedTime = res.GetUpdateTime().AsTime() + return rt +} + +// Validate validates the refresh token. +func (rt *Token) Validate( + ctx context.Context, + expectedResourceType resource.Type, + expectedGrantsHash []byte, +) error { + const op = "refreshtoken.Validate" + if rt == nil { + return errors.New(ctx, errors.InvalidParameter, op, "refresh token was missing") + } + if len(rt.GrantsHash) == 0 { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was missing its grants hash") + } + if !bytes.Equal(rt.GrantsHash, expectedGrantsHash) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "grants have changed since refresh token was issued") + } + if rt.CreatedTime.After(time.Now()) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was created in the future") + } + // Tokens older than 30 days have expired + if rt.CreatedTime.Before(time.Now().AddDate(0, 0, -30)) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was expired") + } + if rt.UpdatedTime.Before(rt.CreatedTime) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was updated before its creation time") + } + if rt.UpdatedTime.After(time.Now()) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was updated in the future") + } + if rt.LastItemId == "" { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token missing last item ID") + } + if rt.LastItemUpdatedTime.After(time.Now()) { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token last item was updated in the future") + } + if rt.ResourceType != expectedResourceType { + return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token resource type does not match expected resource type") + } + + return nil +} diff --git a/internal/refreshtoken/refresh_token_test.go b/internal/refreshtoken/refresh_token_test.go new file mode 100644 index 0000000000..b3f18e71ec --- /dev/null +++ b/internal/refreshtoken/refresh_token_test.go @@ -0,0 +1,398 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package refreshtoken_test + +import ( + "context" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/boundary/internal/boundary" + "github.com/hashicorp/boundary/internal/db/timestamp" + "github.com/hashicorp/boundary/internal/errors" + "github.com/hashicorp/boundary/internal/refreshtoken" + "github.com/hashicorp/boundary/internal/types/resource" + "github.com/stretchr/testify/require" +) + +func Test_ValidateRefreshToken(t *testing.T) { + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + tests := []struct { + name string + token *refreshtoken.Token + grantsHash []byte + resourceType resource.Type + wantErrString string + wantErrCode errors.Code + }{ + { + name: "valid token", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + }, + { + name: "nil token", + token: nil, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was missing", + wantErrCode: errors.InvalidParameter, + }, + { + name: "no grants hash", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: nil, + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was missing its grants hash", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "changed grants hash", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some other hash"), + resourceType: resource.Target, + wantErrString: "grants have changed since refresh token was issued", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "created in the future", + token: &refreshtoken.Token{ + CreatedTime: time.Now().AddDate(1, 0, 0), + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was created in the future", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "expired", + token: &refreshtoken.Token{ + CreatedTime: time.Now().AddDate(0, 0, -31), + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was expired", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "updated before created", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, -1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was updated before its creation time", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "updated after now", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: time.Now().AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token was updated in the future", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "resource type mismatch", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.SessionRecording, + wantErrString: "refresh token resource type does not match expected resource type", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "last item ID unset", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token missing last item ID", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "last item ID unset", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "", + LastItemUpdatedTime: fiveDaysAgo, + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token missing last item ID", + wantErrCode: errors.InvalidRefreshToken, + }, + { + name: "updated in the future", + token: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "s_1234567890", + LastItemUpdatedTime: time.Now().AddDate(1, 0, 0), + }, + grantsHash: []byte("some hash"), + resourceType: resource.Target, + wantErrString: "refresh token last item was updated in the future", + wantErrCode: errors.InvalidRefreshToken, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.token.Validate(context.Background(), tt.resourceType, tt.grantsHash) + if tt.wantErrString != "" { + require.ErrorContains(t, err, tt.wantErrString) + require.Equal(t, errors.Convert(err).Code, tt.wantErrCode) + return + } + require.NoError(t, err) + }) + } +} + +func TestNew(t *testing.T) { + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + tests := []struct { + name string + createdTime time.Time + updatedTime time.Time + typ resource.Type + grantsHash []byte + lastItemId string + lastItemUpdatedTime time.Time + want *refreshtoken.Token + wantErrString string + wantErrCode errors.Code + }{ + { + name: "valid refresh token", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + want: &refreshtoken.Token{ + CreatedTime: fiveDaysAgo, + UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1), + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "some id", + LastItemUpdatedTime: fiveDaysAgo, + }, + }, + { + name: "missing grants hash", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: nil, + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "missing grants hash", + wantErrCode: errors.InvalidParameter, + }, + { + name: "new created time", + createdTime: fiveDaysAgo.AddDate(1, 0, 0), + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "created time is in the future", + wantErrCode: errors.InvalidParameter, + }, + { + name: "old created time", + createdTime: fiveDaysAgo.AddDate(-1, 0, 0), + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "created time is too old", + wantErrCode: errors.InvalidParameter, + }, + { + name: "new updated time", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(1, 0, 0), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "updated time is in the future", + wantErrCode: errors.InvalidParameter, + }, + { + name: "updated time older than created time", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(0, 0, -11), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "updated time is older than created time", + wantErrCode: errors.InvalidParameter, + }, + { + name: "missing last item id", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "", + lastItemUpdatedTime: fiveDaysAgo, + wantErrString: "missing last item ID", + wantErrCode: errors.InvalidParameter, + }, + { + name: "new last item updated time", + createdTime: fiveDaysAgo, + updatedTime: fiveDaysAgo.AddDate(0, 0, 1), + typ: resource.Target, + grantsHash: []byte("some hash"), + lastItemId: "some id", + lastItemUpdatedTime: fiveDaysAgo.AddDate(1, 0, 0), + wantErrString: "last item updated time is in the future", + wantErrCode: errors.InvalidParameter, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := refreshtoken.New(context.Background(), tt.createdTime, tt.updatedTime, tt.typ, tt.grantsHash, tt.lastItemId, tt.lastItemUpdatedTime) + if tt.wantErrString != "" { + require.ErrorContains(t, err, tt.wantErrString) + require.Equal(t, errors.Convert(err).Code, tt.wantErrCode) + return + } + require.NoError(t, err) + require.Empty(t, cmp.Diff(got, tt.want)) + }) + } +} + +type fakeTargetResource struct { + boundary.Resource + + publicId string + updateTime *timestamp.Timestamp +} + +func (m *fakeTargetResource) GetResourceType() resource.Type { + return resource.Target +} + +func (m *fakeTargetResource) GetPublicId() string { + return m.publicId +} + +func (m *fakeTargetResource) GetUpdateTime() *timestamp.Timestamp { + return m.updateTime +} + +func TestFromResource(t *testing.T) { + fiveDaysAgo := time.Now().AddDate(0, 0, -5) + res := &fakeTargetResource{ + publicId: "tcp_1234567890", + updateTime: timestamp.New(fiveDaysAgo), + } + + tok := refreshtoken.FromResource(res, []byte("some hash")) + + // Check that it's within 1 second of now according to the system + // If this is flaky... just increase the limit 😬. + require.True(t, tok.CreatedTime.Before(time.Now().Add(time.Second))) + require.True(t, tok.CreatedTime.After(time.Now().Add(-time.Second))) + require.True(t, tok.UpdatedTime.Before(time.Now().Add(time.Second))) + require.True(t, tok.UpdatedTime.After(time.Now().Add(-time.Second))) + + require.Equal(t, tok.ResourceType, res.GetResourceType()) + require.Equal(t, tok.GrantsHash, []byte("some hash")) + require.Equal(t, tok.LastItemId, res.GetPublicId()) + require.True(t, tok.LastItemUpdatedTime.Equal(res.GetUpdateTime().AsTime())) +} + +func TestRefresh(t *testing.T) { + createdTime := time.Now().AddDate(0, 0, -5) + updatedTime := time.Now() + tok := &refreshtoken.Token{ + CreatedTime: createdTime, + UpdatedTime: createdTime, + ResourceType: resource.Target, + GrantsHash: []byte("some hash"), + LastItemId: "tcp_1234567890", + LastItemUpdatedTime: createdTime, + } + newTok := tok.Refresh(updatedTime) + + require.True(t, newTok.UpdatedTime.Equal(updatedTime)) + require.True(t, newTok.CreatedTime.Equal(createdTime)) + require.Equal(t, newTok.ResourceType, tok.ResourceType) + require.Equal(t, newTok.GrantsHash, tok.GrantsHash) + require.Equal(t, newTok.LastItemId, tok.LastItemId) + require.True(t, newTok.LastItemUpdatedTime.Equal(tok.LastItemUpdatedTime)) +}