Store and use refresh tokens in the client cache (#3857)

* Refresh resources using refresh tokens
pull/4202/head
Todd 3 years ago committed by Johan Brandhorst-Satzkorn
parent 5daef58f54
commit b2fec2e4f0

@ -11,10 +11,11 @@ import (
)
var (
ErrNotFound = &Error{Kind: codes.NotFound.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusNotFound}}}
ErrInvalidArgument = &Error{Kind: codes.InvalidArgument.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}}
ErrPermissionDenied = &Error{Kind: codes.PermissionDenied.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusForbidden}}}
ErrUnauthorized = &Error{Kind: codes.Unauthenticated.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusUnauthorized}}}
ErrNotFound = &Error{Kind: codes.NotFound.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusNotFound}}}
ErrInvalidArgument = &Error{Kind: codes.InvalidArgument.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}}
ErrPermissionDenied = &Error{Kind: codes.PermissionDenied.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusForbidden}}}
ErrUnauthorized = &Error{Kind: codes.Unauthenticated.String(), response: &Response{resp: &http.Response{StatusCode: http.StatusUnauthorized}}}
ErrInvalidRefreshToken = &Error{Kind: "invalid refresh token", response: &Response{resp: &http.Response{StatusCode: http.StatusBadRequest}}}
)
// AsServerError returns an api *Error from the provided error. If the provided error

@ -33,7 +33,9 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(t, opts, testOpts)
})
t.Run("WithTargetRetrievalFunc", func(t *testing.T) {
var f TargetRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*targets.Target, error) { return nil, nil }
var f TargetRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) {
return nil, nil, "", nil
}
opts, err := getOpts(WithTargetRetrievalFunc(f))
require.NoError(t, err)
@ -44,8 +46,8 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(t, opts, testOpts)
})
t.Run("WithSessionRetrievalFunc", func(t *testing.T) {
var f SessionRetrievalFunc = func(ctx context.Context, keyringstring, tokenName string) ([]*sessions.Session, error) {
return nil, nil
var f SessionRetrievalFunc = func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) {
return nil, nil, "", nil
}
opts, err := getOpts(WithSessionRetrievalFunc(f))
require.NoError(t, err)

@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"strconv"
"sync"
"testing"
"time"
@ -24,10 +25,42 @@ import (
// testStaticResourceRetrievalFunc returns a function that always returns the
// provided slice and a nil error. The returned function can be passed into the
// options that provide a resource retrieval func such as
// WithTargetRetrievalFunc and WithSessionRetrievalFunc.
func testStaticResourceRetrievalFunc[T any](ret []T) func(context.Context, string, string) ([]T, error) {
return func(ctx context.Context, s1, s2 string) ([]T, error) {
return ret, nil
// WithTargetRetrievalFunc and WithSessionRetrievalFunc. The provided refresh
// token determines the returned value and is a string representation of an
// incrementing integer. This integer is the index into the provided return
// values and once it reaches the length of the provided slice it returns an
// empty slice and the same refresh token repeatedly.
func testStaticResourceRetrievalFunc[T any](t *testing.T, ret [][]T, removed [][]string) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
t.Helper()
require.Equal(t, len(ret), len(removed), "returned slice and removed slice must be the same length")
return func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
index := 0
if refToken != "" {
var err error
index, err = strconv.Atoi(string(refToken))
require.NoError(t, err)
}
switch {
case len(ret) == 0:
return nil, nil, "", nil
case index > 0 && index >= len(ret):
return []T{}, []string{}, RefreshTokenValue(fmt.Sprintf("%d", index)), nil
default:
return ret[index], removed[index], RefreshTokenValue(fmt.Sprintf("%d", index+1)), nil
}
}
}
// testErroringForRefreshTokenRetrievalFunc returns a refresh token error when
// the refresh token is not empty. This is useful for testing behavior when
// the refresh token has expired or is otherwise invalid.
func testErroringForRefreshTokenRetrievalFunc[T any](t *testing.T, ret []T) func(context.Context, string, string, RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
return func(ctx context.Context, s1, s2 string, refToken RefreshTokenValue) ([]T, []string, RefreshTokenValue, error) {
if refToken != "" {
return nil, nil, "", api.ErrInvalidRefreshToken
}
return ret, nil, "1", nil
}
}
@ -297,32 +330,32 @@ func TestRefresh(t *testing.T) {
target("1"),
target("2"),
target("3"),
target("4"),
}
assert.NoError(t, rs.Refresh(ctx,
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
return retTargets, nil
})))
opts := []Option{
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)),
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t,
[][]*targets.Target{
retTargets[:3],
retTargets[3:],
},
[][]string{
nil,
{retTargets[0].Id, retTargets[1].Id},
},
)),
}
assert.NoError(t, rs.Refresh(ctx, opts...))
cachedTargets, err := r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
assert.ElementsMatch(t, retTargets, cachedTargets)
t.Run("empty response clears it out", func(t *testing.T) {
assert.NoError(t, rs.Refresh(ctx,
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
return nil, nil
})))
cachedTargets, err := r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, cachedTargets)
})
assert.ElementsMatch(t, retTargets[:3], cachedTargets)
// Second call removes the first 2 resources from the cache and adds the last
assert.NoError(t, rs.Refresh(ctx, opts...))
cachedTargets, err = r.ListTargets(ctx, at.Id)
assert.NoError(t, err)
assert.ElementsMatch(t, retTargets[2:], cachedTargets)
})
t.Run("set sessions", func(t *testing.T) {
@ -330,42 +363,49 @@ func TestRefresh(t *testing.T) {
session("1"),
session("2"),
session("3"),
session("4"),
}
assert.NoError(t, rs.Refresh(ctx,
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)),
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(retSess))))
opts := []Option{
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)),
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t,
[][]*sessions.Session{
retSess[:3],
retSess[3:],
},
[][]string{
nil,
{retSess[0].Id, retSess[1].Id},
},
)),
}
assert.NoError(t, rs.Refresh(ctx, opts...))
cachedSessions, err := r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.ElementsMatch(t, retSess, cachedSessions)
t.Run("empty response clears it out", func(t *testing.T) {
assert.NoError(t, rs.Refresh(ctx,
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)),
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil))))
assert.ElementsMatch(t, retSess[:3], cachedSessions)
cachedTargets, err := r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.Empty(t, cachedTargets)
})
// Second call removes the first 2 resources from the cache and adds the last
assert.NoError(t, rs.Refresh(ctx, opts...))
cachedSessions, err = r.ListSessions(ctx, at.Id)
assert.NoError(t, err)
assert.ElementsMatch(t, retSess[2:], cachedSessions)
})
t.Run("error propogates up", func(t *testing.T) {
innerErr := errors.New("test error")
err := rs.Refresh(ctx,
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string) ([]*targets.Target, error) {
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)),
WithTargetRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
return nil, innerErr
return nil, nil, "", innerErr
}))
assert.ErrorContains(t, err, innerErr.Error())
err = rs.Refresh(ctx,
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)),
WithSessionRetrievalFunc(func(ctx context.Context, addr, token string) ([]*sessions.Session, error) {
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)),
WithSessionRetrievalFunc(func(ctx context.Context, addr, token string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) {
require.Equal(t, boundaryAddr, addr)
require.Equal(t, at.Token, token)
return nil, innerErr
return nil, nil, "", innerErr
}))
assert.ErrorContains(t, err, innerErr.Error())
})
@ -385,8 +425,8 @@ func TestRefresh(t *testing.T) {
assert.Len(t, us, 1)
rs.Refresh(ctx,
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](nil)),
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](nil)))
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc[*sessions.Session](t, nil, nil)),
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc[*targets.Target](t, nil, nil)))
ps, err = r.listTokens(ctx, u)
require.NoError(t, err)

