You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/authtoken/repository_test.go

1073 lines
31 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package authtoken
import (
"context"
"sort"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/hashicorp/boundary/internal/authtoken/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
)
func TestRepository_New(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
type args struct {
r db.Reader
w db.Writer
kms *kms.Kms
opts []Option
}
tests := []struct {
name string
args args
want *Repository
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "valid default limit",
args: args{
r: rw,
w: rw,
kms: kmsCache,
opts: []Option{},
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
limit: db.DefaultLimit,
timeToLiveDuration: defaultTokenTimeToLiveDuration,
timeToStaleDuration: defaultTokenTimeToStaleDuration,
},
},
{
name: "valid new limit",
args: args{
r: rw,
w: rw,
kms: kmsCache,
opts: []Option{
WithLimit(5),
},
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
limit: 5,
timeToLiveDuration: defaultTokenTimeToLiveDuration,
timeToStaleDuration: defaultTokenTimeToStaleDuration,
},
},
{
name: "valid token time to live",
args: args{
r: rw,
w: rw,
kms: kmsCache,
opts: []Option{
WithTokenTimeToLiveDuration(1 * time.Hour),
},
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
limit: db.DefaultLimit,
timeToLiveDuration: 1 * time.Hour,
timeToStaleDuration: defaultTokenTimeToStaleDuration,
},
},
{
name: "valid token time to stale",
args: args{
r: rw,
w: rw,
kms: kmsCache,
opts: []Option{
WithTokenTimeToStaleDuration(1 * time.Hour),
},
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
limit: db.DefaultLimit,
timeToStaleDuration: 1 * time.Hour,
timeToLiveDuration: defaultTokenTimeToLiveDuration,
},
},
{
name: "nil-reader",
args: args{
r: nil,
w: rw,
kms: kmsCache,
},
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
},
{
name: "nil-writer",
args: args{
r: rw,
w: nil,
kms: kmsCache,
},
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db writer: parameter violation: error #100",
},
{
name: "nil-kms",
args: args{
r: rw,
w: rw,
kms: nil,
},
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil kms: parameter violation: error #100",
},
{
name: "all-nils",
args: args{
r: nil,
w: nil,
kms: nil,
},
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, err := NewRepository(context.Background(), tt.args.r, tt.args.w, tt.args.kms, tt.args.opts...)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
assert.NotNil(got)
assert.Equal(tt.want, got)
})
}
}
func TestRepository_CreateAuthToken(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org1, _ := iam.TestScopes(t, repo)
am := password.TestAuthMethods(t, conn, org1.GetPublicId(), 1)[0]
aAcct := password.TestAccount(t, conn, am.GetPublicId(), "name1")
iam.TestSetPrimaryAuthMethod(t, repo, org1, am.PublicId)
u1 := iam.TestUser(t, repo, org1.PublicId, iam.WithAccountIds(aAcct.PublicId))
org2, _ := iam.TestScopes(t, repo)
u2 := iam.TestUser(t, repo, org2.GetPublicId())
testId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
tests := []struct {
name string
iamUser *iam.User
authAcctId string
opt []Option
want *AuthToken
wantErr bool
}{
{
name: "valid",
iamUser: u1,
authAcctId: aAcct.GetPublicId(),
want: &AuthToken{
AuthToken: &store.AuthToken{
AuthAccountId: aAcct.GetPublicId(),
},
},
},
{
name: "WithPublicId-WithStatus",
iamUser: u1,
authAcctId: aAcct.GetPublicId(),
opt: []Option{WithPublicId(testId), WithStatus(PendingStatus)},
want: &AuthToken{
AuthToken: &store.AuthToken{
AuthAccountId: aAcct.GetPublicId(),
PublicId: testId,
Status: string(PendingStatus),
},
},
},
{
name: "unconnected-authaccount-user",
iamUser: u2,
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
{
name: "no-authacctid",
iamUser: u1,
wantErr: true,
},
{
name: "no-userid",
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
{
name: "invalid-authacctid",
iamUser: u1,
authAcctId: "this_is_invalid",
wantErr: true,
},
{
name: "invalid-userid",
iamUser: func() *iam.User { u := u1.Clone().(*iam.User); u.PublicId = "this_is_invalid"; return u }(),
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.CreateAuthToken(ctx, tt.iamUser, tt.authAcctId, tt.opt...)
if tt.wantErr {
assert.Error(err)
assert.Nil(got)
return
}
require.NoError(err, "Got error for CreateAuthToken(ctx, %v, %v)", tt.iamUser, tt.authAcctId)
assert.NotNil(got)
db.AssertPublicId(t, globals.AuthTokenPrefix, got.PublicId)
assert.Equal(tt.authAcctId, got.GetAuthAccountId())
assert.Equal(got.CreateTime, got.UpdateTime)
assert.Equal(got.CreateTime, got.ApproximateLastAccessTime)
opts := getOpts(tt.opt...)
if opts.withPublicId != "" {
assert.Equal(opts.withPublicId, got.PublicId)
}
if opts.withStatus != "" {
assert.Equal(string(opts.withStatus), got.Status)
}
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_CREATE)))
})
}
}
func TestRepository_LookupAuthToken(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
at.Token = ""
at.CtToken = nil
at.KeyId = ""
badId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
require.NotNil(t, badId)
tests := []struct {
name string
id string
want *AuthToken
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "found",
id: at.GetPublicId(),
want: at,
},
{
name: "not-found",
id: badId,
want: nil,
},
{
name: "bad-public-id",
id: "",
want: nil,
wantIsErr: errors.InvalidPublicId,
wantErrMsg: "authtoken.(Repository).LookupAuthToken: missing public id: parameter violation: error #102",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.LookupAuthToken(ctx, tt.id)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
if got == nil {
assert.Nil(tt.want)
return
}
require.NotNil(tt.want, "got %v, wanted nil", got)
// TODO: This test fails by a very small amount -- 500 nanos ish in
// my experience -- if they are required to be equal. I think this
// is because the resolution of the timestamp in the db does not
// match the resolution in Go code. But might be worth checking
// into.
wantGoTimeExpr := tt.want.AuthToken.GetExpirationTime().Timestamp.AsTime()
gotGoTimeExpr := got.AuthToken.GetExpirationTime().Timestamp.AsTime()
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
})
}
}
func TestRepository_ValidateToken(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
lastAccessedUpdateDuration = 0
timeSkew = 20 * time.Millisecond
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
require.NotNil(t, repo)
org, _ := iam.TestScopes(t, iamRepo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
atToken := at.GetToken()
at.Token = ""
at.CtToken = nil
at.KeyId = ""
atTime := at.GetApproximateLastAccessTime()
require.NotNil(t, atTime)
badId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
require.NotNil(t, badId)
badToken, err := newAuthToken(ctx)
require.NoError(t, err)
require.NotNil(t, badToken)
tests := []struct {
name string
id string
token string
want *AuthToken
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "exists",
id: at.GetPublicId(),
token: atToken,
want: at,
},
{
name: "doesnt-exist",
id: badId,
token: badToken.Token,
want: nil,
},
{
name: "empty-token",
id: at.GetPublicId(),
token: "",
want: nil,
wantIsErr: errors.InvalidParameter,
wantErrMsg: "authtoken.(Repository).ValidateToken: missing token: parameter violation: error #100",
},
{
name: "mismatched-token",
id: at.GetPublicId(),
token: badToken.Token,
want: nil,
wantIsErr: errors.Unknown,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := repo.ValidateToken(ctx, tt.id, tt.token)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
if got == nil {
assert.Nil(tt.want)
// No need to compare updated time if we didn't get an initial auth token to compare against.
return
}
require.NotNil(tt.want, "Got %v but wanted nil", got)
// NOTE: See comment in LookupAuthToken about this logic
wantGoTimeExpr := tt.want.AuthToken.GetExpirationTime().AsTime()
gotGoTimeExpr := got.AuthToken.GetExpirationTime().AsTime()
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
// preTime1 should be the value prior to the ValidateToken was called so it should equal creation time
preTime1 := got.GetApproximateLastAccessTime()
require.NoError(err)
assert.True(preTime1.AsTime().Equal(atTime.AsTime()), "Create time %q doesn't match the time from the first call to MaybeUpdateLastAccesssed: %q.", atTime, preTime1)
// Enable the duration which limits how frequently a token's approximate last accessed time can be updated
// so the next call doesn't cause the last accessed time to be updated.
lastAccessedUpdateDuration = 1 * time.Hour
got2, err := repo.ValidateToken(ctx, tt.id, tt.token)
assert.NoError(err)
preTime2 := got2.GetApproximateLastAccessTime().GetTimestamp()
assert.True(preTime2.AsTime().After(preTime1.AsTime()), "First updated time %q was not after the creation time %q", preTime2, preTime1)
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE)))
got3, err := repo.ValidateToken(ctx, tt.id, tt.token)
require.NoError(err)
preTime3 := got3.GetApproximateLastAccessTime()
assert.True(preTime3.AsTime().Equal(preTime2.AsTime()), "The 3rd timestamp %q was not equal to the second time %q", preTime3, preTime2)
})
}
}
func TestRepository_ValidateToken_expired(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, iamRepo)
baseAT := TestAuthToken(t, conn, kms, org.GetPublicId())
baseAT.GetAuthAccountId()
aAcct := allocAuthAccount()
aAcct.PublicId = baseAT.GetAuthAccountId()
require.NoError(t, rw.LookupByPublicId(context.Background(), aAcct))
iamUser, _, err := iamRepo.LookupUser(context.Background(), aAcct.GetIamUserId())
require.NoError(t, err)
require.NotNil(t, iamUser)
tests := []struct {
name string
staleDuration time.Duration
expirationDuration time.Duration
wantReturned bool
}{
{
name: "not-stale-or-expired",
staleDuration: defaultTokenTimeToStaleDuration,
expirationDuration: defaultTokenTimeToLiveDuration,
wantReturned: true,
},
{
name: "stale",
staleDuration: 1 * time.Millisecond,
expirationDuration: defaultTokenTimeToLiveDuration,
wantReturned: false,
},
{
name: "expired",
staleDuration: defaultTokenTimeToStaleDuration,
expirationDuration: 1 * time.Millisecond,
wantReturned: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
ctx := context.Background()
timeSkew = 20 * time.Millisecond
repo, err := NewRepository(ctx, rw, rw, kms,
WithTokenTimeToLiveDuration(tt.expirationDuration),
WithTokenTimeToStaleDuration(tt.staleDuration))
require.NoError(err)
require.NotNil(repo)
at, err := repo.CreateAuthToken(ctx, iamUser, baseAT.GetAuthAccountId())
require.NoError(err)
got, err := repo.ValidateToken(ctx, at.GetPublicId(), at.GetToken())
require.NoError(err)
if tt.wantReturned {
assert.NotNil(got)
} else {
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, at.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
assert.Nil(got)
}
})
}
}
func TestRepository_DeleteAuthToken(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
badId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
require.NotNil(t, badId)
tests := []struct {
name string
id string
want int
wantIsErr errors.Code
wantErrMsg string
}{
{
name: "found",
id: at.GetPublicId(),
want: 1,
},
{
name: "not-found",
id: badId,
want: 0,
},
{
name: "empty-public-id",
id: "",
want: 0,
wantIsErr: errors.InvalidPublicId,
wantErrMsg: "authtoken.(Repository).DeleteAuthToken: missing public id: parameter violation: error #102",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.DeleteAuthToken(ctx, tt.id)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
assert.Equal(tt.wantErrMsg, err.Error())
return
}
assert.NoError(err)
assert.Equal(tt.want, got, "row count")
if tt.want != 0 {
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, tt.id, db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
}
})
}
}
func TestRepository_ListAuthTokens(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at1 := TestAuthToken(t, conn, kms, org.GetPublicId())
at1.Token = ""
at1.KeyId = ""
at2 := TestAuthToken(t, conn, kms, org.GetPublicId())
at2.Token = ""
at2.KeyId = ""
at3 := TestAuthToken(t, conn, kms, org.GetPublicId())
at3.Token = ""
at3.KeyId = ""
emptyOrg, _ := iam.TestScopes(t, repo)
tests := []struct {
name string
orgId string
want []*AuthToken
wantTTime time.Time
}{
{
name: "populated",
orgId: org.GetPublicId(),
want: []*AuthToken{at1, at2, at3},
wantTTime: time.Now(),
},
{
name: "empty",
orgId: emptyOrg.GetPublicId(),
want: []*AuthToken{},
wantTTime: time.Now(),
},
{
name: "empty-org-id",
orgId: "",
want: []*AuthToken{},
wantTTime: time.Now(),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, ttime, err := repo.listAuthTokens(ctx, []string{tt.orgId})
assert.NoError(err)
sort.Slice(tt.want, func(i, j int) bool { return tt.want[i].PublicId < tt.want[j].PublicId })
sort.Slice(got, func(i, j int) bool { return got[i].PublicId < got[j].PublicId })
assert.Empty(cmp.Diff(tt.want, got, protocmp.Transform()), "row count")
// Transaction timestamp should be within ~10 seconds of now
assert.True(tt.wantTTime.Before(ttime.Add(10 * time.Second)))
assert.True(tt.wantTTime.After(ttime.Add(-10 * time.Second)))
})
}
t.Run("withStartPageAfter", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
for i := 0; i < 7; i++ {
at := TestAuthToken(t, conn, kms, org.GetPublicId())
at.Token = ""
at.KeyId = ""
}
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(err)
page1, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2))
require.NoError(err)
require.Len(page1, 2)
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
page2, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page1[1]))
require.NoError(err)
require.Len(page2, 2)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
for _, item := range page1 {
assert.NotEqual(item.GetPublicId(), page2[0].GetPublicId())
assert.NotEqual(item.GetPublicId(), page2[1].GetPublicId())
}
page3, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page2[1]))
require.NoError(err)
require.Len(page3, 2)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
for _, item := range page2 {
assert.NotEqual(item.GetPublicId(), page3[0].GetPublicId())
assert.NotEqual(item.GetPublicId(), page3[1].GetPublicId())
}
page4, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page3[1]))
require.NoError(err)
require.Len(page4, 2)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
for _, item := range page3 {
assert.NotEqual(item.GetPublicId(), page4[0].GetPublicId())
assert.NotEqual(item.GetPublicId(), page4[1].GetPublicId())
}
page5, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page4[1]))
require.NoError(err)
require.Len(page5, 2)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
for _, item := range page4 {
assert.NotEqual(item.GetPublicId(), page5[0].GetPublicId())
assert.NotEqual(item.GetPublicId(), page5[1].GetPublicId())
}
page6, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page5[1]))
require.NoError(err)
require.Empty(page6)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
// Create 2 new auth tokens
newAt1 := TestAuthToken(t, conn, kms, org.GetPublicId())
newAt1.Token = ""
newAt1.KeyId = ""
newAt2 := TestAuthToken(t, conn, kms, org.GetPublicId())
newAt2.Token = ""
newAt2.KeyId = ""
// since it will return newest to oldest, we get page1[1] first
page7, ttime, err := repo.listAuthTokensRefresh(
ctx,
time.Now().Add(-1*time.Second),
[]string{org.GetPublicId()},
WithLimit(1),
)
require.NoError(err)
require.Len(page7, 1)
require.Equal(page7[0].GetPublicId(), newAt2.GetPublicId())
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
page8, ttime, err := repo.listAuthTokensRefresh(
context.Background(),
time.Now().Add(-1*time.Second),
[]string{org.GetPublicId()},
WithLimit(1),
WithStartPageAfterItem(page7[0]),
)
require.NoError(err)
require.Len(page8, 1)
require.Equal(page8[0].GetPublicId(), newAt1.GetPublicId())
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
})
}
func TestRepository_ListAuthTokens_Multiple_Scopes(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, proj := iam.TestScopes(t, iamRepo)
const numPerScope = 10
var total int
for i := 0; i < numPerScope; i++ {
TestAuthToken(t, conn, kms, "global")
total++
TestAuthToken(t, conn, kms, org.GetPublicId())
total++
}
got, ttime, err := repo.listAuthTokens(ctx, []string{"global", org.GetPublicId(), proj.GetPublicId()})
require.NoError(t, err)
assert.Equal(t, total, len(got)) // Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
}
func Test_IssuePendingToken(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
repo, err := NewRepository(ctx, rw, rw, kmsCache)
require.NoError(t, err)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
tests := []struct {
name string
tokenRequestId string
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "missing-id",
wantErrMatch: errors.T(errors.InvalidParameter),
wantErrContains: "missing token request id",
},
{
name: "not-found",
tokenRequestId: func() string {
tokenPublicId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
tk := TestAuthToken(t, conn, kmsCache, org.PublicId, WithPublicId(tokenPublicId))
return tk.PublicId
}(),
wantErrMatch: errors.T(errors.RecordNotFound),
wantErrContains: "pending auth token",
},
{
name: "success",
tokenRequestId: func() string {
tokenPublicId, err := NewAuthTokenId(ctx)
require.NoError(t, err)
tk := TestAuthToken(t, conn, kmsCache, org.PublicId, WithStatus(PendingStatus), WithPublicId(tokenPublicId))
return tk.PublicId
}(),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
tk, err := repo.IssueAuthToken(ctx, tt.tokenRequestId)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %s and got: %+v", tt.wantErrMatch.Code, err)
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
assert.NotEmpty(tk)
accessTime := tk.GetApproximateLastAccessTime()
createTime := tk.GetCreateTime().GetTimestamp()
assert.True(accessTime.AsTime().After(createTime.AsTime()), "last access time %q was not after the creation time %q", accessTime, createTime)
})
}
}
func Test_CloseExpiredPendingTokens(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
rootWrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, rootWrapper)
repo, err := NewRepository(ctx, rw, rw, kmsCache)
require.NoError(t, err)
iamRepo := iam.TestRepo(t, conn, rootWrapper)
org, _ := iam.TestScopes(t, iamRepo)
user := iam.TestUser(t, iamRepo, org.GetPublicId())
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
require.NoError(t, err)
createWith := func(cnt int, expIn time.Duration, status Status) {
authMethods := password.TestAuthMethods(t, conn, org.PublicId, 1)
authMethodId := authMethods[0].PublicId
accts := password.TestMultipleAccounts(t, conn, authMethodId, cnt)
for i := 0; i < cnt; i++ {
user, _, err = iamRepo.LookupUser(ctx, user.GetPublicId())
require.NoError(t, err)
_, _ = iamRepo.AddUserAccounts(ctx, user.GetPublicId(), user.GetVersion(), []string{accts[i].GetPublicId()})
at := allocAuthToken()
id, err := NewAuthTokenId(ctx)
require.NoError(t, err)
at.PublicId = id
exp := timestamppb.New(time.Now().Add(expIn).Truncate(time.Second))
at.ExpirationTime = &timestamp.Timestamp{Timestamp: exp}
at.Status = string(status)
at.AuthAccountId = accts[i].PublicId
keyId, err := databaseWrapper.KeyId(ctx)
require.NoError(t, err)
at.KeyId = keyId
at.CtToken = []byte(id)
err = rw.Create(ctx, at)
require.NoError(t, err)
}
}
tests := []struct {
name string
setup func()
wantCnt int
wantErrMatch *errors.Template
wantErrContains string
}{
{
name: "nada-todo",
},
{
name: "close-2",
setup: func() {
createWith(2, -10*time.Second, PendingStatus)
createWith(1, 10*time.Second, IssuedStatus)
createWith(1, 10*time.Second, PendingStatus)
},
wantCnt: 2,
},
{
name: "close-zero",
setup: func() {
createWith(2, 10*time.Second, PendingStatus)
createWith(1, 10*time.Second, IssuedStatus)
},
wantCnt: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
// start with no tokens in the db.
_, err := rw.Exec(ctx, "delete from auth_token", nil)
require.NoError(err)
_, err = rw.Exec(ctx, "delete from auth_account", nil)
require.NoError(err)
if tt.setup != nil {
tt.setup()
}
tokensClosed, err := repo.CloseExpiredPendingTokens(ctx)
if tt.wantErrMatch != nil {
require.Error(err)
assert.Equal(0, tokensClosed)
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %s and got: %s", tt.wantErrMatch.Code, err.Error())
if tt.wantErrContains != "" {
assert.Contains(err.Error(), tt.wantErrContains)
}
return
}
require.NoError(err)
assert.Equal(tt.wantCnt, tokensClosed)
})
}
}
func Test_listDeletedIds(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, iamRepo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
at.Token = ""
at.KeyId = ""
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
// Expect no entries at the start
deletedIds, ttime, err := repo.listDeletedIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
require.Empty(t, deletedIds)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Delete an auth token
_, err = repo.DeleteAuthToken(ctx, at.GetPublicId())
require.NoError(t, err)
// Expect a single entry
deletedIds, ttime, err = repo.listDeletedIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
require.Equal(t, []string{at.GetPublicId()}, deletedIds)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Try again with the time set to now, expect no entries
deletedIds, ttime, err = repo.listDeletedIds(ctx, time.Now())
require.NoError(t, err)
require.Empty(t, deletedIds)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
}
func Test_estimatedCount(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
sqlDb, err := conn.SqlDB(ctx)
require.NoError(t, err)
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
require.NoError(t, err)
// Run analyze to update estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
// Check total entries at start, expect 0
numItems, err := repo.estimatedCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
iamRepo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, iamRepo)
// Create an auth token, expect 1 entry
at := TestAuthToken(t, conn, kms, org.GetPublicId())
at.Token = ""
at.KeyId = ""
// Run analyze to update estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedCount(ctx)
require.NoError(t, err)
assert.Equal(t, 1, numItems)
// Delete the auth token, expect 0 again
_, err = repo.DeleteAuthToken(ctx, at.GetPublicId())
require.NoError(t, err)
// Run analyze to update estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
}