mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1073 lines
31 KiB
1073 lines
31 KiB
// Copyright (c) HashiCorp, Inc.
|
|
// SPDX-License-Identifier: BUSL-1.1
|
|
|
|
package authtoken
|
|
|
|
import (
|
|
"context"
|
|
"sort"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/hashicorp/boundary/globals"
|
|
"github.com/hashicorp/boundary/internal/auth/password"
|
|
"github.com/hashicorp/boundary/internal/db/timestamp"
|
|
"github.com/hashicorp/boundary/internal/kms"
|
|
"github.com/hashicorp/boundary/internal/oplog"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/testing/protocmp"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/hashicorp/boundary/internal/authtoken/store"
|
|
"github.com/hashicorp/boundary/internal/db"
|
|
"github.com/hashicorp/boundary/internal/errors"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
)
|
|
|
|
func TestRepository_New(t *testing.T) {
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, wrapper)
|
|
|
|
type args struct {
|
|
r db.Reader
|
|
w db.Writer
|
|
kms *kms.Kms
|
|
opts []Option
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
want *Repository
|
|
wantIsErr errors.Code
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "valid default limit",
|
|
args: args{
|
|
r: rw,
|
|
w: rw,
|
|
kms: kmsCache,
|
|
opts: []Option{},
|
|
},
|
|
want: &Repository{
|
|
reader: rw,
|
|
writer: rw,
|
|
kms: kmsCache,
|
|
limit: db.DefaultLimit,
|
|
timeToLiveDuration: defaultTokenTimeToLiveDuration,
|
|
timeToStaleDuration: defaultTokenTimeToStaleDuration,
|
|
},
|
|
},
|
|
{
|
|
name: "valid new limit",
|
|
args: args{
|
|
r: rw,
|
|
w: rw,
|
|
kms: kmsCache,
|
|
opts: []Option{
|
|
WithLimit(5),
|
|
},
|
|
},
|
|
want: &Repository{
|
|
reader: rw,
|
|
writer: rw,
|
|
kms: kmsCache,
|
|
limit: 5,
|
|
timeToLiveDuration: defaultTokenTimeToLiveDuration,
|
|
timeToStaleDuration: defaultTokenTimeToStaleDuration,
|
|
},
|
|
},
|
|
{
|
|
name: "valid token time to live",
|
|
args: args{
|
|
r: rw,
|
|
w: rw,
|
|
kms: kmsCache,
|
|
opts: []Option{
|
|
WithTokenTimeToLiveDuration(1 * time.Hour),
|
|
},
|
|
},
|
|
want: &Repository{
|
|
reader: rw,
|
|
writer: rw,
|
|
kms: kmsCache,
|
|
limit: db.DefaultLimit,
|
|
timeToLiveDuration: 1 * time.Hour,
|
|
timeToStaleDuration: defaultTokenTimeToStaleDuration,
|
|
},
|
|
},
|
|
{
|
|
name: "valid token time to stale",
|
|
args: args{
|
|
r: rw,
|
|
w: rw,
|
|
kms: kmsCache,
|
|
opts: []Option{
|
|
WithTokenTimeToStaleDuration(1 * time.Hour),
|
|
},
|
|
},
|
|
want: &Repository{
|
|
reader: rw,
|
|
writer: rw,
|
|
kms: kmsCache,
|
|
limit: db.DefaultLimit,
|
|
timeToStaleDuration: 1 * time.Hour,
|
|
timeToLiveDuration: defaultTokenTimeToLiveDuration,
|
|
},
|
|
},
|
|
|
|
{
|
|
name: "nil-reader",
|
|
args: args{
|
|
r: nil,
|
|
w: rw,
|
|
kms: kmsCache,
|
|
},
|
|
want: nil,
|
|
wantIsErr: errors.InvalidParameter,
|
|
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
|
|
},
|
|
{
|
|
name: "nil-writer",
|
|
args: args{
|
|
r: rw,
|
|
w: nil,
|
|
kms: kmsCache,
|
|
},
|
|
want: nil,
|
|
wantIsErr: errors.InvalidParameter,
|
|
wantErrMsg: "authtoken.NewRepository: nil db writer: parameter violation: error #100",
|
|
},
|
|
{
|
|
name: "nil-kms",
|
|
args: args{
|
|
r: rw,
|
|
w: rw,
|
|
kms: nil,
|
|
},
|
|
want: nil,
|
|
wantIsErr: errors.InvalidParameter,
|
|
wantErrMsg: "authtoken.NewRepository: nil kms: parameter violation: error #100",
|
|
},
|
|
{
|
|
name: "all-nils",
|
|
args: args{
|
|
r: nil,
|
|
w: nil,
|
|
kms: nil,
|
|
},
|
|
want: nil,
|
|
wantIsErr: errors.InvalidParameter,
|
|
wantErrMsg: "authtoken.NewRepository: nil db reader: parameter violation: error #100",
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert := assert.New(t)
|
|
got, err := NewRepository(context.Background(), tt.args.r, tt.args.w, tt.args.kms, tt.args.opts...)
|
|
if tt.wantIsErr != 0 {
|
|
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
|
|
assert.Equal(tt.wantErrMsg, err.Error())
|
|
return
|
|
}
|
|
assert.NoError(err)
|
|
assert.NotNil(got)
|
|
assert.Equal(tt.want, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_CreateAuthToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo := iam.TestRepo(t, conn, wrapper)
|
|
|
|
org1, _ := iam.TestScopes(t, repo)
|
|
am := password.TestAuthMethods(t, conn, org1.GetPublicId(), 1)[0]
|
|
aAcct := password.TestAccount(t, conn, am.GetPublicId(), "name1")
|
|
iam.TestSetPrimaryAuthMethod(t, repo, org1, am.PublicId)
|
|
u1 := iam.TestUser(t, repo, org1.PublicId, iam.WithAccountIds(aAcct.PublicId))
|
|
|
|
org2, _ := iam.TestScopes(t, repo)
|
|
u2 := iam.TestUser(t, repo, org2.GetPublicId())
|
|
|
|
testId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
iamUser *iam.User
|
|
authAcctId string
|
|
opt []Option
|
|
want *AuthToken
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid",
|
|
iamUser: u1,
|
|
authAcctId: aAcct.GetPublicId(),
|
|
want: &AuthToken{
|
|
AuthToken: &store.AuthToken{
|
|
AuthAccountId: aAcct.GetPublicId(),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "WithPublicId-WithStatus",
|
|
iamUser: u1,
|
|
authAcctId: aAcct.GetPublicId(),
|
|
opt: []Option{WithPublicId(testId), WithStatus(PendingStatus)},
|
|
want: &AuthToken{
|
|
AuthToken: &store.AuthToken{
|
|
AuthAccountId: aAcct.GetPublicId(),
|
|
PublicId: testId,
|
|
Status: string(PendingStatus),
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "unconnected-authaccount-user",
|
|
iamUser: u2,
|
|
authAcctId: aAcct.GetPublicId(),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "no-authacctid",
|
|
iamUser: u1,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "no-userid",
|
|
authAcctId: aAcct.GetPublicId(),
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid-authacctid",
|
|
iamUser: u1,
|
|
authAcctId: "this_is_invalid",
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "invalid-userid",
|
|
iamUser: func() *iam.User { u := u1.Clone().(*iam.User); u.PublicId = "this_is_invalid"; return u }(),
|
|
authAcctId: aAcct.GetPublicId(),
|
|
wantErr: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(err)
|
|
require.NotNil(repo)
|
|
got, err := repo.CreateAuthToken(ctx, tt.iamUser, tt.authAcctId, tt.opt...)
|
|
if tt.wantErr {
|
|
assert.Error(err)
|
|
assert.Nil(got)
|
|
return
|
|
}
|
|
require.NoError(err, "Got error for CreateAuthToken(ctx, %v, %v)", tt.iamUser, tt.authAcctId)
|
|
assert.NotNil(got)
|
|
db.AssertPublicId(t, globals.AuthTokenPrefix, got.PublicId)
|
|
assert.Equal(tt.authAcctId, got.GetAuthAccountId())
|
|
assert.Equal(got.CreateTime, got.UpdateTime)
|
|
assert.Equal(got.CreateTime, got.ApproximateLastAccessTime)
|
|
|
|
opts := getOpts(tt.opt...)
|
|
if opts.withPublicId != "" {
|
|
assert.Equal(opts.withPublicId, got.PublicId)
|
|
}
|
|
if opts.withStatus != "" {
|
|
assert.Equal(string(opts.withStatus), got.Status)
|
|
}
|
|
|
|
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
|
|
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_CREATE)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_LookupAuthToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo := iam.TestRepo(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, repo)
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at.Token = ""
|
|
at.CtToken = nil
|
|
at.KeyId = ""
|
|
|
|
badId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, badId)
|
|
|
|
tests := []struct {
|
|
name string
|
|
id string
|
|
want *AuthToken
|
|
wantIsErr errors.Code
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "found",
|
|
id: at.GetPublicId(),
|
|
want: at,
|
|
},
|
|
{
|
|
name: "not-found",
|
|
id: badId,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "bad-public-id",
|
|
id: "",
|
|
want: nil,
|
|
wantIsErr: errors.InvalidPublicId,
|
|
wantErrMsg: "authtoken.(Repository).LookupAuthToken: missing public id: parameter violation: error #102",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(err)
|
|
require.NotNil(repo)
|
|
|
|
got, err := repo.LookupAuthToken(ctx, tt.id)
|
|
if tt.wantIsErr != 0 {
|
|
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
|
|
assert.Equal(tt.wantErrMsg, err.Error())
|
|
return
|
|
}
|
|
assert.NoError(err)
|
|
if got == nil {
|
|
assert.Nil(tt.want)
|
|
return
|
|
}
|
|
require.NotNil(tt.want, "got %v, wanted nil", got)
|
|
// TODO: This test fails by a very small amount -- 500 nanos ish in
|
|
// my experience -- if they are required to be equal. I think this
|
|
// is because the resolution of the timestamp in the db does not
|
|
// match the resolution in Go code. But might be worth checking
|
|
// into.
|
|
wantGoTimeExpr := tt.want.AuthToken.GetExpirationTime().Timestamp.AsTime()
|
|
gotGoTimeExpr := got.AuthToken.GetExpirationTime().Timestamp.AsTime()
|
|
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
|
|
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
|
|
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_ValidateToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
lastAccessedUpdateDuration = 0
|
|
timeSkew = 20 * time.Millisecond
|
|
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, repo)
|
|
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
atToken := at.GetToken()
|
|
at.Token = ""
|
|
at.CtToken = nil
|
|
at.KeyId = ""
|
|
atTime := at.GetApproximateLastAccessTime()
|
|
require.NotNil(t, atTime)
|
|
|
|
badId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, badId)
|
|
|
|
badToken, err := newAuthToken(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, badToken)
|
|
|
|
tests := []struct {
|
|
name string
|
|
id string
|
|
token string
|
|
want *AuthToken
|
|
wantIsErr errors.Code
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "exists",
|
|
id: at.GetPublicId(),
|
|
token: atToken,
|
|
want: at,
|
|
},
|
|
{
|
|
name: "doesnt-exist",
|
|
id: badId,
|
|
token: badToken.Token,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "empty-token",
|
|
id: at.GetPublicId(),
|
|
token: "",
|
|
want: nil,
|
|
wantIsErr: errors.InvalidParameter,
|
|
wantErrMsg: "authtoken.(Repository).ValidateToken: missing token: parameter violation: error #100",
|
|
},
|
|
{
|
|
name: "mismatched-token",
|
|
id: at.GetPublicId(),
|
|
token: badToken.Token,
|
|
want: nil,
|
|
wantIsErr: errors.Unknown,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
got, err := repo.ValidateToken(ctx, tt.id, tt.token)
|
|
|
|
if tt.wantIsErr != 0 {
|
|
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
|
|
assert.Equal(tt.wantErrMsg, err.Error())
|
|
return
|
|
}
|
|
assert.NoError(err)
|
|
if got == nil {
|
|
assert.Nil(tt.want)
|
|
// No need to compare updated time if we didn't get an initial auth token to compare against.
|
|
return
|
|
}
|
|
require.NotNil(tt.want, "Got %v but wanted nil", got)
|
|
|
|
// NOTE: See comment in LookupAuthToken about this logic
|
|
wantGoTimeExpr := tt.want.AuthToken.GetExpirationTime().AsTime()
|
|
gotGoTimeExpr := got.AuthToken.GetExpirationTime().AsTime()
|
|
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
|
|
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
|
|
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
|
|
|
|
// preTime1 should be the value prior to the ValidateToken was called so it should equal creation time
|
|
preTime1 := got.GetApproximateLastAccessTime()
|
|
require.NoError(err)
|
|
assert.True(preTime1.AsTime().Equal(atTime.AsTime()), "Create time %q doesn't match the time from the first call to MaybeUpdateLastAccesssed: %q.", atTime, preTime1)
|
|
|
|
// Enable the duration which limits how frequently a token's approximate last accessed time can be updated
|
|
// so the next call doesn't cause the last accessed time to be updated.
|
|
lastAccessedUpdateDuration = 1 * time.Hour
|
|
|
|
got2, err := repo.ValidateToken(ctx, tt.id, tt.token)
|
|
assert.NoError(err)
|
|
preTime2 := got2.GetApproximateLastAccessTime().GetTimestamp()
|
|
assert.True(preTime2.AsTime().After(preTime1.AsTime()), "First updated time %q was not after the creation time %q", preTime2, preTime1)
|
|
|
|
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
|
|
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE)))
|
|
|
|
got3, err := repo.ValidateToken(ctx, tt.id, tt.token)
|
|
require.NoError(err)
|
|
preTime3 := got3.GetApproximateLastAccessTime()
|
|
assert.True(preTime3.AsTime().Equal(preTime2.AsTime()), "The 3rd timestamp %q was not equal to the second time %q", preTime3, preTime2)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_ValidateToken_expired(t *testing.T) {
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
baseAT := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
baseAT.GetAuthAccountId()
|
|
aAcct := allocAuthAccount()
|
|
aAcct.PublicId = baseAT.GetAuthAccountId()
|
|
require.NoError(t, rw.LookupByPublicId(context.Background(), aAcct))
|
|
iamUser, _, err := iamRepo.LookupUser(context.Background(), aAcct.GetIamUserId())
|
|
require.NoError(t, err)
|
|
require.NotNil(t, iamUser)
|
|
|
|
tests := []struct {
|
|
name string
|
|
staleDuration time.Duration
|
|
expirationDuration time.Duration
|
|
wantReturned bool
|
|
}{
|
|
{
|
|
name: "not-stale-or-expired",
|
|
staleDuration: defaultTokenTimeToStaleDuration,
|
|
expirationDuration: defaultTokenTimeToLiveDuration,
|
|
wantReturned: true,
|
|
},
|
|
{
|
|
name: "stale",
|
|
staleDuration: 1 * time.Millisecond,
|
|
expirationDuration: defaultTokenTimeToLiveDuration,
|
|
wantReturned: false,
|
|
},
|
|
{
|
|
name: "expired",
|
|
staleDuration: defaultTokenTimeToStaleDuration,
|
|
expirationDuration: 1 * time.Millisecond,
|
|
wantReturned: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
ctx := context.Background()
|
|
|
|
timeSkew = 20 * time.Millisecond
|
|
|
|
repo, err := NewRepository(ctx, rw, rw, kms,
|
|
WithTokenTimeToLiveDuration(tt.expirationDuration),
|
|
WithTokenTimeToStaleDuration(tt.staleDuration))
|
|
require.NoError(err)
|
|
require.NotNil(repo)
|
|
|
|
at, err := repo.CreateAuthToken(ctx, iamUser, baseAT.GetAuthAccountId())
|
|
require.NoError(err)
|
|
|
|
got, err := repo.ValidateToken(ctx, at.GetPublicId(), at.GetToken())
|
|
require.NoError(err)
|
|
|
|
if tt.wantReturned {
|
|
assert.NotNil(got)
|
|
} else {
|
|
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
|
|
assert.Error(db.TestVerifyOplog(t, rw, at.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
|
|
assert.Nil(got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_DeleteAuthToken(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo := iam.TestRepo(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, repo)
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
badId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, badId)
|
|
|
|
tests := []struct {
|
|
name string
|
|
id string
|
|
want int
|
|
wantIsErr errors.Code
|
|
wantErrMsg string
|
|
}{
|
|
{
|
|
name: "found",
|
|
id: at.GetPublicId(),
|
|
want: 1,
|
|
},
|
|
{
|
|
name: "not-found",
|
|
id: badId,
|
|
want: 0,
|
|
},
|
|
{
|
|
name: "empty-public-id",
|
|
id: "",
|
|
want: 0,
|
|
wantIsErr: errors.InvalidPublicId,
|
|
wantErrMsg: "authtoken.(Repository).DeleteAuthToken: missing public id: parameter violation: error #102",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(err)
|
|
require.NotNil(repo)
|
|
|
|
got, err := repo.DeleteAuthToken(ctx, tt.id)
|
|
if tt.wantIsErr != 0 {
|
|
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "Unexpected error %s", err)
|
|
assert.Equal(tt.wantErrMsg, err.Error())
|
|
return
|
|
}
|
|
assert.NoError(err)
|
|
assert.Equal(tt.want, got, "row count")
|
|
if tt.want != 0 {
|
|
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
|
|
assert.Error(db.TestVerifyOplog(t, rw, tt.id, db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRepository_ListAuthTokens(t *testing.T) {
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo := iam.TestRepo(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, repo)
|
|
at1 := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at1.Token = ""
|
|
at1.KeyId = ""
|
|
at2 := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at2.Token = ""
|
|
at2.KeyId = ""
|
|
at3 := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at3.Token = ""
|
|
at3.KeyId = ""
|
|
|
|
emptyOrg, _ := iam.TestScopes(t, repo)
|
|
|
|
tests := []struct {
|
|
name string
|
|
orgId string
|
|
want []*AuthToken
|
|
wantTTime time.Time
|
|
}{
|
|
{
|
|
name: "populated",
|
|
orgId: org.GetPublicId(),
|
|
want: []*AuthToken{at1, at2, at3},
|
|
wantTTime: time.Now(),
|
|
},
|
|
{
|
|
name: "empty",
|
|
orgId: emptyOrg.GetPublicId(),
|
|
want: []*AuthToken{},
|
|
wantTTime: time.Now(),
|
|
},
|
|
{
|
|
name: "empty-org-id",
|
|
orgId: "",
|
|
want: []*AuthToken{},
|
|
wantTTime: time.Now(),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(err)
|
|
require.NotNil(repo)
|
|
got, ttime, err := repo.listAuthTokens(ctx, []string{tt.orgId})
|
|
assert.NoError(err)
|
|
sort.Slice(tt.want, func(i, j int) bool { return tt.want[i].PublicId < tt.want[j].PublicId })
|
|
sort.Slice(got, func(i, j int) bool { return got[i].PublicId < got[j].PublicId })
|
|
assert.Empty(cmp.Diff(tt.want, got, protocmp.Transform()), "row count")
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(tt.wantTTime.Before(ttime.Add(10 * time.Second)))
|
|
assert.True(tt.wantTTime.After(ttime.Add(-10 * time.Second)))
|
|
})
|
|
}
|
|
t.Run("withStartPageAfter", func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
for i := 0; i < 7; i++ {
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at.Token = ""
|
|
at.KeyId = ""
|
|
}
|
|
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(err)
|
|
page1, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2))
|
|
require.NoError(err)
|
|
require.Len(page1, 2)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
page2, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page1[1]))
|
|
require.NoError(err)
|
|
require.Len(page2, 2)
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
for _, item := range page1 {
|
|
assert.NotEqual(item.GetPublicId(), page2[0].GetPublicId())
|
|
assert.NotEqual(item.GetPublicId(), page2[1].GetPublicId())
|
|
}
|
|
page3, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page2[1]))
|
|
require.NoError(err)
|
|
require.Len(page3, 2)
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
for _, item := range page2 {
|
|
assert.NotEqual(item.GetPublicId(), page3[0].GetPublicId())
|
|
assert.NotEqual(item.GetPublicId(), page3[1].GetPublicId())
|
|
}
|
|
page4, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page3[1]))
|
|
require.NoError(err)
|
|
require.Len(page4, 2)
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
for _, item := range page3 {
|
|
assert.NotEqual(item.GetPublicId(), page4[0].GetPublicId())
|
|
assert.NotEqual(item.GetPublicId(), page4[1].GetPublicId())
|
|
}
|
|
page5, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page4[1]))
|
|
require.NoError(err)
|
|
require.Len(page5, 2)
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
for _, item := range page4 {
|
|
assert.NotEqual(item.GetPublicId(), page5[0].GetPublicId())
|
|
assert.NotEqual(item.GetPublicId(), page5[1].GetPublicId())
|
|
}
|
|
page6, ttime, err := repo.listAuthTokens(ctx, []string{org.GetPublicId()}, WithLimit(2), WithStartPageAfterItem(page5[1]))
|
|
require.NoError(err)
|
|
require.Empty(page6)
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
|
|
// Create 2 new auth tokens
|
|
newAt1 := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
newAt1.Token = ""
|
|
newAt1.KeyId = ""
|
|
newAt2 := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
newAt2.Token = ""
|
|
newAt2.KeyId = ""
|
|
|
|
// since it will return newest to oldest, we get page1[1] first
|
|
page7, ttime, err := repo.listAuthTokensRefresh(
|
|
ctx,
|
|
time.Now().Add(-1*time.Second),
|
|
[]string{org.GetPublicId()},
|
|
WithLimit(1),
|
|
)
|
|
require.NoError(err)
|
|
require.Len(page7, 1)
|
|
require.Equal(page7[0].GetPublicId(), newAt2.GetPublicId())
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
|
|
page8, ttime, err := repo.listAuthTokensRefresh(
|
|
context.Background(),
|
|
time.Now().Add(-1*time.Second),
|
|
[]string{org.GetPublicId()},
|
|
WithLimit(1),
|
|
WithStartPageAfterItem(page7[0]),
|
|
)
|
|
require.NoError(err)
|
|
require.Len(page8, 1)
|
|
require.Equal(page8[0].GetPublicId(), newAt1.GetPublicId())
|
|
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
|
|
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
|
|
})
|
|
}
|
|
|
|
func TestRepository_ListAuthTokens_Multiple_Scopes(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(t, err)
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
org, proj := iam.TestScopes(t, iamRepo)
|
|
|
|
const numPerScope = 10
|
|
var total int
|
|
for i := 0; i < numPerScope; i++ {
|
|
TestAuthToken(t, conn, kms, "global")
|
|
total++
|
|
TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
total++
|
|
}
|
|
|
|
got, ttime, err := repo.listAuthTokens(ctx, []string{"global", org.GetPublicId(), proj.GetPublicId()})
|
|
require.NoError(t, err)
|
|
assert.Equal(t, total, len(got)) // Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
|
|
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
|
|
}
|
|
|
|
func Test_IssuePendingToken(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
repo, err := NewRepository(ctx, rw, rw, kmsCache)
|
|
require.NoError(t, err)
|
|
|
|
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, rootWrapper))
|
|
|
|
tests := []struct {
|
|
name string
|
|
tokenRequestId string
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "missing-id",
|
|
wantErrMatch: errors.T(errors.InvalidParameter),
|
|
wantErrContains: "missing token request id",
|
|
},
|
|
{
|
|
name: "not-found",
|
|
tokenRequestId: func() string {
|
|
tokenPublicId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
tk := TestAuthToken(t, conn, kmsCache, org.PublicId, WithPublicId(tokenPublicId))
|
|
return tk.PublicId
|
|
}(),
|
|
wantErrMatch: errors.T(errors.RecordNotFound),
|
|
wantErrContains: "pending auth token",
|
|
},
|
|
{
|
|
name: "success",
|
|
tokenRequestId: func() string {
|
|
tokenPublicId, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
tk := TestAuthToken(t, conn, kmsCache, org.PublicId, WithStatus(PendingStatus), WithPublicId(tokenPublicId))
|
|
return tk.PublicId
|
|
}(),
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
tk, err := repo.IssueAuthToken(ctx, tt.tokenRequestId)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %s and got: %+v", tt.wantErrMatch.Code, err)
|
|
if tt.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tt.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.NotEmpty(tk)
|
|
accessTime := tk.GetApproximateLastAccessTime()
|
|
createTime := tk.GetCreateTime().GetTimestamp()
|
|
assert.True(accessTime.AsTime().After(createTime.AsTime()), "last access time %q was not after the creation time %q", accessTime, createTime)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_CloseExpiredPendingTokens(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
rootWrapper := db.TestWrapper(t)
|
|
kmsCache := kms.TestKms(t, conn, rootWrapper)
|
|
repo, err := NewRepository(ctx, rw, rw, kmsCache)
|
|
require.NoError(t, err)
|
|
|
|
iamRepo := iam.TestRepo(t, conn, rootWrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
user := iam.TestUser(t, iamRepo, org.GetPublicId())
|
|
databaseWrapper, err := kmsCache.GetWrapper(ctx, org.PublicId, kms.KeyPurposeDatabase)
|
|
require.NoError(t, err)
|
|
|
|
createWith := func(cnt int, expIn time.Duration, status Status) {
|
|
authMethods := password.TestAuthMethods(t, conn, org.PublicId, 1)
|
|
authMethodId := authMethods[0].PublicId
|
|
|
|
accts := password.TestMultipleAccounts(t, conn, authMethodId, cnt)
|
|
for i := 0; i < cnt; i++ {
|
|
user, _, err = iamRepo.LookupUser(ctx, user.GetPublicId())
|
|
require.NoError(t, err)
|
|
|
|
_, _ = iamRepo.AddUserAccounts(ctx, user.GetPublicId(), user.GetVersion(), []string{accts[i].GetPublicId()})
|
|
|
|
at := allocAuthToken()
|
|
id, err := NewAuthTokenId(ctx)
|
|
require.NoError(t, err)
|
|
at.PublicId = id
|
|
exp := timestamppb.New(time.Now().Add(expIn).Truncate(time.Second))
|
|
at.ExpirationTime = ×tamp.Timestamp{Timestamp: exp}
|
|
at.Status = string(status)
|
|
at.AuthAccountId = accts[i].PublicId
|
|
keyId, err := databaseWrapper.KeyId(ctx)
|
|
require.NoError(t, err)
|
|
at.KeyId = keyId
|
|
at.CtToken = []byte(id)
|
|
err = rw.Create(ctx, at)
|
|
require.NoError(t, err)
|
|
}
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
setup func()
|
|
wantCnt int
|
|
wantErrMatch *errors.Template
|
|
wantErrContains string
|
|
}{
|
|
{
|
|
name: "nada-todo",
|
|
},
|
|
{
|
|
name: "close-2",
|
|
setup: func() {
|
|
createWith(2, -10*time.Second, PendingStatus)
|
|
createWith(1, 10*time.Second, IssuedStatus)
|
|
createWith(1, 10*time.Second, PendingStatus)
|
|
},
|
|
wantCnt: 2,
|
|
},
|
|
{
|
|
name: "close-zero",
|
|
setup: func() {
|
|
createWith(2, 10*time.Second, PendingStatus)
|
|
createWith(1, 10*time.Second, IssuedStatus)
|
|
},
|
|
wantCnt: 0,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
assert, require := assert.New(t), require.New(t)
|
|
|
|
// start with no tokens in the db.
|
|
_, err := rw.Exec(ctx, "delete from auth_token", nil)
|
|
require.NoError(err)
|
|
_, err = rw.Exec(ctx, "delete from auth_account", nil)
|
|
require.NoError(err)
|
|
|
|
if tt.setup != nil {
|
|
tt.setup()
|
|
}
|
|
tokensClosed, err := repo.CloseExpiredPendingTokens(ctx)
|
|
if tt.wantErrMatch != nil {
|
|
require.Error(err)
|
|
assert.Equal(0, tokensClosed)
|
|
assert.Truef(errors.Match(tt.wantErrMatch, err), "wanted %s and got: %s", tt.wantErrMatch.Code, err.Error())
|
|
if tt.wantErrContains != "" {
|
|
assert.Contains(err.Error(), tt.wantErrContains)
|
|
}
|
|
return
|
|
}
|
|
require.NoError(err)
|
|
assert.Equal(tt.wantCnt, tokensClosed)
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_listDeletedIds(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at.Token = ""
|
|
at.KeyId = ""
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(t, err)
|
|
|
|
// Expect no entries at the start
|
|
deletedIds, ttime, err := repo.listDeletedIds(ctx, time.Now().AddDate(-1, 0, 0))
|
|
require.NoError(t, err)
|
|
require.Empty(t, deletedIds)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
|
|
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
|
|
|
|
// Delete an auth token
|
|
_, err = repo.DeleteAuthToken(ctx, at.GetPublicId())
|
|
require.NoError(t, err)
|
|
|
|
// Expect a single entry
|
|
deletedIds, ttime, err = repo.listDeletedIds(ctx, time.Now().AddDate(-1, 0, 0))
|
|
require.NoError(t, err)
|
|
require.Equal(t, []string{at.GetPublicId()}, deletedIds)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
|
|
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
|
|
|
|
// Try again with the time set to now, expect no entries
|
|
deletedIds, ttime, err = repo.listDeletedIds(ctx, time.Now())
|
|
require.NoError(t, err)
|
|
require.Empty(t, deletedIds)
|
|
// Transaction timestamp should be within ~10 seconds of now
|
|
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
|
|
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
|
|
}
|
|
|
|
func Test_estimatedCount(t *testing.T) {
|
|
t.Parallel()
|
|
ctx := context.Background()
|
|
conn, _ := db.TestSetup(t, "postgres")
|
|
sqlDb, err := conn.SqlDB(ctx)
|
|
require.NoError(t, err)
|
|
rw := db.New(conn)
|
|
wrapper := db.TestWrapper(t)
|
|
kms := kms.TestKms(t, conn, wrapper)
|
|
repo, err := NewRepository(ctx, rw, rw, kms)
|
|
require.NoError(t, err)
|
|
|
|
// Run analyze to update estimate
|
|
_, err = sqlDb.ExecContext(ctx, "analyze")
|
|
require.NoError(t, err)
|
|
|
|
// Check total entries at start, expect 0
|
|
numItems, err := repo.estimatedCount(ctx)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 0, numItems)
|
|
|
|
iamRepo := iam.TestRepo(t, conn, wrapper)
|
|
org, _ := iam.TestScopes(t, iamRepo)
|
|
// Create an auth token, expect 1 entry
|
|
at := TestAuthToken(t, conn, kms, org.GetPublicId())
|
|
at.Token = ""
|
|
at.KeyId = ""
|
|
|
|
// Run analyze to update estimate
|
|
_, err = sqlDb.ExecContext(ctx, "analyze")
|
|
require.NoError(t, err)
|
|
numItems, err = repo.estimatedCount(ctx)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 1, numItems)
|
|
|
|
// Delete the auth token, expect 0 again
|
|
_, err = repo.DeleteAuthToken(ctx, at.GetPublicId())
|
|
require.NoError(t, err)
|
|
|
|
// Run analyze to update estimate
|
|
_, err = sqlDb.ExecContext(ctx, "analyze")
|
|
require.NoError(t, err)
|
|
numItems, err = repo.estimatedCount(ctx)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, 0, numItems)
|
|
}
|