@ -56,7 +56,7 @@ func NewRepository(ctx context.Context, conn *db.DB, idToAuthToken *sync.Map, ke
}, nil
}
func (r *Repository) SaveError(ctx context.Context, u *user, resourceType string, err error) error {
func (r *Repository) SaveError(ctx context.Context, u *user, resourceType resourceType, err error) error {
const op = "cache.(Repository).StoreError"
switch {
case resourceType == "":
@ -70,7 +70,7 @@ func (r *Repository) SaveError(ctx context.Context, u *user, resourceType string
}
apiErr := &ApiError{
UserId: u.Id,
ResourceType: resourceType,
ResourceType: string(resourceType),
Error: err.Error(),
}
onConflict := db.OnConflict{

@ -0,0 +1,140 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"context"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/go-dbw"
)
// RefreshTokenValue is the the type for the actual refresh token value handled
// by the client cache.
type RefreshTokenValue string
// lookupRefreshToken returns the last known valid refresh token or an empty
// string if one is unkonwn. No error is returned if no valid refresh token is
// found.
func (r *Repository) lookupRefreshToken(ctx context.Context, u *user, resourceType resourceType) (RefreshTokenValue, error) {
const op = "cache.(Repsoitory).lookupRefreshToken"
switch {
case util.IsNil(u):
return "", errors.New(ctx, errors.InvalidParameter, op, "user is nil")
case u.Id == "":
return "", errors.New(ctx, errors.InvalidParameter, op, "user id is empty")
case !resourceType.valid():
return "", errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid")
}
rt := &refreshToken{
UserId: u.Id,
ResourceType: resourceType,
}
if err := r.rw.LookupById(ctx, rt); err != nil {
if errors.Is(err, dbw.ErrRecordNotFound) {
return "", nil
}
return "", errors.Wrap(ctx, err, op)
}
return rt.RefreshToken, nil
}
// deleteRefreshToken deletes the refresh token for the provided user and resource type
func (r *Repository) deleteRefreshToken(ctx context.Context, u *user, rType resourceType) error {
const op = "cache.(Repository).deleteRefreshToken"
switch {
case util.IsNil(u):
return errors.New(ctx, errors.InvalidParameter, op, "user is nil")
case u.Id == "":
return errors.New(ctx, errors.InvalidParameter, op, "user id is empty")
case !rType.valid():
return errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid")
}
_, err := r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
rt := &refreshToken{
UserId: u.Id,
ResourceType: rType,
}
n, err := w.Delete(ctx, rt)
if err != nil {
return errors.Wrap(ctx, err, op)
}
if n > 1 {
return errors.New(ctx, errors.MultipleRecords, op, "attempted to delete a single resource but multiple were resource deletions were attempted")
}
return nil
})
if err != nil {
return err
}
return nil
}
func upsertRefreshToken(ctx context.Context, writer db.Writer, u *user, rt resourceType, tok RefreshTokenValue) error {
const op = "cache.upsertRefreshToken"
switch {
case util.IsNil(writer):
return errors.New(ctx, errors.InvalidParameter, op, "writer is nil")
case !writer.IsTx(ctx):
return errors.New(ctx, errors.InvalidParameter, op, "writer isn't in a transaction")
case util.IsNil(u):
return errors.New(ctx, errors.InvalidParameter, op, "user is nil")
case u.Id == "":
return errors.New(ctx, errors.InvalidParameter, op, "user id is empty")
case !rt.valid():
return errors.New(ctx, errors.InvalidParameter, op, "resource type is invalid")
}
refTok := &refreshToken{
UserId: u.Id,
ResourceType: rt,
RefreshToken: tok,
}
switch tok {
case "":
writer.Delete(ctx, refTok)
default:
onConflict := &db.OnConflict{
Target: db.Columns{"user_id", "resource_type"},
Action: db.SetColumns([]string{"refresh_token"}),
}
if err := writer.Create(ctx, refTok, db.WithOnConflict(onConflict)); err != nil {
return errors.Wrap(ctx, err, op)
}
}
return nil
}
type resourceType string
const (
unknownResourceType resourceType = "unknown"
targetResourceType resourceType = "target"
sessionResourceType resourceType = "session"
)
func (r resourceType) valid() bool {
switch r {
case targetResourceType, sessionResourceType:
return true
}
return false
}
type refreshToken struct {
UserId string `gorm:"primaryKey"`
ResourceType resourceType `gorm:"primaryKey"`
RefreshToken RefreshTokenValue
}
func (*refreshToken) TableName() string {
return "refresh_token"
}

@ -0,0 +1,123 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package cache
import (
"context"
"sync"
"testing"
"github.com/hashicorp/boundary/api/authtokens"
cachedb "github.com/hashicorp/boundary/internal/clientcache/internal/db"
"github.com/hashicorp/boundary/internal/db"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestLookupRefreshToken(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
r, err := NewRepository(ctx, s, &sync.Map{},
mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}),
sliceBasedAuthTokenBoundaryReader(nil))
require.NoError(t, err)
t.Run("nil user", func(t *testing.T) {
_, err := r.lookupRefreshToken(ctx, nil, targetResourceType)
assert.Error(t, err)
assert.ErrorContains(t, err, "user is nil")
})
t.Run("user id is empty", func(t *testing.T) {
_, err := r.lookupRefreshToken(ctx, &user{Address: "addr"}, targetResourceType)
assert.Error(t, err)
assert.ErrorContains(t, err, "user id is empty")
})
t.Run("resource type is invalid", func(t *testing.T) {
_, err := r.lookupRefreshToken(ctx, &user{Id: "something", Address: "addr"}, resourceType("invalid"))
assert.Error(t, err)
assert.ErrorContains(t, err, "resource type is invalid")
})
t.Run("unknown user", func(t *testing.T) {
got, err := r.lookupRefreshToken(ctx, &user{Id: "unkonwnUser", Address: "addr"}, targetResourceType)
assert.NoError(t, err)
assert.Empty(t, got)
})
t.Run("no refresh token", func(t *testing.T) {
known := &user{Id: "known", Address: "addr"}
require.NoError(t, r.rw.Create(ctx, known))
got, err := r.lookupRefreshToken(ctx, known, targetResourceType)
assert.NoError(t, err)
assert.Empty(t, got)
})
t.Run("got refresh token", func(t *testing.T) {
token := RefreshTokenValue("something")
known := &user{Id: "withrefreshtoken", Address: "addr"}
require.NoError(t, r.rw.Create(ctx, known))
r.rw.DoTx(ctx, 1, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
require.NoError(t, upsertRefreshToken(ctx, w, known, targetResourceType, token))
return nil
})
got, err := r.lookupRefreshToken(ctx, known, targetResourceType)
assert.NoError(t, err)
assert.Equal(t, token, got)
})
}
func TestDeleteRefreshTokens(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
r, err := NewRepository(ctx, s, &sync.Map{},
mapBasedAuthTokenKeyringLookup(map[ringToken]*authtokens.AuthToken{}),
sliceBasedAuthTokenBoundaryReader(nil))
require.NoError(t, err)
t.Run("nil user", func(t *testing.T) {
err := r.deleteRefreshToken(ctx, nil, targetResourceType)
assert.Error(t, err)
assert.ErrorContains(t, err, "user is nil")
})
t.Run("no user id", func(t *testing.T) {
err := r.deleteRefreshToken(ctx, &user{Address: "addr"}, targetResourceType)
assert.Error(t, err)
assert.ErrorContains(t, err, "user id is empty")
})
t.Run("invalid resource type", func(t *testing.T) {
err := r.deleteRefreshToken(ctx, &user{Id: "id", Address: "addr"}, "this is invalid")
assert.Error(t, err)
assert.ErrorContains(t, err, "resource type is invalid")
})
t.Run("success", func(t *testing.T) {
u := &user{Id: "id", Address: "addr"}
require.NoError(t, r.rw.Create(ctx, u))
r.rw.DoTx(ctx, 1, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
require.NoError(t, upsertRefreshToken(ctx, w, u, targetResourceType, "token"))
return nil
})
got, err := r.lookupRefreshToken(ctx, u, targetResourceType)
require.NoError(t, err)
require.NotEmpty(t, got)
assert.NoError(t, r.deleteRefreshToken(ctx, u, targetResourceType))
got, err = r.lookupRefreshToken(ctx, u, targetResourceType)
require.NoError(t, err)
require.Empty(t, got)
})
}

