You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/plugin/repository_plugin_test.go

412 lines
9.3 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package plugin
import (
"context"
"fmt"
"strings"
"testing"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/kms"
"github.com/hashicorp/boundary/internal/plugin/store"
"github.com/hashicorp/boundary/internal/types/scope"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestRepository_CreatePlugin(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
iam.TestRepo(t, conn, wrapper)
kmsCache := kms.TestKms(t, conn, wrapper)
ctx := context.Background()
repo, err := NewRepository(ctx, rw, rw, kmsCache)
assert.NoError(t, err)
assert.NotNil(t, repo)
tests := []struct {
name string
in *Plugin
opts []Option
want *Plugin
wantErr string
}{
{
name: "nil-plugin",
wantErr: "nil Plugin",
},
{
name: "nil-embedded-plugin",
in: &Plugin{},
wantErr: "nil embedded Plugin",
},
{
name: "no-scope-id",
in: &Plugin{
Plugin: &store.Plugin{
ScopeId: "",
},
},
wantErr: "no scope id",
},
{
name: "non-global-scope",
in: &Plugin{
Plugin: &store.Plugin{
ScopeId: "o_1234567890",
},
},
wantErr: "scope id is not 'global'",
},
{
name: "public-id-not-empty",
in: &Plugin{
Plugin: &store.Plugin{
PublicId: "biscuit",
ScopeId: "global",
},
},
wantErr: "public id not empty",
},
{
name: "valid-with-name",
in: &Plugin{
Plugin: &store.Plugin{
Name: "test-name-repo",
ScopeId: scope.Global.String(),
},
},
want: &Plugin{
Plugin: &store.Plugin{
Name: "test-name-repo",
ScopeId: scope.Global.String(),
},
},
},
{
name: "valid-with-description",
in: &Plugin{
Plugin: &store.Plugin{
Description: "test-description-repo",
ScopeId: scope.Global.String(),
},
},
want: &Plugin{
Plugin: &store.Plugin{
Description: "test-description-repo",
ScopeId: scope.Global.String(),
},
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
got, err := repo.CreatePlugin(context.Background(), tt.in, tt.opts...)
if tt.wantErr != "" {
assert.ErrorContains(err, tt.wantErr)
assert.Nil(got)
return
}
require.NoError(t, err)
assert.NoError(err)
assert.Empty(tt.in.PublicId)
assert.NotNil(got)
assertPublicId(t, PluginPrefix, got.PublicId)
assert.NotSame(tt.in, got)
assert.Equal(tt.want.Name, got.Name)
assert.Equal(tt.want.Description, got.Description)
assert.Equal(got.CreateTime, got.UpdateTime)
})
}
t.Run("invalid-duplicate-names", func(t *testing.T) {
assert := assert.New(t)
kms := kms.TestKms(t, conn, wrapper)
ctx := context.Background()
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
assert.NotNil(repo)
in := &Plugin{
Plugin: &store.Plugin{
ScopeId: scope.Global.String(),
Name: "invalid-duplicate-names",
},
}
got, err := repo.CreatePlugin(context.Background(), in)
assert.NoError(err)
assert.NotNil(got)
assertPublicId(t, PluginPrefix, got.PublicId)
assert.NotSame(in, got)
assert.Equal(in.Name, got.Name)
assert.Equal(in.Description, got.Description)
assert.Equal(got.CreateTime, got.UpdateTime)
got2, err := repo.CreatePlugin(context.Background(), in)
assert.Truef(errors.Match(errors.T(errors.NotUnique), err), "want err code: %v got err: %v", errors.NotUnique, err)
assert.Nil(got2)
})
}
func assertPublicId(t *testing.T, prefix, actual string) {
t.Helper()
assert.NotEmpty(t, actual)
parts := strings.Split(actual, "_")
assert.Equalf(t, 2, len(parts), "want one '_' in PublicId, got multiple in %q", actual)
assert.Equalf(t, prefix, parts[0], "PublicId want prefix: %q, got: %q in %q", prefix, parts[0], actual)
}
func TestRepository_LookupPlugin(t *testing.T) {
ctx := context.Background()
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
plg := TestPlugin(t, conn, "test")
badId, err := newPluginId(ctx)
assert.NoError(t, err)
assert.NotNil(t, badId)
tests := []struct {
name string
id string
want *Plugin
wantErr errors.Code
}{
{
name: "found",
id: plg.GetPublicId(),
want: plg,
},
{
name: "not-found",
id: badId,
want: nil,
},
{
name: "bad-public-id",
id: "",
want: nil,
wantErr: errors.InvalidParameter,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
assert.NotNil(repo)
got, err := repo.LookupPlugin(ctx, tt.id)
if tt.wantErr != 0 {
assert.Truef(errors.Match(errors.T(tt.wantErr), err), "want err: %q got: %q", tt.wantErr, err)
return
}
assert.NoError(err)
switch {
case tt.want == nil:
assert.Nil(got)
case tt.want != nil:
assert.NotNil(got)
assert.Equal(got, tt.want)
}
})
}
}
func TestRepository_LookupPluginByName(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
plg := TestPlugin(t, conn, "name123")
tests := []struct {
name string
pluginName string
want *Plugin
wantErr string
}{
{
name: "found",
pluginName: plg.GetName(),
want: plg,
},
{
name: "not-found",
pluginName: "randomname",
want: nil,
},
{
name: "emptyname",
pluginName: "",
want: nil,
wantErr: "no plugin name",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
kms := kms.TestKms(t, conn, wrapper)
ctx := context.Background()
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
assert.NotNil(repo)
got, err := repo.LookupPluginByName(context.Background(), tt.pluginName)
if tt.wantErr != "" {
assert.ErrorContains(err, tt.wantErr)
assert.Nil(got)
return
}
assert.NoError(err)
switch {
case tt.want == nil:
assert.Nil(got)
case tt.want != nil:
assert.NotNil(got)
assert.Equal(got, tt.want)
}
})
}
}
func TestRepository_AddSupportFlag(t *testing.T) {
conn, _ := db.TestSetup(t, "postgres")
rw := db.New(conn)
wrapper := db.TestWrapper(t)
tests := []struct {
name string
pluginName string
table string
flag PluginType
flagExists bool
forceErr bool
err string
}{
{
name: "flag doesn't exist and is added",
pluginName: "test1",
table: "plugin_host_supported",
flag: PluginTypeHost,
flagExists: false,
forceErr: false,
err: "",
},
{
name: "flag exists and is unchanged",
pluginName: "test2",
table: "plugin_host_supported",
flag: PluginTypeHost,
flagExists: true,
forceErr: false,
err: "",
},
{
name: "err thrown and is handled",
pluginName: "test3",
table: "plugin_host_supported",
flag: PluginTypeHost,
flagExists: false,
forceErr: true,
err: "wt_plugin_id_check constraint failed",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
// require.Contains(pluginTypeDbMap, tt.flag)
plg := TestPlugin(t, conn, tt.pluginName, WithHostFlag(tt.flagExists))
// check to make sure the start state is as expected
rows, err := rw.Query(ctx, fmt.Sprintf("select public_id from %s where public_id = ?;", tt.table), []any{plg.PublicId})
require.NoError(err)
rowCount := 0
var plgid string
for rows.Next() {
rowCount++
require.NoError(rows.Scan(&plgid))
}
require.NoError(rows.Err())
if tt.flagExists {
assert.Equal(1, rowCount)
assert.Equal(plg.PublicId, plgid)
} else {
assert.Equal(0, rowCount)
assert.Equal("", plgid)
}
// create the plugin repo
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
assert.NotNil(repo)
if tt.forceErr {
plg.PublicId = "biscuit"
}
// do the thing
err = repo.AddSupportFlag(ctx, plg, tt.flag)
if tt.err != "" {
require.ErrorContains(err, tt.err)
return
}
require.NoError(err)
// check to make sure the end state is as expected
rows, err = rw.Query(ctx, fmt.Sprintf("select public_id from %s where public_id = ?;", tt.table), []any{plg.PublicId})
require.NoError(err)
rowCount = 0
for rows.Next() {
rowCount++
}
require.NoError(rows.Err())
assert.Equal(1, rowCount)
})
}
// this needs it's own test becaues of the setup required for above tests
t.Run("plugin type unknown", func(t *testing.T) {
assert := assert.New(t)
require := require.New(t)
ctx := context.Background()
plg := TestPlugin(t, conn, "test12345")
// create the plugin repo
kms := kms.TestKms(t, conn, wrapper)
repo, err := NewRepository(ctx, rw, rw, kms)
assert.NoError(err)
assert.NotNil(repo)
// do the thing
err = repo.AddSupportFlag(ctx, plg, PluginTypeUnknown)
require.ErrorContains(err, "plugin type does not exist: parameter violation")
})
}