You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/host/plugin/repository_host_set_test.go

1986 lines
62 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package plugin
import (
"context"
"fmt"
"sort"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/hashicorp/boundary/globals"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/host"
"github.com/hashicorp/boundary/internal/host/plugin/store"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/libs/patchstruct"
"github.com/hashicorp/boundary/internal/oplog"
"github.com/hashicorp/boundary/internal/plugin"
"github.com/hashicorp/boundary/internal/plugin/loopback"
"github.com/hashicorp/boundary/internal/scheduler"
plgpb "github.com/hashicorp/boundary/sdk/pbs/plugin"
"github.com/mitchellh/mapstructure"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/wrapperspb"
)
func TestRepository_CreateSet(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "create")
unimplementedPlugin := plugin.TestPlugin(t, conn, "unimplemented")
catalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
unimplementedPluginCatalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
attrs := []byte{}
const normalizeToSliceKey = "normalize_to_slice"
tests := []struct {
name string
in *HostSet
opts []Option
want *HostSet
wantPluginCalled bool
wantIsErr errors.Code
}{
{
name: "nil-HostSet",
wantIsErr: errors.InvalidParameter,
},
{
name: "nil-embedded-HostSet",
in: &HostSet{},
wantIsErr: errors.InvalidParameter,
},
{
name: "invalid-no-catalog-id",
in: &HostSet{
HostSet: &store.HostSet{
Attributes: attrs,
},
},
wantIsErr: errors.InvalidParameter,
},
{
name: "invalid-public-id-set",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
PublicId: "abcd_OOOOOOOOOO",
Attributes: attrs,
},
},
wantIsErr: errors.InvalidParameter,
},
{
name: "invalid-no-attribte",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
},
},
wantIsErr: errors.InvalidParameter,
},
{
name: "invalid-sync-interval-too-negative",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
SyncIntervalSeconds: -99,
Attributes: attrs,
},
},
wantIsErr: errors.InvalidParameter,
},
{
name: "valid-no-options",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-preferred-endpoints",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Attributes: attrs,
},
PreferredEndpoints: []string{"cidr:1.2.3.4/32", "dns:a.b.c"},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Attributes: attrs,
},
PreferredEndpoints: []string{"cidr:1.2.3.4/32", "dns:a.b.c"},
},
wantPluginCalled: true,
},
{
name: "valid-with-name",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Name: "test-name-repo",
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Name: "test-name-repo",
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-sync-interval-disabled",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
SyncIntervalSeconds: -1,
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
SyncIntervalSeconds: -1,
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-sync-interval-positive",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
SyncIntervalSeconds: 60,
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
SyncIntervalSeconds: 60,
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-unimplemented-plugin",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: unimplementedPluginCatalog.PublicId,
Name: "valid-unimplemented-plugin",
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: unimplementedPluginCatalog.PublicId,
Name: "valid-unimplemented-plugin",
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-with-description",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Description: ("test-description-repo"),
Attributes: attrs,
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Description: ("test-description-repo"),
Attributes: attrs,
},
},
wantPluginCalled: true,
},
{
name: "valid-with-custom-attributes",
in: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Description: ("test-description-repo"),
Attributes: func() []byte {
st, err := structpb.NewStruct(map[string]any{
"k1": nil,
"removed": nil,
normalizeToSliceKey: "normalizeme",
})
require.NoError(t, err)
b, err := proto.Marshal(st)
require.NoError(t, err)
return b
}(),
},
},
want: &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Description: ("test-description-repo"),
Attributes: func() []byte {
b, err := proto.Marshal(&structpb.Struct{Fields: map[string]*structpb.Value{
normalizeToSliceKey: structpb.NewListValue(
&structpb.ListValue{
Values: []*structpb.Value{
structpb.NewStringValue("normalizeme"),
},
}),
}})
require.NoError(t, err)
return b
}(),
},
},
wantPluginCalled: true,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var origPluginAttrs, pluginReceivedAttrs *structpb.Struct
if tt.in != nil && tt.in.HostSet != nil && len(tt.in.Attributes) > 0 {
origPluginAttrs = new(structpb.Struct)
require.NoError(proto.Unmarshal(tt.in.Attributes, origPluginAttrs))
}
var pluginCalled bool
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(loopback.TestPluginHostServer{
NormalizeSetDataFn: func(_ context.Context, req *plgpb.NormalizeSetDataRequest) (*plgpb.NormalizeSetDataResponse, error) {
if req.Attributes == nil {
return new(plgpb.NormalizeSetDataResponse), nil
}
var attrs struct {
NormalizeToSlice string `mapstructure:"normalize_to_slice"`
}
require.NoError(mapstructure.Decode(req.Attributes.AsMap(), &attrs))
if attrs.NormalizeToSlice == "" {
return new(plgpb.NormalizeSetDataResponse), nil
}
retAttrs := proto.Clone(req.Attributes).(*structpb.Struct)
retAttrs.Fields[normalizeToSliceKey] = structpb.NewListValue(&structpb.ListValue{
Values: []*structpb.Value{structpb.NewStringValue(attrs.NormalizeToSlice)},
})
return &plgpb.NormalizeSetDataResponse{Attributes: retAttrs}, nil
},
OnCreateSetFn: func(ctx context.Context, req *plgpb.OnCreateSetRequest) (*plgpb.OnCreateSetResponse, error) {
pluginCalled = true
pluginReceivedAttrs = req.GetSet().GetAttributes()
return &plgpb.OnCreateSetResponse{}, nil
},
}),
unimplementedPlugin.GetPublicId(): loopback.NewWrappingPluginHostClient(loopback.TestPluginHostServer{OnCreateSetFn: func(ctx context.Context, req *plgpb.OnCreateSetRequest) (*plgpb.OnCreateSetResponse, error) {
pluginCalled = true
pluginReceivedAttrs = req.GetSet().GetAttributes()
return plgpb.UnimplementedHostPluginServiceServer{}.OnCreateSet(ctx, req)
}}),
}
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
require.NoError(err)
require.NotNil(repo)
got, plgInfo, err := repo.CreateSet(ctx, prj.GetPublicId(), tt.in, tt.opts...)
assert.Equal(tt.wantPluginCalled, pluginCalled)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
assert.Empty(tt.in.PublicId)
require.NotNil(got)
assert.True(strings.HasPrefix(got.GetPublicId(), globals.PluginHostSetPrefix))
assert.NotSame(tt.in, got)
assert.Equal(tt.want.Name, got.GetName())
assert.Equal(tt.want.Description, got.GetDescription())
assert.Equal(got.GetCreateTime(), got.GetUpdateTime())
assert.Equal(string(tt.want.GetAttributes()), string(got.GetAttributes()))
if origPluginAttrs != nil {
if normalizeVal := origPluginAttrs.Fields[normalizeToSliceKey]; normalizeVal != nil {
gotVal := pluginReceivedAttrs.Fields[normalizeToSliceKey]
require.NotNil(gotVal)
listVal := gotVal.GetListValue()
require.NotNil(listVal)
require.Len(listVal.Values, 1)
assert.Equal(normalizeVal.GetStringValue(), listVal.Values[0].GetStringValue())
origPluginAttrs.Fields[normalizeToSliceKey] = structpb.NewListValue(listVal)
tt.want.Attributes, err = proto.Marshal(origPluginAttrs)
require.NoError(err)
tt.want.Attributes, err = patchstruct.PatchBytes([]byte{}, tt.want.Attributes)
require.NoError(err)
}
}
wantedPluginAttributes := &structpb.Struct{}
require.NoError(proto.Unmarshal(tt.want.Attributes, wantedPluginAttributes))
assert.Empty(cmp.Diff(wantedPluginAttributes, pluginReceivedAttrs, protocmp.Transform()))
assert.Empty(cmp.Diff(plgInfo, plg, protocmp.Transform()))
assert.NoError(db.TestVerifyOplog(t, rw, got.GetPublicId(), db.WithOperation(oplog.OpType_OP_TYPE_CREATE), db.WithCreateNotBefore(10*time.Second)))
})
}
t.Run("invalid-duplicate-names", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var pluginCalled bool
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(loopback.TestPluginHostServer{OnCreateSetFn: func(ctx context.Context, req *plgpb.OnCreateSetRequest) (*plgpb.OnCreateSetResponse, error) {
pluginCalled = true
return &plgpb.OnCreateSetResponse{}, nil
}}),
}
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
require.NoError(err)
require.NotNil(repo)
_, prj := iam.TestScopes(t, iamRepo)
catalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
in := &HostSet{
HostSet: &store.HostSet{
CatalogId: catalog.PublicId,
Name: "test-name-repo",
Attributes: []byte{},
},
}
got, _, err := repo.CreateSet(context.Background(), prj.GetPublicId(), in)
require.NoError(err)
require.NotNil(got)
assert.True(pluginCalled)
assert.True(strings.HasPrefix(got.GetPublicId(), globals.PluginHostSetPrefix))
assert.NotSame(in, got)
assert.Equal(in.Name, got.GetName())
assert.Equal(in.Description, got.GetDescription())
assert.Equal(got.GetCreateTime(), got.GetUpdateTime())
// reset pluginCalled so we can see the duplicate name causes the plugin
// to not get called.
pluginCalled = false
got2, _, err := repo.CreateSet(context.Background(), prj.GetPublicId(), in)
assert.False(pluginCalled)
assert.Truef(errors.Match(errors.T(errors.NotUnique), err), "want err code: %v got err: %v", errors.NotUnique, err)
assert.Nil(got2)
})
t.Run("valid-duplicate-names-diff-catalogs", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
var pluginCalled bool
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(loopback.TestPluginHostServer{OnCreateSetFn: func(ctx context.Context, req *plgpb.OnCreateSetRequest) (*plgpb.OnCreateSetResponse, error) {
pluginCalled = true
return &plgpb.OnCreateSetResponse{}, nil
}}),
}
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
require.NoError(err)
require.NotNil(repo)
_, prj := iam.TestScopes(t, iamRepo)
catalogA := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
catalogB := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
in := &HostSet{
HostSet: &store.HostSet{
Name: "test-name-repo",
Attributes: []byte{},
},
}
in2 := in.clone()
in.CatalogId = catalogA.PublicId
// reset pluginCalled so we can see that the plugin is called during this
// createSet call
pluginCalled = false
got, _, err := repo.CreateSet(context.Background(), prj.GetPublicId(), in)
assert.True(pluginCalled)
require.NoError(err)
require.NotNil(got)
assert.True(strings.HasPrefix(got.GetPublicId(), globals.PluginHostSetPrefix))
assert.NotSame(in, got)
assert.Equal(in.Name, got.GetName())
assert.Equal(in.Description, got.GetDescription())
assert.Equal(got.GetCreateTime(), got.GetUpdateTime())
in2.CatalogId = catalogB.PublicId
// reset pluginCalled so we can see that the plugin is called during this
// createSet call
pluginCalled = false
got2, _, err := repo.CreateSet(context.Background(), prj.GetPublicId(), in2)
require.NoError(err)
require.NotNil(got2)
assert.True(pluginCalled)
assert.True(strings.HasPrefix(got.GetPublicId(), globals.PluginHostSetPrefix))
assert.NotSame(in2, got2)
assert.Equal(in2.Name, got2.GetName())
assert.Equal(in2.Description, got2.GetDescription())
assert.Equal(got2.GetCreateTime(), got2.GetUpdateTime())
})
}
func TestRepository_UpdateSet(t *testing.T) {
ctx := context.Background()
dbConn, _ := db.TestSetup(t, "postgres")
dbRW := db.New(dbConn)
dbWrapper := db.TestWrapper(t)
sched := scheduler.TestScheduler(t, dbConn, dbWrapper)
dbKmsCache := kms.TestKms(t, dbConn, dbWrapper)
_, projectScope := iam.TestScopes(t, iam.TestRepo(t, dbConn, dbWrapper))
testPlugin := plugin.TestPlugin(t, dbConn, "test")
dummyPluginMap := map[string]plgpb.HostPluginServiceClient{
testPlugin.GetPublicId(): &loopback.WrappingPluginHostClient{Server: &plgpb.UnimplementedHostPluginServiceServer{}},
}
// Set up a test catalog and the secrets for it
testCatalog := TestCatalog(t, dbConn, projectScope.PublicId, testPlugin.GetPublicId())
testCatalogSecret, err := newHostCatalogSecret(ctx, testCatalog.GetPublicId(), mustStruct(map[string]any{
"one": "two",
}))
require.NoError(t, err)
scopeWrapper, err := dbKmsCache.GetWrapper(ctx, testCatalog.GetProjectId(), kms.KeyPurposeDatabase)
require.NoError(t, err)
require.NoError(t, testCatalogSecret.encrypt(ctx, scopeWrapper))
err = dbRW.Create(ctx, testCatalogSecret)
require.NoError(t, err)
const normalizeToSliceKey = "normalize_to_slice"
// Create a test duplicate set. We don't use this set, it just
// exists to ensure that we can test for conflicts when setting
// the name.
const testDuplicateSetName = "duplicate-set-name"
testDuplicateSet := TestSet(t, dbConn, dbKmsCache, sched, testCatalog, dummyPluginMap)
testDuplicateSet.Name = testDuplicateSetName
setsUpdated, err := dbRW.Update(ctx, testDuplicateSet, []string{"name"}, []string{})
require.NoError(t, err)
require.Equal(t, 1, setsUpdated)
// Define some helpers here to make the test table more readable.
type changeHostSetFunc func(s *HostSet) *HostSet
changeSetToNil := func() changeHostSetFunc {
return func(_ *HostSet) *HostSet {
return nil
}
}
changeEmbeddedSetToNil := func() changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.HostSet = nil
return s
}
}
changePublicId := func(v string) changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.PublicId = v
return s
}
}
changeName := func(v string) changeHostSetFunc {
return func(s *HostSet) *HostSet {
s.Name = v
return s
}
}
changeDescription := func(s string) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Description = s
return c
}
}
changeSyncInterval := func(s int32) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.SyncIntervalSeconds = s
return c
}
}
changePreferredEndpoints := func(s []string) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.PreferredEndpoints = s
return c
}
}
changeAttributes := func(m map[string]any) changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Attributes = mustMarshal(m)
return c
}
}
changeAttributesNil := func() changeHostSetFunc {
return func(c *HostSet) *HostSet {
c.Attributes = nil
return c
}
}
// Define some checks that will be used in the below tests. Some of
// these are re-used, so we define them here. Most of these are
// assertions and no particular one is non-fatal in that they will
// stop execution. Note that these are executed after wantIsErr, so
// if that is set in an individual table test, these will not be
// executed.
//
// Note that we define some state here, similar to how we
// previously defined gotOnUpdateCatalogRequest above next the
// plugin map.
type checkHostSetFunc func(t *testing.T, got *HostSet)
checkVersion := func(want uint32) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.Version, "checkVersion")
}
}
checkName := func(want string) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.Name, "checkName")
}
}
checkDescription := func(want string) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.Description, "checkDescription")
}
}
checkPreferredEndpoints := func(want []string) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.PreferredEndpoints, "checkPreferredEndpoints")
}
}
checkAttributes := func(want map[string]any) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
st := &structpb.Struct{}
require.NoError(t, proto.Unmarshal(got.Attributes, st))
assert.Empty(t, cmp.Diff(mustStruct(want), st, protocmp.Transform()))
}
}
checkSyncInterval := func(want int32) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.SyncIntervalSeconds, "checkSyncInterval")
}
}
checkNeedSync := func(want bool) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.Equal(t, want, got.NeedSync, "checkNeedSync")
}
}
checkVerifySetOplog := func(op oplog.OpType) checkHostSetFunc {
return func(t *testing.T, got *HostSet) {
t.Helper()
assert.NoError(t,
db.TestVerifyOplog(
t,
dbRW,
got.PublicId,
db.WithOperation(op),
db.WithCreateNotBefore(10*time.Second),
),
)
}
}
type checkPluginReqFunc func(t *testing.T, got *plgpb.OnUpdateSetRequest)
checkUpdateSetRequestCurrentNameNil := func() checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Nil(t, got.CurrentSet.Name, "checkUpdateSetRequestCurrentNameNil")
}
}
checkUpdateSetRequestNewName := func(want string) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Equal(t, wrapperspb.String(want), got.NewSet.Name, "checkUpdateSetRequestNewName")
}
}
checkUpdateSetRequestNewNameNil := func() checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Nil(t, got.NewSet.Name, "checkUpdateSetRequestNewNameNil")
}
}
checkUpdateSetRequestCurrentDescriptionNil := func() checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Nil(t, got.CurrentSet.Description, "checkUpdateSetRequestCurrentDescriptionNil")
}
}
checkUpdateSetRequestNewDescription := func(want string) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Equal(t, wrapperspb.String(want), got.NewSet.Description, "checkUpdateSetRequestNewDescription")
}
}
checkUpdateSetRequestNewDescriptionNil := func() checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Nil(t, got.NewSet.Description, "checkUpdateSetRequestNewDescriptionNil")
}
}
checkUpdateSetRequestCurrentPreferredEndpoints := func(want []string) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Equal(t, want, got.CurrentSet.PreferredEndpoints, "checkUpdateSetRequestCurrentPreferredEndpoints")
}
}
checkUpdateSetRequestNewPreferredEndpoints := func(want []string) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Equal(t, want, got.NewSet.PreferredEndpoints, "checkUpdateSetRequestNewPreferredEndpoints")
}
}
checkUpdateSetRequestNewPreferredEndpointsNil := func() checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Nil(t, got.NewSet.PreferredEndpoints, "checkUpdateSetRequestNewPreferredEndpointsNil")
}
}
checkUpdateSetRequestCurrentAttributes := func(want map[string]any) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Empty(t, cmp.Diff(mustStruct(want), got.CurrentSet.GetAttributes(), protocmp.Transform()), "checkUpdateSetRequestCurrentAttributes")
}
}
checkUpdateSetRequestNewAttributes := func(want map[string]any) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Empty(t, cmp.Diff(mustStruct(want), got.NewSet.GetAttributes(), protocmp.Transform()), "checkUpdateSetRequestNewAttributes")
}
}
checkUpdateSetRequestPersistedSecrets := func(want map[string]any) checkPluginReqFunc {
return func(t *testing.T, got *plgpb.OnUpdateSetRequest) {
t.Helper()
assert.Empty(t, cmp.Diff(mustStruct(want), got.Persisted.Secrets, protocmp.Transform()), "checkUpdateSetRequestPersistedSecrets")
}
}
// Finally define a function for bringing the test subject host
// set.
setupBareHostSet := func(t *testing.T, ctx context.Context) (*HostSet, []*Host) {
t.Helper()
set := TestSet(
t,
dbConn,
dbKmsCache,
sched,
testCatalog,
dummyPluginMap,
)
t.Cleanup(func() {
t.Helper()
assert := assert.New(t)
n, err := dbRW.Delete(ctx, set)
assert.NoError(err)
assert.Equal(1, n)
})
return set, nil
}
// Create a host that will be verified as belonging to the created set below.
testHostExternIdPrefix := "test-host-external-id"
testHosts := make([]*Host, 3)
for i := 0; i <= 2; i++ {
testHosts[i] = TestHost(t, dbConn, testCatalog.PublicId, fmt.Sprintf("%s-%d", testHostExternIdPrefix, i))
}
// Finally define a function for bringing the test subject host
// set.
setupHostSet := func(t *testing.T, ctx context.Context) (*HostSet, []*Host) {
t.Helper()
require := require.New(t)
set := TestSet(
t,
dbConn,
dbKmsCache,
sched,
testCatalog,
dummyPluginMap,
WithPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
)
// Set some (default) attributes on our test set
set.Attributes = mustMarshal(map[string]any{
"foo": "bar",
})
// Set some fake sync detail to the set.
set.LastSyncTime = timestamp.New(time.Now())
set.NeedSync = false
numSetsUpdated, err := dbRW.Update(ctx, set, []string{"attributes", "LastSyncTime", "NeedSync"}, []string{})
require.NoError(err)
require.Equal(1, numSetsUpdated)
// Add some hosts to the host set.
TestSetMembers(t, dbConn, set.PublicId, testHosts)
t.Cleanup(func() {
t.Helper()
assert := assert.New(t)
n, err := dbRW.Delete(ctx, set)
assert.NoError(err)
assert.Equal(1, n)
})
return set, testHosts
}
tests := []struct {
name string
startingSet func(*testing.T, context.Context) (*HostSet, []*Host)
withProjectId *string
withEmptyPluginMap bool
withPluginError error
changeFuncs []changeHostSetFunc
version uint32
fieldMask []string
wantCheckSetFuncs []checkHostSetFunc
wantCheckPluginReqFuncs []checkPluginReqFunc
wantIsErr errors.Code
}{
{
name: "nil set",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changeSetToNil()},
wantIsErr: errors.InvalidParameter,
},
{
name: "nil embedded set",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changeEmbeddedSetToNil()},
wantIsErr: errors.InvalidParameter,
},
{
name: "missing public id",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changePublicId("")},
wantIsErr: errors.InvalidParameter,
},
{
name: "missing project id",
startingSet: setupBareHostSet,
withProjectId: func() *string { a := ""; return &a }(),
wantIsErr: errors.InvalidParameter,
},
{
name: "empty field mask",
startingSet: setupBareHostSet,
fieldMask: nil, // Should be testing on len
wantIsErr: errors.EmptyFieldMask,
},
{
name: "bad set id",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changePublicId("badid")},
fieldMask: []string{"name"},
wantIsErr: errors.RecordNotFound,
},
{
name: "version mismatch",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changeName("foo")},
version: 4,
fieldMask: []string{"name"},
wantIsErr: errors.VersionMismatch,
},
{
name: "mismatched project id to catalog project",
startingSet: setupBareHostSet,
withProjectId: func() *string { a := "badid"; return &a }(),
fieldMask: []string{"name"},
wantIsErr: errors.InvalidParameter,
},
{
name: "plugin lookup error",
startingSet: setupBareHostSet,
withEmptyPluginMap: true,
fieldMask: []string{"name"},
wantIsErr: errors.Internal,
},
{
name: "plugin invocation error",
startingSet: setupBareHostSet,
withPluginError: errors.New(context.Background(), errors.Internal, "TestRepository_UpdateSet/plugin_invocation_error", "test plugin error"),
fieldMask: []string{"name"},
wantIsErr: errors.Internal,
},
{
name: "update name (duplicate)",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changeName(testDuplicateSetName)},
fieldMask: []string{"name"},
wantIsErr: errors.NotUnique,
},
{
name: "update name",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeName("foo")},
fieldMask: []string{"name"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewName("foo"),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkName("foo"),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update name to same",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeName("")},
fieldMask: []string{"name"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewNameNil(),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(2), // Version remains same even though row is updated
checkName(""),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update description",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeDescription("foo")},
fieldMask: []string{"description"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentDescriptionNil(),
checkUpdateSetRequestNewDescription("foo"),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkDescription("foo"),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update description to same",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeDescription("")},
fieldMask: []string{"description"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentDescriptionNil(),
checkUpdateSetRequestNewDescriptionNil(),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(2), // Version remains same even though row is updated
checkDescription(""),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "set sync interval",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changeSyncInterval(42)},
fieldMask: []string{"SyncIntervalSeconds"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(2),
checkNeedSync(true),
checkSyncInterval(42),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "add preferred endpoints",
startingSet: setupBareHostSet,
changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})},
fieldMask: []string{"PreferredEndpoints"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentPreferredEndpoints(nil),
checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(2),
checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update preferred endpoints",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changePreferredEndpoints([]string{"cidr:10.0.0.0/24"})},
fieldMask: []string{"PreferredEndpoints"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "clear preferred endpoints",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changePreferredEndpoints(nil)},
fieldMask: []string{"PreferredEndpoints"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpointsNil(),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkPreferredEndpoints(nil),
checkNeedSync(false),
},
},
{
name: "update attributes (add)",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]any{
"baz": "qux",
})},
fieldMask: []string{"attributes"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentAttributes(map[string]any{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]any{
"foo": "bar",
"baz": "qux",
}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkAttributes(map[string]any{
"foo": "bar",
"baz": "qux",
}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (overwrite)",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]any{
"foo": "baz",
normalizeToSliceKey: "normalizeme",
})},
fieldMask: []string{"attributes"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentAttributes(map[string]any{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]any{
"foo": "baz",
normalizeToSliceKey: []any{"normalizeme"},
}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkAttributes(map[string]any{
"foo": "baz",
normalizeToSliceKey: []any{"normalizeme"},
}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (null)",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]any{
"foo": nil,
})},
fieldMask: []string{"attributes"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentAttributes(map[string]any{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]any{}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkAttributes(map[string]any{}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (full null)",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeAttributesNil()},
fieldMask: []string{"attributes"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentAttributes(map[string]any{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]any{}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkAttributes(map[string]any{}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update attributes (combined)",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{changeAttributes(map[string]any{
"a": "b",
"foo": "baz",
})},
fieldMask: []string{"attributes.a", "attributes.foo"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentAttributes(map[string]any{
"foo": "bar",
}),
checkUpdateSetRequestNewAttributes(map[string]any{
"a": "b",
"foo": "baz",
}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkAttributes(map[string]any{
"a": "b",
"foo": "baz",
}),
checkNeedSync(true),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
{
name: "update name and preferred endpoints",
startingSet: setupHostSet,
changeFuncs: []changeHostSetFunc{
changePreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
changeName("foo"),
},
fieldMask: []string{"name", "PreferredEndpoints"},
wantCheckPluginReqFuncs: []checkPluginReqFunc{
checkUpdateSetRequestCurrentNameNil(),
checkUpdateSetRequestNewName("foo"),
checkUpdateSetRequestCurrentPreferredEndpoints([]string{"cidr:192.168.0.0/24", "cidr:192.168.1.0/24", "cidr:172.16.0.0/12"}),
checkUpdateSetRequestNewPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
},
wantCheckSetFuncs: []checkHostSetFunc{
checkVersion(3),
checkName("foo"),
checkPreferredEndpoints([]string{"cidr:10.0.0.0/24"}),
checkNeedSync(false),
checkVerifySetOplog(oplog.OpType_OP_TYPE_UPDATE),
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
require := require.New(t)
assert := assert.New(t)
// Define a plugin "manager", basically just a map with a mock
// plugin in it. This also includes functionality to capture the
// state, and set an error and the returned secrets to nil. Note
// that the way that this is set up means that the tests cannot run
// in parallel, but there could be other factors affecting that as
// well.
var gotOnUpdateCallCount int
testPluginMap := map[string]plgpb.HostPluginServiceClient{
testPlugin.GetPublicId(): &loopback.WrappingPluginHostClient{
Server: &loopback.TestPluginHostServer{
NormalizeSetDataFn: func(_ context.Context, req *plgpb.NormalizeSetDataRequest) (*plgpb.NormalizeSetDataResponse, error) {
if req.Attributes == nil {
return new(plgpb.NormalizeSetDataResponse), nil
}
var attrs struct {
NormalizeToSlice string `mapstructure:"normalize_to_slice"`
}
require.NoError(mapstructure.Decode(req.Attributes.AsMap(), &attrs))
if attrs.NormalizeToSlice == "" {
return new(plgpb.NormalizeSetDataResponse), nil
}
retAttrs := proto.Clone(req.Attributes).(*structpb.Struct)
retAttrs.Fields[normalizeToSliceKey] = structpb.NewListValue(&structpb.ListValue{
Values: []*structpb.Value{structpb.NewStringValue(attrs.NormalizeToSlice)},
})
return &plgpb.NormalizeSetDataResponse{Attributes: retAttrs}, nil
},
OnUpdateSetFn: func(_ context.Context, req *plgpb.OnUpdateSetRequest) (*plgpb.OnUpdateSetResponse, error) {
gotOnUpdateCallCount++
for _, check := range tt.wantCheckPluginReqFuncs {
check(t, req)
}
return &plgpb.OnUpdateSetResponse{}, tt.withPluginError
},
},
},
}
origSet, wantedHosts := tt.startingSet(t, ctx)
pluginMap := testPluginMap
if tt.withEmptyPluginMap {
pluginMap = make(map[string]plgpb.HostPluginServiceClient)
}
repo, err := NewRepository(ctx, dbRW, dbRW, dbKmsCache, sched, pluginMap)
require.NoError(err)
require.NotNil(repo)
workingSet := origSet.clone()
for _, cf := range tt.changeFuncs {
workingSet = cf(workingSet)
}
projectId := testCatalog.ProjectId
if tt.withProjectId != nil {
projectId = *tt.withProjectId
}
version := origSet.Version
if tt.version != 0 {
version = tt.version
}
gotUpdatedSet, gotHosts, gotPlugin, gotNumUpdated, err := repo.UpdateSet(ctx, projectId, workingSet, version, tt.fieldMask)
t.Cleanup(func() { gotOnUpdateCallCount = 0 })
if tt.wantIsErr != 0 {
require.Equal(db.NoRowsAffected, gotNumUpdated)
require.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
return
}
require.NoError(err)
require.Equal(1, gotOnUpdateCallCount)
assert.Equal(1, gotNumUpdated)
// Quick assertion that the set is not nil and that the plugin
// ID in the catalog referenced by the set matches the plugin
// ID in the returned plugin.
require.NotNil(gotUpdatedSet)
require.NotNil(gotPlugin)
assert.Equal(testCatalog.PublicId, gotUpdatedSet.CatalogId)
assert.Equal(testCatalog.PluginId, gotPlugin.PublicId)
// Also assert that the hosts returned by the request are the ones that belong to the set
wantHostMap := make(map[string]string, len(wantedHosts))
for _, h := range wantedHosts {
wantHostMap[h.PublicId] = h.ExternalId
}
gotHostMap := make(map[string]string, len(gotHosts))
for _, h := range gotHosts {
gotHostMap[h.PublicId] = h.ExternalId
}
assert.Equal(wantHostMap, gotHostMap)
// Perform checks
for _, check := range tt.wantCheckSetFuncs {
check(t, gotUpdatedSet)
}
gotLookupSet, _, err := repo.LookupSet(ctx, workingSet.GetPublicId())
assert.NoError(err)
assert.Empty(cmp.Diff(gotUpdatedSet, gotLookupSet, protocmp.Transform()))
})
}
t.Run("Unset Empty PreferredEndpoint", func(t *testing.T) {
var gotOnUpdateCallCount int
testPluginMap := map[string]plgpb.HostPluginServiceClient{
testPlugin.GetPublicId(): &loopback.WrappingPluginHostClient{
Server: &loopback.TestPluginHostServer{
OnUpdateSetFn: func(_ context.Context, req *plgpb.OnUpdateSetRequest) (*plgpb.OnUpdateSetResponse, error) {
gotOnUpdateCallCount++
for _, check := range []checkPluginReqFunc{
checkUpdateSetRequestCurrentPreferredEndpoints(nil),
checkUpdateSetRequestNewPreferredEndpointsNil(),
checkUpdateSetRequestPersistedSecrets(map[string]any{
"one": "two",
}),
} {
check(t, req)
}
return &plgpb.OnUpdateSetResponse{}, nil
},
},
},
}
origSet, _ := setupBareHostSet(t, ctx)
pluginMap := testPluginMap
repo, err := NewRepository(ctx, dbRW, dbRW, dbKmsCache, sched, pluginMap)
require.NoError(t, err)
require.NotNil(t, repo)
workingSet := origSet.clone()
workingSet = changePreferredEndpoints(nil)(workingSet)
gotUpdatedSet, _, _, gotNumUpdated, err := repo.UpdateSet(ctx, testCatalog.ProjectId, workingSet, origSet.Version, []string{"PreferredEndpoints"})
t.Cleanup(func() { gotOnUpdateCallCount = 0 })
require.NoError(t, err)
require.Equal(t, 1, gotOnUpdateCallCount)
assert.Equal(t, 0, gotNumUpdated)
assert.Empty(t, cmp.Diff(gotUpdatedSet, origSet, protocmp.Transform()))
})
}
func TestRepository_UpdateSet_UnsetEmptyPreferredEndpoint(t *testing.T) {
}
func TestRepository_LookupSet(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "lookup")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(&loopback.TestPluginServer{}),
}
catalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
hostSet := TestSet(t, conn, kms, sched, catalog, map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(&loopback.TestPluginHostServer{
ListHostsFn: func(ctx context.Context, req *plgpb.ListHostsRequest) (*plgpb.ListHostsResponse, error) {
require.NotEmpty(t, req.GetSets())
require.NotNil(t, req.GetCatalog())
return &plgpb.ListHostsResponse{}, nil
},
}),
}, WithSyncIntervalSeconds(5))
hostSetId, err := newHostSetId(ctx)
require.NoError(t, err)
tests := []struct {
name string
in string
want *HostSet
wantIsErr errors.Code
}{
{
name: "with-no-public-id",
wantIsErr: errors.InvalidParameter,
},
{
name: "with-non-existing-host-set-id",
in: hostSetId,
},
{
name: "with-existing-host-set-id",
in: hostSet.PublicId,
want: hostSet,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
assert.NoError(err)
require.NotNil(repo)
got, _, err := repo.LookupSet(ctx, tt.in)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
if tt.want != nil {
assert.Empty(cmp.Diff(got, tt.want, protocmp.Transform()), "LookupSet(%q) got response %q, wanted %q", tt.in, got, tt.want)
}
})
}
}
func TestRepository_Endpoints(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "endpoints")
hostlessCatalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(&loopback.TestPluginServer{}),
}
catalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
hostSet10 := TestSet(t, conn, kms, sched, catalog, plgm, WithName("hostSet10"), WithPreferredEndpoints([]string{"cidr:10.0.0.1/24"}))
hostSet192 := TestSet(t, conn, kms, sched, catalog, plgm, WithName("hostSet192"), WithPreferredEndpoints([]string{"cidr:192.168.0.1/24"}))
hostSet100 := TestSet(t, conn, kms, sched, catalog, plgm, WithName("hostSet100"), WithPreferredEndpoints([]string{"cidr:100.100.100.100/24"}))
hostSetDNS := TestSet(t, conn, kms, sched, catalog, plgm, WithName("hostSetDNS"), WithPreferredEndpoints([]string{"dns:*"}))
hostlessSet := TestSet(t, conn, kms, sched, hostlessCatalog, plgm)
h1 := TestHost(t, conn, catalog.GetPublicId(), "test", withIpAddresses([]string{"10.0.0.5", "192.168.0.5"}), withDnsNames([]string{"example.com"}))
TestSetMembers(t, conn, hostSet10.GetPublicId(), []*Host{h1})
TestSetMembers(t, conn, hostSet192.GetPublicId(), []*Host{h1})
TestSetMembers(t, conn, hostSet100.GetPublicId(), []*Host{h1})
TestSetMembers(t, conn, hostSetDNS.GetPublicId(), []*Host{h1})
tests := []struct {
name string
setIds []string
want []*host.Endpoint
wantIsErr errors.Code
}{
{
name: "with-no-set-id",
wantIsErr: errors.InvalidParameter,
},
{
name: "with-set10",
setIds: []string{hostSet10.GetPublicId()},
want: []*host.Endpoint{
{
HostId: func() string {
s, err := newHostId(ctx, catalog.GetPublicId(), "test")
require.NoError(t, err)
return s
}(),
SetId: hostSet10.GetPublicId(),
Address: "10.0.0.5",
},
},
},
{
name: "with-different-set",
setIds: []string{hostSet192.GetPublicId()},
want: []*host.Endpoint{
{
HostId: func() string {
s, err := newHostId(ctx, catalog.GetPublicId(), "test")
require.NoError(t, err)
return s
}(),
SetId: hostSet192.GetPublicId(),
Address: "192.168.0.5",
},
},
},
{
name: "with-all-addresses-filtered-set",
setIds: []string{hostSet100.GetPublicId()},
want: nil,
},
{
name: "with-no-hosts-from-plugin",
setIds: []string{hostlessSet.GetPublicId()},
want: nil,
},
{
name: "with-dns-names",
setIds: []string{hostSetDNS.GetPublicId()},
want: []*host.Endpoint{
{
HostId: func() string {
s, err := newHostId(ctx, catalog.GetPublicId(), "test")
require.NoError(t, err)
return s
}(),
SetId: hostSetDNS.GetPublicId(),
Address: "example.com",
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.Endpoints(ctx, tt.setIds)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
if tt.want == nil {
return
}
sort.Slice(tt.want, func(i, j int) bool {
return tt.want[i].HostId < tt.want[j].HostId
})
sort.Slice(got, func(i, j int) bool {
return got[i].HostId < got[j].HostId
})
assert.Empty(cmp.Diff(got, tt.want, protocmp.Transform()))
for _, ep := range got {
h := allocHost()
h.PublicId = ep.HostId
require.NoError(rw.LookupByPublicId(ctx, h))
assert.Equal(ep.HostId, h.PublicId)
// TODO: Uncomment when we have a better way to lookup host
// with it's address
// assert.Equal(ep.Address, h.Address)
assert.Equal(catalog.GetPublicId(), h.GetCatalogId())
}
})
}
}
func TestRepository_listSets(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "list")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(&loopback.TestPluginServer{}),
}
catalogA := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
catalogB := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
hostSets := []*HostSet{
TestSet(t, conn, kms, sched, catalogA, plgm),
TestSet(t, conn, kms, sched, catalogA, plgm),
TestSet(t, conn, kms, sched, catalogA, plgm),
}
printoutTable(t, rw)
tests := []struct {
name string
in string
opts []Option
want []*HostSet
wantIsErr errors.Code
}{
{
name: "with-no-catalog-id",
wantIsErr: errors.InvalidParameter,
},
{
name: "Catalog-with-no-host-sets",
in: catalogB.PublicId,
want: nil,
},
{
name: "Catalog-with-host-sets",
in: catalogA.PublicId,
want: hostSets,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
assert.NoError(err)
require.NotNil(repo)
got, gotPlg, ttime, err := repo.listSets(ctx, tt.in, tt.opts...)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
assert.Nil(got)
return
}
require.NoError(err)
opts := []cmp.Option{
cmpopts.SortSlices(func(x, y *HostSet) bool { return x.PublicId < y.PublicId }),
protocmp.Transform(),
}
assert.Empty(cmp.Diff(tt.want, got, opts...))
if got != nil {
assert.Empty(cmp.Diff(plg, gotPlg, protocmp.Transform()))
}
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
})
}
}
func TestRepository_listSets_Limits(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "listlimit")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(&loopback.TestPluginServer{}),
}
catalog := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
count := 10
var hostSets []*HostSet
for i := 0; i < count; i++ {
hostSets = append(hostSets, TestSet(t, conn, kms, sched, catalog, plgm))
}
tests := []struct {
name string
repoOpts []host.Option
listOpts []Option
wantLen int
}{
{
name: "With no limits",
wantLen: count,
},
{
name: "With repo limit",
repoOpts: []host.Option{host.WithLimit(3)},
wantLen: 3,
},
{
name: "With List limit",
listOpts: []Option{WithLimit(3)},
wantLen: 3,
},
{
name: "With repo smaller than list limit",
repoOpts: []host.Option{host.WithLimit(2)},
listOpts: []Option{WithLimit(6)},
wantLen: 6,
},
{
name: "With repo larger than list limit",
repoOpts: []host.Option{host.WithLimit(6)},
listOpts: []Option{WithLimit(2)},
wantLen: 2,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm, tt.repoOpts...)
assert.NoError(err)
require.NotNil(repo)
got, gotPlg, ttime, err := repo.listSets(ctx, hostSets[0].CatalogId, tt.listOpts...)
require.NoError(err)
assert.Len(got, tt.wantLen)
assert.Empty(cmp.Diff(plg, gotPlg, protocmp.Transform()))
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
})
}
}
func TestRepository_listSets_Pagination(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
_, proj1 := iam.TestScopes(t, iamRepo)
sched := scheduler.TestScheduler(t, conn, wrapper)
plg := plugin.TestPlugin(t, conn, "test")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): &loopback.WrappingPluginHostClient{Server: &loopback.TestPluginServer{}},
}
catalog := TestCatalog(t, conn, proj1.GetPublicId(), plg.GetPublicId())
total := 5
for i := 0; i < total; i++ {
TestSet(t, conn, kms, sched, catalog, plgm)
}
rw := db.New(conn)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
require.NoError(t, err)
t.Run("no-options", func(t *testing.T) {
got, retPlg, ttime, err := repo.listSets(ctx, catalog.GetPublicId())
require.NoError(t, err)
assert.Equal(t, total, len(got))
assert.Equal(t, retPlg, plg)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
})
t.Run("withStartPageAfter", func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
page1, retPlg, ttime, err := repo.listSets(
context.Background(),
catalog.GetPublicId(),
WithLimit(2),
)
require.NoError(err)
require.Len(page1, 2)
assert.Equal(retPlg, plg)
// Transaction timestamp should be within ~10 seconds of now
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
page2, retPlg, ttime, err := repo.listSets(
context.Background(),
catalog.GetPublicId(),
WithLimit(2),
WithStartPageAfterItem(page1[1]),
)
require.NoError(err)
require.Len(page2, 2)
assert.Equal(retPlg, plg)
for _, item := range page1 {
assert.NotEqual(item.GetPublicId(), page2[0].GetPublicId())
assert.NotEqual(item.GetPublicId(), page2[1].GetPublicId())
}
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
page3, retPlg, ttime, err := repo.listSets(
context.Background(),
catalog.GetPublicId(),
WithLimit(2),
WithStartPageAfterItem(page2[1]),
)
require.NoError(err)
require.Len(page3, 1)
assert.Equal(retPlg, plg)
for _, item := range page2 {
assert.NotEqual(item.GetPublicId(), page3[0].GetPublicId())
}
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
page4, retPlg, ttime, err := repo.listSets(
context.Background(),
catalog.GetPublicId(),
WithLimit(2),
WithStartPageAfterItem(page3[0]),
)
require.NoError(err)
require.Empty(page4)
require.Empty(retPlg)
assert.True(time.Now().Before(ttime.Add(10 * time.Second)))
assert.True(time.Now().After(ttime.Add(-10 * time.Second)))
})
}
func Test_listDeletedSetIds(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
wrapper := db.TestWrapper(t)
testKms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
_, proj := iam.TestScopes(t, iamRepo)
sched := scheduler.TestScheduler(t, conn, wrapper)
plg := plugin.TestPlugin(t, conn, "test")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): &loopback.WrappingPluginHostClient{Server: &loopback.TestPluginServer{}},
}
catalog := TestCatalog(t, conn, proj.GetPublicId(), plg.GetPublicId())
rw := db.New(conn)
repo, err := NewRepository(ctx, rw, rw, testKms, sched, plgm)
require.NoError(t, err)
// Expect no entries at the start
deletedIds, ttime, err := repo.listDeletedSetIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
require.Empty(t, deletedIds)
// Transaction timestamp should be within ~10 seconds of now
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Delete a set
s := TestSet(t, conn, testKms, sched, catalog, plgm)
_, err = repo.DeleteSet(ctx, proj.GetPublicId(), s.GetPublicId())
require.NoError(t, err)
// Expect a single entry
deletedIds, ttime, err = repo.listDeletedSetIds(ctx, time.Now().AddDate(-1, 0, 0))
require.NoError(t, err)
require.Equal(t, []string{s.GetPublicId()}, deletedIds)
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
// Try again with the time set to now, expect no entries
deletedIds, ttime, err = repo.listDeletedSetIds(ctx, time.Now())
require.NoError(t, err)
require.Empty(t, deletedIds)
assert.True(t, time.Now().Before(ttime.Add(10*time.Second)))
assert.True(t, time.Now().After(ttime.Add(-10*time.Second)))
}
func Test_estimatedHostSetCount(t *testing.T) {
t.Parallel()
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
sqlDb, err := conn.SqlDB(ctx)
require.NoError(t, err)
wrapper := db.TestWrapper(t)
testKms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
_, proj := iam.TestScopes(t, iamRepo)
sched := scheduler.TestScheduler(t, conn, wrapper)
plg := plugin.TestPlugin(t, conn, "test")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): &loopback.WrappingPluginHostClient{Server: &loopback.TestPluginServer{}},
}
catalog := TestCatalog(t, conn, proj.GetPublicId(), plg.GetPublicId())
rw := db.New(conn)
repo, err := NewRepository(ctx, rw, rw, testKms, sched, plgm)
require.NoError(t, err)
// Check total entries at start, expect 0
numItems, err := repo.estimatedSetCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
// Create a set, expect 1 entries
s := TestSet(t, conn, testKms, sched, catalog, plgm)
// Run analyze to update estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedSetCount(ctx)
require.NoError(t, err)
assert.Equal(t, 1, numItems)
// Delete the set, expect 0 again
_, err = repo.DeleteSet(ctx, proj.GetPublicId(), s.GetPublicId())
require.NoError(t, err)
// Run analyze to update estimate
_, err = sqlDb.ExecContext(ctx, "analyze")
require.NoError(t, err)
numItems, err = repo.estimatedSetCount(ctx)
require.NoError(t, err)
assert.Equal(t, 0, numItems)
}
func TestRepository_DeleteSet(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
kms := kms.TestKms(t, conn, wrapper)
iamRepo := iam.TestRepo(t, conn, wrapper)
sched := scheduler.TestScheduler(t, conn, wrapper)
_, prj := iam.TestScopes(t, iamRepo)
plg := plugin.TestPlugin(t, conn, "create")
plgm := map[string]plgpb.HostPluginServiceClient{
plg.GetPublicId(): loopback.NewWrappingPluginHostClient(loopback.TestPluginHostServer{OnCreateSetFn: func(ctx context.Context, req *plgpb.OnCreateSetRequest) (*plgpb.OnCreateSetResponse, error) {
return &plgpb.OnCreateSetResponse{}, nil
}}),
}
c := TestCatalog(t, conn, prj.PublicId, plg.GetPublicId())
hostSet := TestSet(t, conn, kms, sched, c, plgm)
hostSet2 := TestSet(t, conn, kms, sched, c, plgm)
newHostSetId, err := newHostSetId(ctx)
require.NoError(t, err)
tests := []struct {
name string
in string
pluginChecker func(*plgpb.OnDeleteSetRequest) error
want int
wantIsErr errors.Code
}{
{
name: "With no public id",
pluginChecker: func(req *plgpb.OnDeleteSetRequest) error {
assert.Fail(t, "The plugin shouldn't be called")
return nil
},
wantIsErr: errors.InvalidParameter,
},
{
name: "With non existing host set id",
in: newHostSetId,
pluginChecker: func(req *plgpb.OnDeleteSetRequest) error {
assert.Fail(t, "The plugin shouldn't be called")
return nil
},
want: 0,
},
{
name: "With existing host set id",
in: hostSet.PublicId,
pluginChecker: func(req *plgpb.OnDeleteSetRequest) error {
assert.Equal(t, c.GetPublicId(), req.GetCatalog().GetId())
assert.Equal(t, hostSet.GetPublicId(), req.GetSet().GetId())
return nil
},
want: 1,
},
{
name: "Ignores plugin errors",
in: hostSet2.PublicId,
pluginChecker: func(req *plgpb.OnDeleteSetRequest) error {
assert.Equal(t, c.GetPublicId(), req.GetCatalog().GetId())
assert.Equal(t, hostSet2.GetPublicId(), req.GetSet().GetId())
return fmt.Errorf("test error")
},
want: 1,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert, require := assert.New(t), require.New(t)
repo, err := NewRepository(ctx, rw, rw, kms, sched, plgm)
assert.NoError(err)
require.NotNil(repo)
got, err := repo.DeleteSet(ctx, prj.PublicId, tt.in)
if tt.wantIsErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantIsErr), err), "want err: %q got: %q", tt.wantIsErr, err)
assert.Zero(got)
return
}
require.NoError(err)
assert.EqualValues(tt.want, got)
})
}
}
func printoutTable(t *testing.T, rw *db.Db) {
ctx := context.Background()
hsas := []*hostSetAgg{}
require.NoError(t, rw.SearchWhere(ctx, &hsas, "", nil))
for _, hs := range hsas {
t.Logf("%#v", hs)
}
}