@ -15,30 +15,32 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/mql"
)
// SessionRetrievalFunc is a function that retrieves sessions
// from the provided boundary addr using the provided token.
type SessionRetrievalFunc func(ctx context.Context, addr, token string) ([]*sessions.Session, error)
type SessionRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*sessions.Session, removedIds []string, refreshToken RefreshTokenValue, err error)
func defaultSessionFunc(ctx context.Context, addr, token string) ([]*sessions.Session, error) {
func defaultSessionFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*sessions.Session, []string, RefreshTokenValue, error) {
const op = "cache.defaultSessionFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
Token: authTok,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
return nil, nil, "", errors.Wrap(ctx, err, op)
}
sClient := sessions.NewClient(client)
l, err := sClient.List(ctx, "global", sessions.WithRecursive(true))
l, err := sClient.List(ctx, "global", sessions.WithRecursive(true), sessions.WithRefreshToken(string(refreshTok)))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
if api.ErrInvalidRefreshToken.Is(err) {
return nil, nil, "", err
}
return nil, nil, "", errors.Wrap(ctx, err, op)
}
return l.Items, nil
return l.Items, l.RemovedIds, RefreshTokenValue(l.RefreshToken), nil
}
func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error {
@ -51,6 +53,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
case u.Address == "":
return errors.New(ctx, errors.InvalidParameter, op, "user boundary address is missing")
}
const resourceType = sessionResourceType
opts, err := getOpts(opt...)
if err != nil {
@ -59,16 +62,28 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
if opts.withSessionRetrievalFunc == nil {
opts.withSessionRetrievalFunc = defaultSessionFunc
}
oldRefreshToken, err := r.lookupRefreshToken(ctx, u, resourceType)
if err != nil {
return errors.Wrap(ctx, err, op)
}
// Find and use a token for retrieving sessions
var gotResponse bool
var resp []*sessions.Session
var newRefreshToken RefreshTokenValue
var removedIds []string
var retErr error
for at, t := range tokens {
resp, err = opts.withSessionRetrievalFunc(ctx, u.Address, t)
resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, oldRefreshToken)
if api.ErrInvalidRefreshToken.Is(err) {
if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil {
return errors.Wrap(ctx, err, op)
}
// try again without the refresh token
oldRefreshToken = ""
resp, removedIds, newRefreshToken, err = opts.withSessionRetrievalFunc(ctx, u.Address, t, "")
}
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
@ -77,7 +92,7 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
}
if retErr != nil {
if saveErr := r.SaveError(ctx, u, resource.Session.String(), retErr); saveErr != nil {
if saveErr := r.SaveError(ctx, u, resourceType, retErr); saveErr != nil {
return stderrors.Join(err, errors.Wrap(ctx, saveErr, op))
}
}
@ -87,10 +102,17 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d sessions for user %v", len(resp), u))
_, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
// TODO: Instead of deleting everything, use refresh tokens and apply the delta
if _, err := w.Exec(ctx, "delete from session where user_id = @user_id",
[]any{sql.Named("user_id", u.Id)}); err != nil {
return err
switch {
case oldRefreshToken == "":
if _, err := w.Exec(ctx, "delete from session where user_id = @user_id",
[]any{sql.Named("user_id", u.Id)}); err != nil {
return err
}
case len(removedIds) > 0:
if _, err := w.Exec(ctx, "delete from session where id in @ids",
[]any{sql.Named("ids", removedIds)}); err != nil {
return err
}
}
for _, s := range resp {
@ -114,6 +136,10 @@ func (r *Repository) refreshSessions(ctx context.Context, u *user, tokens map[Au
return err
}
}
if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil {
return err
}
return nil
})
if err != nil {

@ -116,13 +116,22 @@ func TestRepository_refreshSessions(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := r.refreshSessions(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(tc.sess)))
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{tc.sess}, [][]string{nil})))
if tc.errorContains == "" {
assert.NoError(t, err)
rw := db.New(s)
var got []*Session
require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil))
assert.Len(t, got, tc.wantCount)
t.Cleanup(func() {
refTok := &refreshToken{
UserId: tc.u.Id,
ResourceType: sessionResourceType,
}
_, err := r.rw.Delete(ctx, refTok)
require.NoError(t, err)
})
} else {
assert.ErrorContains(t, err, tc.errorContains)
}
@ -130,6 +139,96 @@ func TestRepository_refreshSessions(t *testing.T) {
}
}
func TestRepository_RefreshSessions_withRefreshTokens(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
addr := "address"
u := user{
Id: "u1",
Address: addr,
}
at := &authtokens.AuthToken{
Id: "at_1",
Token: "at_1_token",
UserId: u.Id,
}
kt := KeyringToken{
KeyringType: "keyring",
TokenName: "token",
AuthTokenId: at.Id,
}
atMap := map[ringToken]*authtokens.AuthToken{
{kt.KeyringType, kt.TokenName}: at,
}
r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap)))
require.NoError(t, err)
require.NoError(t, r.AddKeyringToken(ctx, addr, kt))
ss := [][]*sessions.Session{
{
{
Id: "ttcp_1",
Status: "status1",
Endpoint: "address1",
Type: "tcp",
},
{
Id: "ttcp_2",
Status: "status2",
Endpoint: "address2",
Type: "tcp",
},
},
{
{
Id: "ttcp_3",
Status: "status3",
Endpoint: "address3",
Type: "tcp",
},
},
}
err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))
assert.NoError(t, err)
got, err := r.ListSessions(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 2)
// Refreshing again uses the refresh token and get additional sessions, appending
// them to the response
err = r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil})))
assert.NoError(t, err)
got, err = r.ListSessions(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 3)
// Refreshing again wont return any more resources, but also none should be
// removed
require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, ss, [][]string{nil, nil}))))
assert.NoError(t, err)
got, err = r.ListSessions(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 3)
// Refresh again with the refresh token being reported as invalid.
require.NoError(t, r.refreshSessions(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ss[0]))))
assert.NoError(t, err)
got, err = r.ListSessions(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 2)
}
func TestRepository_ListSessions(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
@ -200,7 +299,7 @@ func TestRepository_ListSessions(t *testing.T) {
},
}
require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(ss))))
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))
t.Run("wrong user gets no sessions", func(t *testing.T) {
l, err := r.ListSessions(ctx, kt2.AuthTokenId)
@ -307,7 +406,7 @@ func TestRepository_QuerySessions(t *testing.T) {
},
}
require.NoError(t, r.refreshSessions(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"},
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(ss))))
WithSessionRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*sessions.Session{ss}, [][]string{nil}))))
t.Run("wrong token gets no sessions", func(t *testing.T) {
l, err := r.QuerySessions(ctx, kt2.AuthTokenId, query)
@ -337,7 +436,15 @@ func TestDefaultSessionRetrievalFunc(t *testing.T) {
_, err = tarClient.AuthorizeSession(tc.Context(), tar1.Item.Id)
assert.NoError(t, err)
got, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token)
got, removed, refTok, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "")
assert.NoError(t, err)
assert.NotEmpty(t, refTok)
assert.Empty(t, removed)
assert.Len(t, got, 1)
got2, removed2, refTok2, err := defaultSessionFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok)
assert.NoError(t, err)
assert.NotEmpty(t, refTok2)
assert.Empty(t, removed2)
assert.Empty(t, got2)
}

