You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/authtoken/repository_test.go

615 lines
16 KiB

package authtoken
import (
"context"
"errors"
"sort"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/google/go-cmp/cmp"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"github.com/hashicorp/boundary/internal/authtoken/store"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
)
func TestRepository_New(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
type args struct {
r db.Reader
w db.Writer
kms *kms.Kms
opts []Option
}
var tests = []struct {
name string
args args
want *Repository
wantIsErr error
}{
{
name: "valid default limit",
args: args{
r: rw,
w: rw,
kms: kmsCache,
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
defaultLimit: db.DefaultLimit,
},
},
{
name: "valid new limit",
args: args{
r: rw,
w: rw,
kms: kmsCache,
opts: []Option{WithLimit(5)},
},
want: &Repository{
reader: rw,
writer: rw,
kms: kmsCache,
defaultLimit: 5,
},
},
{
name: "nil-reader",
args: args{
r: nil,
w: rw,
kms: kmsCache,
},
want: nil,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "nil-writer",
args: args{
r: rw,
w: nil,
kms: kmsCache,
},
want: nil,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "nil-kms",
args: args{
r: rw,
w: rw,
kms: nil,
},
want: nil,
wantIsErr: db.ErrInvalidParameter,
},
{
name: "all-nils",
args: args{
r: nil,
w: nil,
kms: nil,
},
want: nil,
wantIsErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, err := NewRepository(tt.args.r, tt.args.w, tt.args.kms, tt.args.opts...)
if tt.wantIsErr != nil {
assert.Truef(errors.Is(err, tt.wantIsErr), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
assert.NoError(err)
assert.NotNil(got)
assert.Equal(tt.want, got)
})
}
}
func TestRepository_CreateAuthToken(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org1, _ := iam.TestScopes(t, repo)
am := password.TestAuthMethods(t, conn, org1.GetPublicId(), 1)[0]
aAcct := password.TestAccounts(t, conn, am.GetPublicId(), 1)[0]
iamRepo, err := iam.NewRepository(rw, rw, kms)
require.NoError(t, err)
u1, err := iamRepo.LookupUserWithLogin(context.Background(), aAcct.GetPublicId(), iam.WithAutoVivify(true))
require.NoError(t, err)
org2, _ := iam.TestScopes(t, repo)
u2 := iam.TestUser(t, repo, org2.GetPublicId())
var tests = []struct {
name string
iamUser *iam.User
authAcctId string
want *AuthToken
wantErr bool
}{
{
name: "valid",
iamUser: u1,
authAcctId: aAcct.GetPublicId(),
want: &AuthToken{
AuthToken: &store.AuthToken{
AuthAccountId: aAcct.GetPublicId(),
},
},
},
{
name: "unconnected-authaccount-user",
iamUser: u2,
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
{
name: "no-authacctid",
iamUser: u1,
wantErr: true,
},
{
name: "no-userid",
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
{
name: "invalid-authacctid",
iamUser: u1,
authAcctId: "this_is_invalid",
wantErr: true,
},
{
name: "invalid-userid",
iamUser: func() *iam.User { u := u1.Clone().(*iam.User); u.PublicId = "this_is_invalid"; return u }(),
authAcctId: aAcct.GetPublicId(),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.CreateAuthToken(context.Background(), tt.iamUser, tt.authAcctId)
if tt.wantErr {
assert.Error(err)
assert.Nil(got)
return
}
require.NoError(err, "Got error for CreateAuthToken(ctx, %v, %v)", tt.iamUser, tt.authAcctId)
assert.NotNil(got)
db.AssertPublicId(t, AuthTokenPrefix, got.PublicId)
assert.Equal(tt.authAcctId, got.GetAuthAccountId())
assert.Equal(got.CreateTime, got.UpdateTime)
assert.Equal(got.CreateTime, got.ApproximateLastAccessTime)
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_CREATE)))
})
}
}
func TestRepository_LookupAuthToken(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
at.Token = ""
at.CtToken = nil
at.KeyId = ""
badId, err := newAuthTokenId()
require.NoError(t, err)
require.NotNil(t, badId)
var tests = []struct {
name string
id string
want *AuthToken
wantErr error
}{
{
name: "found",
id: at.GetPublicId(),
want: at,
},
{
name: "not-found",
id: badId,
want: nil,
},
{
name: "bad-public-id",
id: "",
want: nil,
wantErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.LookupAuthToken(context.Background(), tt.id)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
if got == nil {
assert.Nil(tt.want)
return
}
require.NotNil(tt.want, "got %v, wanted nil", got)
// TODO: This test fails by a very small amount -- 500 nanos ish in
// my experience -- if they are required to be equal. I think this
// is because the resolution of the timestamp in the db does not
// match the resolution in Go code. But might be worth checking
// into.
wantGoTimeExpr, err := ptypes.Timestamp(tt.want.AuthToken.GetExpirationTime().Timestamp)
require.NoError(err)
gotGoTimeExpr, err := ptypes.Timestamp(got.AuthToken.GetExpirationTime().Timestamp)
require.NoError(err)
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
})
}
}
func TestRepository_ValidateToken(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
lastAccessedUpdateDuration = 0
timeSkew = 20 * time.Millisecond
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
require.NotNil(t, repo)
org, _ := iam.TestScopes(t, iamRepo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
atToken := at.GetToken()
at.Token = ""
at.CtToken = nil
at.KeyId = ""
atTime, err := ptypes.Timestamp(at.GetApproximateLastAccessTime().GetTimestamp())
require.NoError(t, err)
require.NotNil(t, atTime)
badId, err := newAuthTokenId()
require.NoError(t, err)
require.NotNil(t, badId)
badToken, err := newAuthToken()
require.NoError(t, err)
require.NotNil(t, badToken)
var tests = []struct {
name string
id string
token string
want *AuthToken
wantErr error
}{
{
name: "exists",
id: at.GetPublicId(),
token: atToken,
want: at,
},
{
name: "doesnt-exist",
id: badId,
token: badToken,
want: nil,
},
{
name: "empty-token",
id: at.GetPublicId(),
token: "",
want: nil,
wantErr: db.ErrInvalidParameter,
},
{
name: "mismatched-token",
id: at.GetPublicId(),
token: badToken,
want: nil,
wantErr: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := repo.ValidateToken(context.Background(), tt.id, tt.token)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
if got == nil {
assert.Nil(tt.want)
// No need to compare updated time if we didn't get an initial auth token to compare against.
return
}
require.NotNil(tt.want, "Got %v but wanted nil", got)
// NOTE: See comment in LookupAuthToken about this logic
wantGoTimeExpr, err := ptypes.Timestamp(tt.want.AuthToken.GetExpirationTime().Timestamp)
require.NoError(err)
gotGoTimeExpr, err := ptypes.Timestamp(got.AuthToken.GetExpirationTime().Timestamp)
require.NoError(err)
assert.WithinDuration(wantGoTimeExpr, gotGoTimeExpr, time.Millisecond)
tt.want.AuthToken.ExpirationTime = got.AuthToken.ExpirationTime
assert.Empty(cmp.Diff(tt.want.AuthToken, got.AuthToken, protocmp.Transform()))
// preTime1 should be the value prior to the ValidateToken was called so it should equal creation time
preTime1, err := ptypes.Timestamp(got.GetApproximateLastAccessTime().GetTimestamp())
require.NoError(err)
assert.True(preTime1.Equal(atTime), "Create time %q doesn't match the time from the first call to MaybeUpdateLastAccesssed: %q.", atTime, preTime1)
// Enable the duration which limits how frequently a token's approximate last accessed time can be updated
// so the next call doesn't cause the last accessed time to be updated.
lastAccessedUpdateDuration = 1 * time.Hour
got2, err := repo.ValidateToken(context.Background(), tt.id, tt.token)
assert.NoError(err)
preTime2, err := ptypes.Timestamp(got2.GetApproximateLastAccessTime().GetTimestamp())
require.NoError(err)
assert.True(preTime2.After(preTime1), "First updated time %q was not after the creation time %q", preTime2, preTime1)
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_UPDATE)))
got3, err := repo.ValidateToken(context.Background(), tt.id, tt.token)
require.NoError(err)
preTime3, err := ptypes.Timestamp(got3.GetApproximateLastAccessTime().GetTimestamp())
require.NoError(err)
assert.True(preTime3.Equal(preTime2), "The 3rd timestamp %q was not equal to the second time %q", preTime3, preTime2)
})
}
}
func TestRepository_ValidateToken_expired(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
repo, err := NewRepository(rw, rw, kms)
require.NoError(t, err)
require.NotNil(t, repo)
org, _ := iam.TestScopes(t, iamRepo)
baseAT := TestAuthToken(t, conn, kms, org.GetPublicId())
baseAT.GetAuthAccountId()
aAcct := allocAuthAccount()
aAcct.PublicId = baseAT.GetAuthAccountId()
require.NoError(t, rw.LookupByPublicId(context.Background(), aAcct))
iamUser, _, err := iamRepo.LookupUser(context.Background(), aAcct.GetIamUserId())
require.NoError(t, err)
require.NotNil(t, iamUser)
defaultStaleTime := maxStaleness
defaultExpireDuration := maxTokenDuration
var tests = []struct {
name string
staleDuration time.Duration
expirationDuration time.Duration
wantReturned bool
}{
{
name: "not-stale-or-expired",
staleDuration: maxStaleness,
expirationDuration: maxTokenDuration,
wantReturned: true,
},
{
name: "stale",
staleDuration: 0,
expirationDuration: maxTokenDuration,
wantReturned: false,
},
{
name: "expired",
staleDuration: maxStaleness,
expirationDuration: 0,
wantReturned: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
maxStaleness = tt.staleDuration
maxTokenDuration = tt.expirationDuration
timeSkew = 20 * time.Millisecond
ctx := context.Background()
at, err := repo.CreateAuthToken(ctx, iamUser, baseAT.GetAuthAccountId())
require.NoError(err)
got, err := repo.ValidateToken(ctx, at.GetPublicId(), at.GetToken())
require.NoError(err)
if tt.wantReturned {
assert.NotNil(got)
} else {
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, at.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
assert.Nil(got)
}
// reset the system default params
maxStaleness = defaultStaleTime
maxTokenDuration = defaultExpireDuration
})
}
}
func TestRepository_DeleteAuthToken(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at := TestAuthToken(t, conn, kms, org.GetPublicId())
badId, err := newAuthTokenId()
require.NoError(t, err)
require.NotNil(t, badId)
var tests = []struct {
name string
id string
want int
wantErr error
}{
{
name: "found",
id: at.GetPublicId(),
want: 1,
},
{
name: "not-found",
id: badId,
want: 0,
},
{
name: "empty-public-id",
id: "",
want: 0,
wantErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.DeleteAuthToken(context.Background(), tt.id)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
assert.Equal(tt.want, got, "row count")
if tt.want != 0 {
// We should find no oplog since tokens are not replicated, so they don't need oplog entries.
assert.Error(db.TestVerifyOplog(t, rw, tt.id, db.WithOperation(oplog.OpType_OP_TYPE_DELETE)))
}
})
}
}
func TestRepository_ListAuthTokens(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
repo := iam.TestRepo(t, conn, wrapper)
org, _ := iam.TestScopes(t, repo)
at1 := TestAuthToken(t, conn, kms, org.GetPublicId())
at1.Token = ""
at1.KeyId = ""
at2 := TestAuthToken(t, conn, kms, org.GetPublicId())
at2.Token = ""
at2.KeyId = ""
at3 := TestAuthToken(t, conn, kms, org.GetPublicId())
at3.Token = ""
at3.KeyId = ""
emptyOrg, _ := iam.TestScopes(t, repo)
var tests = []struct {
name string
orgId string
want []*AuthToken
wantErr error
}{
{
name: "populated",
orgId: org.GetPublicId(),
want: []*AuthToken{at1, at2, at3},
},
{
name: "empty",
orgId: emptyOrg.GetPublicId(),
want: []*AuthToken{},
},
{
name: "empty-org-id",
orgId: "",
want: nil,
wantErr: db.ErrInvalidParameter,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(rw, rw, kms)
require.NoError(err)
require.NotNil(repo)
got, err := repo.ListAuthTokens(context.Background(), tt.orgId)
if tt.wantErr != nil {
assert.Truef(errors.Is(err, tt.wantErr), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
sort.Slice(tt.want, func(i, j int) bool { return tt.want[i].PublicId < tt.want[j].PublicId })
sort.Slice(got, func(i, j int) bool { return got[i].PublicId < got[j].PublicId })
assert.Empty(cmp.Diff(tt.want, got, protocmp.Transform()), "row count")
})
}
}