You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/kms/repository_test.go

252 lines
6.3 KiB

package kms_test
import (
"context"
"crypto/rand"
"errors"
"io"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewRepository(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
type args struct {
r db.Reader
w db.Writer
}
tests := []struct {
name string
args args
want *kms.Repository
wantErr bool
wantErrString string
}{
{
name: "valid",
args: args{
r: rw,
w: rw,
},
want: func() *kms.Repository {
ret, err := kms.NewRepository(rw, rw)
require.NoError(t, err)
return ret
}(),
wantErr: false,
},
{
name: "nil-writer",
args: args{
r: rw,
w: nil,
},
want: nil,
wantErr: true,
wantErrString: "error creating db repository with nil writer",
},
{
name: "nil-reader",
args: args{
r: nil,
w: rw,
},
want: nil,
wantErr: true,
wantErrString: "error creating db repository with nil reader",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
got, err := kms.NewRepository(tt.args.r, tt.args.w)
if tt.wantErr {
require.Error(err)
assert.Equal(err.Error(), tt.wantErrString)
return
}
require.NoError(err)
assert.Equal(tt.want, got)
})
}
}
func TestCreateKeysTx(t *testing.T) {
t.Parallel()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
org, _ := iam.TestScopes(t, iam.TestRepo(t, conn, wrapper))
require.NoError(t, conn.Where("1=1").Delete(kms.AllocRootKey()).Error)
type args struct {
ctx context.Context
dbReader db.Reader
dbWriter db.Writer
rootWrapper wrapping.Wrapper
randomReader io.Reader
scopeId string
}
tests := []struct {
name string
args args
wantErr bool
wantErrIs error
}{
{
name: "valid",
args: args{
dbReader: rw,
dbWriter: rw,
rootWrapper: wrapper,
randomReader: rand.Reader,
scopeId: org.PublicId,
},
},
{
name: "nil-reader",
args: args{
dbReader: nil,
dbWriter: rw,
rootWrapper: wrapper,
randomReader: rand.Reader,
scopeId: org.PublicId,
},
wantErr: true,
wantErrIs: db.ErrInvalidParameter,
},
{
name: "nil-writer",
args: args{
dbReader: rw,
dbWriter: nil,
rootWrapper: wrapper,
randomReader: rand.Reader,
scopeId: org.PublicId,
},
wantErr: true,
wantErrIs: db.ErrInvalidParameter,
},
{
name: "nil-wrapper",
args: args{
dbReader: rw,
dbWriter: rw,
rootWrapper: nil,
randomReader: rand.Reader,
scopeId: org.PublicId,
},
wantErr: true,
wantErrIs: db.ErrInvalidParameter,
},
{
name: "empty-scope",
args: args{
dbReader: rw,
dbWriter: rw,
rootWrapper: wrapper,
randomReader: rand.Reader,
scopeId: "",
},
wantErr: true,
wantErrIs: db.ErrInvalidParameter,
},
{
name: "bad-scope",
args: args{
dbReader: rw,
dbWriter: rw,
rootWrapper: wrapper,
randomReader: rand.Reader,
scopeId: "o_thisIsAnInvalidScopeId",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
keys, err := kms.CreateKeysTx(tt.args.ctx, tt.args.dbReader, tt.args.dbWriter, tt.args.rootWrapper, tt.args.randomReader, tt.args.scopeId)
if tt.wantErr {
require.Error(err)
if tt.wantErrIs != nil {
assert.Truef(errors.Is(err, tt.wantErrIs), "unexpected error: %s", err.Error())
}
return
}
require.NoError(err)
rk := kms.AllocRootKey()
rk.PrivateId = keys[kms.KeyTypeRootKey].GetPrivateId()
err = rw.LookupById(context.Background(), &rk)
require.NoError(err)
assert.Equal(rk.ScopeId, tt.args.scopeId)
rkv := kms.AllocRootKeyVersion()
rkv.PrivateId = keys[kms.KeyTypeRootKeyVersion].GetPrivateId()
err = rw.LookupById(context.Background(), &rkv)
require.NoError(err)
assert.Equal(rkv.RootKeyId, rk.PrivateId)
dk := kms.AllocDatabaseKey()
dk.PrivateId = keys[kms.KeyTypeDatabaseKey].GetPrivateId()
err = rw.LookupById(context.Background(), &dk)
require.NoError(err)
assert.Equal(rk.PrivateId, dk.RootKeyId)
dkv := kms.AllocDatabaseKeyVersion()
dkv.PrivateId = keys[kms.KeyTypeDatabaseKeyVersion].GetPrivateId()
err = rw.LookupById(context.Background(), &dkv)
require.NoError(err)
assert.Equal(dk.PrivateId, dkv.DatabaseKeyId)
assert.Equal(rkv.PrivateId, dkv.RootKeyVersionId)
opk := kms.AllocOplogKey()
opk.PrivateId = keys[kms.KeyTypeOplogKey].GetPrivateId()
err = rw.LookupById(context.Background(), &opk)
require.NoError(err)
assert.Equal(rk.PrivateId, opk.RootKeyId)
opkv := kms.AllocOplogKeyVersion()
opkv.PrivateId = keys[kms.KeyTypeOplogKeyVersion].GetPrivateId()
err = rw.LookupById(context.Background(), &opkv)
require.NoError(err)
assert.Equal(opk.PrivateId, opkv.OplogKeyId)
assert.Equal(rkv.PrivateId, opkv.RootKeyVersionId)
sk := kms.AllocSessionKey()
sk.PrivateId = keys[kms.KeyTypeSessionKey].GetPrivateId()
err = rw.LookupById(context.Background(), &sk)
require.NoError(err)
assert.Equal(rk.PrivateId, sk.RootKeyId)
skv := kms.AllocSessionKeyVersion()
skv.PrivateId = keys[kms.KeyTypeSessionKeyVersion].GetPrivateId()
err = rw.LookupById(context.Background(), &skv)
require.NoError(err)
assert.Equal(sk.PrivateId, skv.SessionKeyId)
assert.Equal(rkv.PrivateId, skv.RootKeyVersionId)
tk := kms.AllocTokenKey()
tk.PrivateId = keys[kms.KeyTypeTokenKey].GetPrivateId()
err = rw.LookupById(context.Background(), &tk)
require.NoError(err)
assert.Equal(rk.PrivateId, tk.RootKeyId)
tkv := kms.AllocTokenKeyVersion()
tkv.PrivateId = keys[kms.KeyTypeTokenKeyVersion].GetPrivateId()
err = rw.LookupById(context.Background(), &tkv)
require.NoError(err)
assert.Equal(tk.PrivateId, tkv.TokenKeyId)
assert.Equal(rkv.PrivateId, tkv.RootKeyVersionId)
})
}
}