@ -15,30 +15,32 @@ import (
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/event"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/boundary/internal/util"
"github.com/hashicorp/mql"
)
// TargetRetrievalFunc is a function that retrieves targets
// from the provided boundary addr using the provided token.
type TargetRetrievalFunc func(ctx context.Context, addr, token string) ([]*targets.Target, error)
type TargetRetrievalFunc func(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) (ret []*targets.Target, removedIds []string, refreshToken RefreshTokenValue, err error)
func defaultTargetFunc(ctx context.Context, addr, token string) ([]*targets.Target, error) {
func defaultTargetFunc(ctx context.Context, addr, authTok string, refreshTok RefreshTokenValue) ([]*targets.Target, []string, RefreshTokenValue, error) {
const op = "cache.defaultTargetFunc"
client, err := api.NewClient(&api.Config{
Addr: addr,
Token: token,
Token: authTok,
})
if err != nil {
return nil, errors.Wrap(ctx, err, op)
return nil, nil, "", errors.Wrap(ctx, err, op)
}
tarClient := targets.NewClient(client)
l, err := tarClient.List(ctx, "global", targets.WithRecursive(true))
l, err := tarClient.List(ctx, "global", targets.WithRecursive(true), targets.WithRefreshToken(string(refreshTok)))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
if api.ErrInvalidRefreshToken.Is(err) {
return nil, nil, "", err
}
return nil, nil, "", errors.Wrap(ctx, err, op)
}
return l.Items, nil
return l.Items, l.RemovedIds, RefreshTokenValue(l.RefreshToken), nil
}
func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[AuthToken]string, opt ...Option) error {
@ -49,6 +51,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut
case u.Id == "":
return errors.New(ctx, errors.InvalidParameter, op, "user id is missing")
}
const resourceType = targetResourceType
opts, err := getOpts(opt...)
if err != nil {
@ -57,16 +60,28 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut
if opts.withTargetRetrievalFunc == nil {
opts.withTargetRetrievalFunc = defaultTargetFunc
}
oldRefreshToken, err := r.lookupRefreshToken(ctx, u, resourceType)
if err != nil {
return errors.Wrap(ctx, err, op)
}
// Find and use a token for retrieving targets
var gotResponse bool
var resp []*targets.Target
var removedIds []string
var newRefreshToken RefreshTokenValue
var retErr error
for at, t := range tokens {
resp, err = opts.withTargetRetrievalFunc(ctx, u.Address, t)
resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, oldRefreshToken)
if api.ErrInvalidRefreshToken.Is(err) {
if err := r.deleteRefreshToken(ctx, u, resourceType); err != nil {
return errors.Wrap(ctx, err, op)
}
// try again without the refresh token
oldRefreshToken = ""
resp, removedIds, newRefreshToken, err = opts.withTargetRetrievalFunc(ctx, u.Address, t, "")
}
if err != nil {
// TODO: If we get an error about the token no longer having
// permissions, remove it.
retErr = stderrors.Join(retErr, errors.Wrap(ctx, err, op, errors.WithMsg("for token %q", at.Id)))
continue
}
@ -74,7 +89,7 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut
break
}
if retErr != nil {
if saveErr := r.SaveError(ctx, u, resource.Target.String(), retErr); saveErr != nil {
if saveErr := r.SaveError(ctx, u, resourceType, retErr); saveErr != nil {
return stderrors.Join(err, errors.Wrap(ctx, saveErr, op))
}
}
@ -84,10 +99,17 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut
event.WriteSysEvent(ctx, op, fmt.Sprintf("updating %d targets for user %v", len(resp), u))
_, err = r.rw.DoTx(ctx, db.StdRetryCnt, db.ExpBackoff{}, func(r db.Reader, w db.Writer) error {
// TODO: Instead of deleting everything, use refresh tokens and apply the delta
if _, err := w.Exec(ctx, "delete from target where user_id = @user_id",
[]any{sql.Named("user_id", u.Id)}); err != nil {
return err
switch {
case oldRefreshToken == "":
if _, err := w.Exec(ctx, "delete from target where user_id = @user_id",
[]any{sql.Named("user_id", u.Id)}); err != nil {
return err
}
case len(removedIds) > 0:
if _, err := w.Exec(ctx, "delete from target where id in @ids",
[]any{sql.Named("ids", removedIds)}); err != nil {
return err
}
}
for _, t := range resp {
@ -111,6 +133,12 @@ func (r *Repository) refreshTargets(ctx context.Context, u *user, tokens map[Aut
return err
}
}
if newRefreshToken != "" || oldRefreshToken != "" {
if err := upsertRefreshToken(ctx, w, u, resourceType, newRefreshToken); err != nil {
return err
}
}
return nil
})
if err != nil {

@ -110,13 +110,22 @@ func TestRepository_refreshTargets(t *testing.T) {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := r.refreshTargets(ctx, tc.u, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(tc.targets)))
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{tc.targets}, [][]string{nil})))
if tc.errorContains == "" {
assert.NoError(t, err)
rw := db.New(s)
var got []*Target
require.NoError(t, rw.SearchWhere(ctx, &got, "true", nil))
assert.Len(t, got, tc.wantCount)
t.Cleanup(func() {
refTok := &refreshToken{
UserId: tc.u.Id,
ResourceType: targetResourceType,
}
_, err := r.rw.Delete(ctx, refTok)
require.NoError(t, err)
})
} else {
assert.ErrorContains(t, err, tc.errorContains)
}
@ -124,6 +133,97 @@ func TestRepository_refreshTargets(t *testing.T) {
}
}
func TestRepository_RefreshTargets_withRefreshTokens(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
addr := "address"
u := user{
Id: "u1",
Address: addr,
}
at := &authtokens.AuthToken{
Id: "at_1",
Token: "at_1_token",
UserId: u.Id,
}
kt := KeyringToken{
KeyringType: "keyring",
TokenName: "token",
AuthTokenId: at.Id,
}
atMap := map[ringToken]*authtokens.AuthToken{
{kt.KeyringType, kt.TokenName}: at,
}
r, err := NewRepository(ctx, s, &sync.Map{}, mapBasedAuthTokenKeyringLookup(atMap), sliceBasedAuthTokenBoundaryReader(maps.Values(atMap)))
require.NoError(t, err)
require.NoError(t, r.AddKeyringToken(ctx, addr, kt))
ts := [][]*targets.Target{
{
{
Id: "ttcp_1",
Name: "name1",
Address: "address1",
Type: "tcp",
SessionMaxSeconds: 111,
},
{
Id: "ttcp_2",
Name: "name2",
Address: "address2",
Type: "tcp",
SessionMaxSeconds: 222,
},
}, {
{
Id: "ttcp_3",
Name: "name3",
Address: "address3",
Type: "tcp",
SessionMaxSeconds: 333,
},
},
}
require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))
got, err := r.ListTargets(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 2)
// Refreshing again uses the refresh token and get additional sessions, appending
// them to the response
require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))
assert.NoError(t, err)
got, err = r.ListTargets(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 3)
// Refreshing again wont return any more resources, but also none should be
// removed
require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, ts, [][]string{nil, nil}))))
assert.NoError(t, err)
got, err = r.ListTargets(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 3)
// Refresh again with the refresh token being reported as invalid.
require.NoError(t, r.refreshTargets(ctx, &u, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testErroringForRefreshTokenRetrievalFunc(t, ts[0]))))
assert.NoError(t, err)
got, err = r.ListTargets(ctx, at.Id)
require.NoError(t, err)
assert.Len(t, got, 2)
}
func TestRepository_ListTargets(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
@ -190,7 +290,7 @@ func TestRepository_ListTargets(t *testing.T) {
},
}
require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(ts))))
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))
t.Run("wrong user gets no targets", func(t *testing.T) {
l, err := r.ListTargets(ctx, kt2.AuthTokenId)
@ -294,7 +394,7 @@ func TestRepository_QueryTargets(t *testing.T) {
},
}
require.NoError(t, r.refreshTargets(ctx, u1, map[AuthToken]string{{Id: "id"}: "something"},
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(ts))))
WithTargetRetrievalFunc(testStaticResourceRetrievalFunc(t, [][]*targets.Target{ts}, [][]string{nil}))))
t.Run("wrong token gets no targets", func(t *testing.T) {
l, err := r.QueryTargets(ctx, kt2.AuthTokenId, query)
@ -321,8 +421,16 @@ func TestDefaultTargetRetrievalFunc(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, tar2)
got, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token)
got, removed, refTok, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, "")
assert.NoError(t, err)
assert.NotEmpty(t, refTok)
assert.Empty(t, removed)
assert.Contains(t, got, tar1.Item)
assert.Contains(t, got, tar2.Item)
got2, removed2, refTok2, err := defaultTargetFunc(tc.Context(), tc.ApiAddrs()[0], tc.Token().Token, refTok)
assert.NoError(t, err)
assert.NotEmpty(t, refTok2)
assert.Empty(t, removed2)
assert.Empty(t, got2)
}

@ -54,7 +54,7 @@ func TestRepository_SaveError(t *testing.T) {
sliceBasedAuthTokenBoundaryReader(nil))
require.NoError(t, err)
testResource := "test_resource_type"
testResource := targetResourceType
testErr := fmt.Errorf("test error for %q", testResource)
u := &user{

@ -135,6 +135,110 @@ func TestAuthToken_NoMoreKeyringTokens(t *testing.T) {
assert.NoError(t, rw.LookupById(ctx, u))
}
func TestRefreshToken(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)
require.NoError(t, err)
rw := db.New(s)
u := &user{
Id: "userId",
Address: "address",
}
t.Run("no user foreign key constraint", func(t *testing.T) {
tok := &refreshToken{
UserId: u.Id,
ResourceType: targetResourceType,
RefreshToken: "something",
}
require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed")
})
require.NoError(t, rw.Create(ctx, u))
t.Run("no user id", func(t *testing.T) {
tok := &refreshToken{
ResourceType: targetResourceType,
RefreshToken: "something",
}
require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed")
})
t.Run("unknown resource type", func(t *testing.T) {
tok := &refreshToken{
UserId: u.Id,
ResourceType: "thisisntknown",
RefreshToken: "something",
}
require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed")
})
t.Run("empty refresh token", func(t *testing.T) {
tok := &refreshToken{
UserId: u.Id,
ResourceType: "thisisntknown",
}
require.ErrorContains(t, rw.Create(ctx, tok), "constraint failed")
})
t.Run("create", func(t *testing.T) {
tok := &refreshToken{
UserId: u.Id,
ResourceType: targetResourceType,
RefreshToken: "something",
}
require.NoError(t, rw.Create(ctx, tok))
require.NoError(t, rw.LookupById(ctx, tok))
assert.NotEmpty(t, tok.RefreshToken)
})
t.Run("update", func(t *testing.T) {
u := &user{
Id: "updatethis",
Address: "updated",
}
require.NoError(t, rw.Create(ctx, u))
tok := &refreshToken{
UserId: u.Id,
ResourceType: targetResourceType,
RefreshToken: "started",
}
require.NoError(t, rw.Create(ctx, tok))
tok.RefreshToken = "updated"
n, err := rw.Update(ctx, tok, []string{"RefreshToken"}, nil)
assert.NoError(t, err)
assert.Equal(t, 1, n)
})
t.Run("delete user deletes token", func(t *testing.T) {
u := &user{
Id: "deletethis",
Address: "deleted",
}
require.NoError(t, rw.Create(ctx, u))
tok := &refreshToken{
UserId: u.Id,
ResourceType: targetResourceType,
RefreshToken: "deleted_soon",
}
require.NoError(t, rw.Create(ctx, tok))
_, err = rw.Exec(ctx, "delete from user where id = ?", []any{u.Id})
require.True(t, errors.IsNotFoundError(rw.LookupById(ctx, tok)))
})
// TODO: When gorm sqlite driver fixes it's delete, use rw.Delete instead of the Exec.
// n, err := rw.Delete(ctx, p)
_, err = rw.Exec(ctx, "delete from refresh_token", nil)
assert.NoError(t, err)
}
func TestAuthToken(t *testing.T) {
ctx := context.Background()
s, err := cachedb.Open(ctx)

@ -75,17 +75,17 @@ func (s *TestServer) AddResources(t *testing.T, p *authtokens.AuthToken, tars []
r, err := cache.NewRepository(ctx, s.CacheServer.store, &sync.Map{}, s.cmd.ReadTokenFromKeyring, atReadFn)
require.NoError(t, err)
tarFn := func(ctx context.Context, _, tok string) ([]*targets.Target, error) {
tarFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*targets.Target, []string, cache.RefreshTokenValue, error) {
if tok != p.Token {
return nil, nil
return nil, nil, "", nil
}
return tars, nil
return tars, nil, "", nil
}
sessFn := func(ctx context.Context, _, tok string) ([]*sessions.Session, error) {
sessFn := func(ctx context.Context, _, tok string, _ cache.RefreshTokenValue) ([]*sessions.Session, []string, cache.RefreshTokenValue, error) {
if tok != p.Token {
return nil, nil
return nil, nil, "", nil
}
return sess, nil
return sess, nil, "", nil
}
rs, err := cache.NewRefreshService(ctx, r)
require.NoError(t, err)

@ -14,7 +14,7 @@ import (
)
const (
DefaultRefreshIntervalSeconds = 5 * 60
DefaultRefreshIntervalSeconds = 15
defaultRefreshInterval = DefaultRefreshIntervalSeconds * time.Second
)

@ -13,6 +13,33 @@ create table if not exists user (
check (length(address) > 0)
);
-- Contains the known resource types contained in the boundary client cache
create table if not exists resource_type_enm(
string text not null primary key
constraint only_predefined_resource_types_allowed
check(string in ('unknown', 'target', 'session'))
);
insert into resource_type_enm (string)
values
('unknown'),
('target'),
('session');
-- Contains refresh tokens for list requests sent by the client daemon to the
-- boundary instance.
create table if not exists refresh_token(
user_id text not null
references user(id)
on delete cascade,
resource_type text not null
references resource_type_enm(string)
constraint only_known_resource_types_allowed,
refresh_token text not null
check (length(refresh_token) > 0),
primary key (user_id, resource_type)
);
-- Contains the boundary auth token
create table if not exists auth_token (
-- id is the boundary id of the auth token
@ -107,7 +134,9 @@ create table if not exists api_error (
user_id text not null
references user(id)
on delete cascade,
resource_type text not null,
resource_type text not null
references resource_type_enm(string)
constraint only_known_resource_types_allowed,
error text not null,
create_time timestamp not null default current_timestamp,
primary key (user_id, resource_type)

@ -153,6 +153,25 @@ func InvalidArgumentErrorf(msg string, fields map[string]string) *ApiError {
return apiErr
}
func invalidRefreshTokenError(err error) *ApiError {
const op = "handlers.invalidRefreshTokenError"
ctx := context.TODO()
var domainErr *errors.Err
if !errors.As(err, &domainErr) {
event.WriteError(ctx, op, err, event.WithInfoMsg("Unable to build invalid argument api error."))
}
return &ApiError{
Status: http.StatusBadRequest,
Inner: &pb.Error{
Kind: domainErr.Info().Message,
Op: string(domainErr.Op),
Message: domainErr.Msg,
},
}
}
// ConflictErrorf generates an ApiErr when a pre-conditional check is violated.
// Note, this deliberately doesn't translate to the similarly named '412
// Precondition Failed' HTTP response status. The ApiErr returned is a 400 bad
@ -199,6 +218,8 @@ func backendErrorToApiError(inErr error) *ApiError {
return NotFoundErrorf(genericNotFoundMsg)
case errors.Match(errors.T(errors.AccountAlreadyAssociated), inErr):
return InvalidArgumentErrorf(inErr.Error(), nil)
case errors.Match(errors.T(errors.InvalidRefreshToken), inErr):
return invalidRefreshTokenError(inErr)
case errors.Match(errors.T(errors.InvalidFieldMask), inErr), errors.Match(errors.T(errors.EmptyFieldMask), inErr):
return InvalidArgumentErrorf("Error in provided request", map[string]string{"update_mask": "Invalid update mask provided."})
case errors.IsUniqueError(inErr):

@ -219,6 +219,18 @@ func TestApiErrorHandler(t *testing.T) {
},
},
},
{
name: "Invalid refresh token error",
err: errors.New(ctx, errors.InvalidRefreshToken, errors.Op("test.op"), "this is a test invalid refresh token error"),
expected: ApiError{
Status: http.StatusBadRequest,
Inner: &pb.Error{
Kind: "invalid refresh token",
Op: "test.op",
Message: "this is a test invalid refresh token error",
},
},
},
{
name: "Wrapped forbidden domain error",
err: fmt.Errorf("got error: %w", errors.E(ctx, errors.WithCode(errors.Forbidden), errors.WithMsg("test msg"))),

@ -67,6 +67,8 @@ const (
Closed = 134 // Closed represents an error when an operation cannot be completed because the thing being operated on is closed
ChecksumMismatch = 135 // ChecksumMismatch represents an error when a checksum is mismatched
InvalidRefreshToken Code = 136 // InvalidRefreshToken represents an error where the provided refresh token is invalid
AuthAttemptExpired Code = 198 // AuthAttemptExpired represents an expired authentication attempt
AuthMethodInactive Code = 199 // AuthMethodInactive represents an error that means the auth method is not active.

@ -415,6 +415,11 @@ func TestCode_Both_String_Info(t *testing.T) {
c: InvalidConfiguration,
want: InvalidConfiguration,
},
{
name: "InvalidRefreshToken",
c: InvalidRefreshToken,
want: InvalidRefreshToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

@ -347,4 +347,8 @@ var errorCodeInfo = map[Code]Info{
Message: "invalid configuration",
Kind: Configuration,
},
InvalidRefreshToken: {
Message: "invalid refresh token",
Kind: Parameter,
},
}

@ -0,0 +1,133 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
// The refreshtoken package encapsulates domain logic surrounding
// list endpoint refresh tokens. Refresh tokens are used when users
// paginate through results in our list endpoints, and also to
// allow users to request new, updated and deleted resources.
package refreshtoken
import (
"bytes"
"context"
"time"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/types/resource"
)
// A Token is returned in list endpoints for the purposes of pagination
type Token struct {
CreatedTime time.Time
UpdatedTime time.Time
ResourceType resource.Type
GrantsHash []byte
LastItemId string
LastItemUpdatedTime time.Time
}
// New creates a new refresh token from a createdTime, resource type, grants hash, and last item information
func New(ctx context.Context, createdTime time.Time, updatedTime time.Time, typ resource.Type, grantsHash []byte, lastItemId string, lastItemUpdatedTime time.Time) (*Token, error) {
const op = "refreshtoken.New"
if len(grantsHash) == 0 {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing grants hash")
}
if createdTime.After(time.Now()) {
return nil, errors.New(ctx, errors.InvalidParameter, op, "created time is in the future")
}
if createdTime.Before(time.Now().AddDate(0, 0, -30)) {
return nil, errors.New(ctx, errors.InvalidParameter, op, "created time is too old")
}
if updatedTime.Before(createdTime) {
return nil, errors.New(ctx, errors.InvalidParameter, op, "updated time is older than created time")
}
if updatedTime.After(time.Now()) {
return nil, errors.New(ctx, errors.InvalidParameter, op, "updated time is in the future")
}
if lastItemId == "" {
return nil, errors.New(ctx, errors.InvalidParameter, op, "missing last item ID")
}
if lastItemUpdatedTime.After(time.Now()) {
return nil, errors.New(ctx, errors.InvalidParameter, op, "last item updated time is in the future")
}
return &Token{
CreatedTime: createdTime,
UpdatedTime: updatedTime,
ResourceType: typ,
GrantsHash: grantsHash,
LastItemId: lastItemId,
LastItemUpdatedTime: lastItemUpdatedTime,
}, nil
}
// FromResource creates a new refresh token from a resource and grants hash
func FromResource(res boundary.Resource, grantsHash []byte) *Token {
t := time.Now()
return &Token{
CreatedTime: t,
UpdatedTime: t,
ResourceType: res.GetResourceType(),
GrantsHash: grantsHash,
LastItemId: res.GetPublicId(),
LastItemUpdatedTime: res.GetUpdateTime().AsTime(),
}
}
// Refresh refreshes a token's updated time
func (rt *Token) Refresh(updatedTime time.Time) *Token {
rt.UpdatedTime = updatedTime
return rt
}
// RefreshLastItem refreshes a token's updated time and last item
func (rt *Token) RefreshLastItem(res boundary.Resource, updatedTime time.Time) *Token {
rt.UpdatedTime = updatedTime
rt.LastItemId = res.GetPublicId()
rt.LastItemUpdatedTime = res.GetUpdateTime().AsTime()
return rt
}
// Validate validates the refresh token.
func (rt *Token) Validate(
ctx context.Context,
expectedResourceType resource.Type,
expectedGrantsHash []byte,
) error {
const op = "refreshtoken.Validate"
if rt == nil {
return errors.New(ctx, errors.InvalidParameter, op, "refresh token was missing")
}
if len(rt.GrantsHash) == 0 {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was missing its grants hash")
}
if !bytes.Equal(rt.GrantsHash, expectedGrantsHash) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "grants have changed since refresh token was issued")
}
if rt.CreatedTime.After(time.Now()) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was created in the future")
}
// Tokens older than 30 days have expired
if rt.CreatedTime.Before(time.Now().AddDate(0, 0, -30)) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was expired")
}
if rt.UpdatedTime.Before(rt.CreatedTime) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was updated before its creation time")
}
if rt.UpdatedTime.After(time.Now()) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token was updated in the future")
}
if rt.LastItemId == "" {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token missing last item ID")
}
if rt.LastItemUpdatedTime.After(time.Now()) {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token last item was updated in the future")
}
if rt.ResourceType != expectedResourceType {
return errors.New(ctx, errors.InvalidRefreshToken, op, "refresh token resource type does not match expected resource type")
}
return nil
}

@ -0,0 +1,398 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package refreshtoken_test
import (
"context"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/internal/boundary"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/refreshtoken"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/stretchr/testify/require"
)
func Test_ValidateRefreshToken(t *testing.T) {
fiveDaysAgo := time.Now().AddDate(0, 0, -5)
tests := []struct {
name string
token *refreshtoken.Token
grantsHash []byte
resourceType resource.Type
wantErrString string
wantErrCode errors.Code
}{
{
name: "valid token",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
},
{
name: "nil token",
token: nil,
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was missing",
wantErrCode: errors.InvalidParameter,
},
{
name: "no grants hash",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: nil,
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was missing its grants hash",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "changed grants hash",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some other hash"),
resourceType: resource.Target,
wantErrString: "grants have changed since refresh token was issued",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "created in the future",
token: &refreshtoken.Token{
CreatedTime: time.Now().AddDate(1, 0, 0),
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was created in the future",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "expired",
token: &refreshtoken.Token{
CreatedTime: time.Now().AddDate(0, 0, -31),
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was expired",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "updated before created",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, -1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was updated before its creation time",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "updated after now",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: time.Now().AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token was updated in the future",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "resource type mismatch",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.SessionRecording,
wantErrString: "refresh token resource type does not match expected resource type",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "last item ID unset",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token missing last item ID",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "last item ID unset",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "",
LastItemUpdatedTime: fiveDaysAgo,
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token missing last item ID",
wantErrCode: errors.InvalidRefreshToken,
},
{
name: "updated in the future",
token: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "s_1234567890",
LastItemUpdatedTime: time.Now().AddDate(1, 0, 0),
},
grantsHash: []byte("some hash"),
resourceType: resource.Target,
wantErrString: "refresh token last item was updated in the future",
wantErrCode: errors.InvalidRefreshToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.token.Validate(context.Background(), tt.resourceType, tt.grantsHash)
if tt.wantErrString != "" {
require.ErrorContains(t, err, tt.wantErrString)
require.Equal(t, errors.Convert(err).Code, tt.wantErrCode)
return
}
require.NoError(t, err)
})
}
}
func TestNew(t *testing.T) {
fiveDaysAgo := time.Now().AddDate(0, 0, -5)
tests := []struct {
name string
createdTime time.Time
updatedTime time.Time
typ resource.Type
grantsHash []byte
lastItemId string
lastItemUpdatedTime time.Time
want *refreshtoken.Token
wantErrString string
wantErrCode errors.Code
}{
{
name: "valid refresh token",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
want: &refreshtoken.Token{
CreatedTime: fiveDaysAgo,
UpdatedTime: fiveDaysAgo.AddDate(0, 0, 1),
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "some id",
LastItemUpdatedTime: fiveDaysAgo,
},
},
{
name: "missing grants hash",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: nil,
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "missing grants hash",
wantErrCode: errors.InvalidParameter,
},
{
name: "new created time",
createdTime: fiveDaysAgo.AddDate(1, 0, 0),
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "created time is in the future",
wantErrCode: errors.InvalidParameter,
},
{
name: "old created time",
createdTime: fiveDaysAgo.AddDate(-1, 0, 0),
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "created time is too old",
wantErrCode: errors.InvalidParameter,
},
{
name: "new updated time",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(1, 0, 0),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "updated time is in the future",
wantErrCode: errors.InvalidParameter,
},
{
name: "updated time older than created time",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(0, 0, -11),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "updated time is older than created time",
wantErrCode: errors.InvalidParameter,
},
{
name: "missing last item id",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "",
lastItemUpdatedTime: fiveDaysAgo,
wantErrString: "missing last item ID",
wantErrCode: errors.InvalidParameter,
},
{
name: "new last item updated time",
createdTime: fiveDaysAgo,
updatedTime: fiveDaysAgo.AddDate(0, 0, 1),
typ: resource.Target,
grantsHash: []byte("some hash"),
lastItemId: "some id",
lastItemUpdatedTime: fiveDaysAgo.AddDate(1, 0, 0),
wantErrString: "last item updated time is in the future",
wantErrCode: errors.InvalidParameter,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := refreshtoken.New(context.Background(), tt.createdTime, tt.updatedTime, tt.typ, tt.grantsHash, tt.lastItemId, tt.lastItemUpdatedTime)
if tt.wantErrString != "" {
require.ErrorContains(t, err, tt.wantErrString)
require.Equal(t, errors.Convert(err).Code, tt.wantErrCode)
return
}
require.NoError(t, err)
require.Empty(t, cmp.Diff(got, tt.want))
})
}
}
type fakeTargetResource struct {
boundary.Resource
publicId string
updateTime *timestamp.Timestamp
}
func (m *fakeTargetResource) GetResourceType() resource.Type {
return resource.Target
}
func (m *fakeTargetResource) GetPublicId() string {
return m.publicId
}
func (m *fakeTargetResource) GetUpdateTime() *timestamp.Timestamp {
return m.updateTime
}
func TestFromResource(t *testing.T) {
fiveDaysAgo := time.Now().AddDate(0, 0, -5)
res := &fakeTargetResource{
publicId: "tcp_1234567890",
updateTime: timestamp.New(fiveDaysAgo),
}
tok := refreshtoken.FromResource(res, []byte("some hash"))
// Check that it's within 1 second of now according to the system
// If this is flaky... just increase the limit 😬.
require.True(t, tok.CreatedTime.Before(time.Now().Add(time.Second)))
require.True(t, tok.CreatedTime.After(time.Now().Add(-time.Second)))
require.True(t, tok.UpdatedTime.Before(time.Now().Add(time.Second)))
require.True(t, tok.UpdatedTime.After(time.Now().Add(-time.Second)))
require.Equal(t, tok.ResourceType, res.GetResourceType())
require.Equal(t, tok.GrantsHash, []byte("some hash"))
require.Equal(t, tok.LastItemId, res.GetPublicId())
require.True(t, tok.LastItemUpdatedTime.Equal(res.GetUpdateTime().AsTime()))
}
func TestRefresh(t *testing.T) {
createdTime := time.Now().AddDate(0, 0, -5)
updatedTime := time.Now()
tok := &refreshtoken.Token{
CreatedTime: createdTime,
UpdatedTime: createdTime,
ResourceType: resource.Target,
GrantsHash: []byte("some hash"),
LastItemId: "tcp_1234567890",
LastItemUpdatedTime: createdTime,
}
newTok := tok.Refresh(updatedTime)
require.True(t, newTok.UpdatedTime.Equal(updatedTime))
require.True(t, newTok.CreatedTime.Equal(createdTime))
require.Equal(t, newTok.ResourceType, tok.ResourceType)
require.Equal(t, newTok.GrantsHash, tok.GrantsHash)
require.Equal(t, newTok.LastItemId, tok.LastItemId)
require.True(t, newTok.LastItemUpdatedTime.Equal(tok.LastItemUpdatedTime))
}
Loading…
Cancel
Save