mirror of https://github.com/hashicorp/boundary
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
368 lines
8.8 KiB
368 lines
8.8 KiB
package controller
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"testing"
|
|
|
|
"github.com/hashicorp/boundary/api"
|
|
"github.com/hashicorp/boundary/internal/cmd/base"
|
|
"github.com/hashicorp/boundary/internal/cmd/config"
|
|
"github.com/hashicorp/boundary/internal/iam"
|
|
"github.com/hashicorp/boundary/internal/servers"
|
|
"github.com/hashicorp/go-hclog"
|
|
wrapping "github.com/hashicorp/go-kms-wrapping"
|
|
"github.com/hashicorp/vault/sdk/helper/base62"
|
|
"github.com/jinzhu/gorm"
|
|
)
|
|
|
|
// TestController wraps a base.Server and Controller to provide a
|
|
// fully-programmatic controller for tests. Error checking (for instance, for
|
|
// valid config) is not stringent at the moment.
|
|
type TestController struct {
|
|
b *base.Server
|
|
c *Controller
|
|
t *testing.T
|
|
apiAddrs []string // The address the Controller API is listening on
|
|
clusterAddrs []string
|
|
client *api.Client
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
name string
|
|
}
|
|
|
|
// Controller returns the underlying controller
|
|
func (tc *TestController) Controller() *Controller {
|
|
return tc.c
|
|
}
|
|
|
|
func (tc *TestController) Config() *Config {
|
|
return tc.c.conf
|
|
}
|
|
|
|
func (tc *TestController) Client() *api.Client {
|
|
return tc.client
|
|
}
|
|
|
|
func (tc *TestController) Context() context.Context {
|
|
return tc.ctx
|
|
}
|
|
|
|
func (tc *TestController) IamRepo() *iam.Repository {
|
|
repo, err := tc.c.IamRepoFn()
|
|
if err != nil {
|
|
tc.t.Fatal(err)
|
|
}
|
|
return repo
|
|
}
|
|
|
|
func (tc *TestController) ServersRepo() *servers.Repository {
|
|
repo, err := tc.c.ServersRepoFn()
|
|
if err != nil {
|
|
tc.t.Fatal(err)
|
|
}
|
|
return repo
|
|
}
|
|
|
|
func (tc *TestController) Cancel() {
|
|
tc.cancel()
|
|
}
|
|
|
|
func (tc *TestController) Name() string {
|
|
return tc.name
|
|
}
|
|
|
|
func (tc *TestController) ApiAddrs() []string {
|
|
return tc.addrs("api")
|
|
}
|
|
|
|
func (tc *TestController) ClusterAddrs() []string {
|
|
return tc.addrs("cluster")
|
|
}
|
|
|
|
func (tc *TestController) DbConn() *gorm.DB {
|
|
return tc.b.Database
|
|
}
|
|
|
|
func (tc *TestController) addrs(purpose string) []string {
|
|
var prefix string
|
|
switch purpose {
|
|
case "api":
|
|
if tc.apiAddrs != nil {
|
|
return tc.apiAddrs
|
|
}
|
|
prefix = "http://"
|
|
case "cluster":
|
|
if tc.clusterAddrs != nil {
|
|
return tc.clusterAddrs
|
|
}
|
|
}
|
|
|
|
addrs := make([]string, 0, len(tc.b.Listeners))
|
|
for _, listener := range tc.b.Listeners {
|
|
if listener.Config.Purpose[0] == purpose {
|
|
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
|
|
if !ok {
|
|
tc.t.Fatal("could not parse address as a TCP addr")
|
|
}
|
|
addr := fmt.Sprintf("%s%s:%d", prefix, tcpAddr.IP.String(), tcpAddr.Port)
|
|
addrs = append(addrs, addr)
|
|
}
|
|
}
|
|
|
|
switch purpose {
|
|
case "api":
|
|
tc.apiAddrs = addrs
|
|
case "cluster":
|
|
tc.clusterAddrs = addrs
|
|
}
|
|
|
|
return addrs
|
|
}
|
|
|
|
func (tc *TestController) buildClient() {
|
|
client, err := api.NewClient(nil)
|
|
if err != nil {
|
|
tc.t.Fatal(fmt.Errorf("error creating client: %w", err))
|
|
}
|
|
apiAddrs := tc.ApiAddrs()
|
|
if len(apiAddrs) == 0 {
|
|
tc.t.Fatal("no API addresses found")
|
|
}
|
|
if err := client.SetAddr(apiAddrs[0]); err != nil {
|
|
tc.t.Fatal(fmt.Errorf("error setting client address: %w", err))
|
|
}
|
|
// Because this is using the real lib it can pick up from stored locations
|
|
// like the system keychain. Explicitly clear the token to ensure we
|
|
// understand the client state at the start of each test.
|
|
client.SetToken("")
|
|
|
|
tc.client = client
|
|
}
|
|
|
|
// Shutdown runs any cleanup functions; be sure to run this after your test is
|
|
// done
|
|
func (tc *TestController) Shutdown() {
|
|
if tc.b != nil {
|
|
close(tc.b.ShutdownCh)
|
|
}
|
|
|
|
tc.cancel()
|
|
|
|
if tc.c != nil {
|
|
if err := tc.c.Shutdown(false); err != nil {
|
|
tc.t.Error(err)
|
|
}
|
|
}
|
|
if tc.b != nil {
|
|
if err := tc.b.RunShutdownFuncs(); err != nil {
|
|
tc.t.Error(err)
|
|
}
|
|
if tc.b.DestroyDevDatabase() != nil {
|
|
if err := tc.b.DestroyDevDatabase(); err != nil {
|
|
tc.t.Error(err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type TestControllerOpts struct {
|
|
// Config; if not provided a dev one will be created
|
|
Config *config.Config
|
|
|
|
// DefaultAuthMethodId is the default auth method ID to use, if set.
|
|
DefaultAuthMethodId string
|
|
|
|
// DefaultLoginName is the login name used when creating the default account.
|
|
DefaultLoginName string
|
|
|
|
// DefaultPassword is the password used when creating the default account.
|
|
DefaultPassword string
|
|
|
|
// DisableAuthMethodCreation can be set true to disable creating an auth
|
|
// method automatically.
|
|
DisableAuthMethodCreation bool
|
|
|
|
// DisableDatabaseCreation can be set true to disable creating a dev
|
|
// database
|
|
DisableDatabaseCreation bool
|
|
|
|
// If set, instead of creating a dev database, it will connect to an
|
|
// existing database given the url
|
|
DatabaseUrl string
|
|
|
|
// If true, the controller will not be started
|
|
DisableAutoStart bool
|
|
|
|
// DisableAuthorizationFailures will still cause authz checks to be
|
|
// performed but they won't cause 403 Forbidden. Useful for API-level
|
|
// testing to avoid a lot of faff.
|
|
DisableAuthorizationFailures bool
|
|
|
|
// The controller KMS to use, or one will be created
|
|
RootKms wrapping.Wrapper
|
|
|
|
// The worker auth KMS to use, or one will be created
|
|
WorkerAuthKms wrapping.Wrapper
|
|
|
|
// The recovery KMS to use, or one will be created
|
|
RecoveryKms wrapping.Wrapper
|
|
|
|
// The name to use for the controller, otherwise one will be randomly
|
|
// generated, unless provided in a non-nil Config
|
|
Name string
|
|
|
|
// The logger to use, or one will be created
|
|
Logger hclog.Logger
|
|
}
|
|
|
|
func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
tc := &TestController{
|
|
t: t,
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
if opts == nil {
|
|
opts = new(TestControllerOpts)
|
|
}
|
|
|
|
// Base server
|
|
tc.b = base.NewServer(nil)
|
|
tc.b.Command = &base.Command{
|
|
ShutdownCh: make(chan struct{}),
|
|
}
|
|
|
|
// Get dev config, or use a provided one
|
|
var err error
|
|
if opts.Config == nil {
|
|
opts.Config, err = config.DevController()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
opts.Config.Controller.Name = opts.Name
|
|
}
|
|
|
|
if opts.DefaultAuthMethodId != "" {
|
|
tc.b.DevAuthMethodId = opts.DefaultAuthMethodId
|
|
}
|
|
if opts.DefaultLoginName != "" {
|
|
tc.b.DevLoginName = opts.DefaultLoginName
|
|
}
|
|
if opts.DefaultPassword != "" {
|
|
tc.b.DevPassword = opts.DefaultPassword
|
|
}
|
|
|
|
// Start a logger
|
|
tc.b.Logger = opts.Logger
|
|
if tc.b.Logger == nil {
|
|
tc.b.Logger = hclog.New(&hclog.LoggerOptions{
|
|
Level: hclog.Trace,
|
|
})
|
|
}
|
|
|
|
if opts.Config.Controller == nil {
|
|
opts.Config.Controller = new(config.Controller)
|
|
}
|
|
if opts.Config.Controller.Name == "" {
|
|
opts.Config.Controller.Name, err = base62.Random(5)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
tc.b.Logger.Info("controller name generated", "name", opts.Config.Controller.Name)
|
|
}
|
|
tc.name = opts.Config.Controller.Name
|
|
|
|
// Set up KMSes
|
|
switch {
|
|
case opts.RootKms != nil && opts.WorkerAuthKms != nil:
|
|
tc.b.RootKms = opts.RootKms
|
|
tc.b.WorkerAuthKms = opts.WorkerAuthKms
|
|
case opts.RootKms == nil && opts.WorkerAuthKms == nil:
|
|
if err := tc.b.SetupKMSes(nil, opts.Config); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
default:
|
|
t.Fatal("either controller and worker auth KMS must both be set, or neither")
|
|
}
|
|
if opts.RecoveryKms != nil {
|
|
tc.b.RecoveryKms = opts.RecoveryKms
|
|
}
|
|
|
|
// Ensure the listeners use random port allocation
|
|
for _, listener := range opts.Config.Listeners {
|
|
listener.RandomPort = true
|
|
}
|
|
if err := tc.b.SetupListeners(nil, opts.Config.SharedConfig, []string{"api", "cluster"}); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if opts.DatabaseUrl != "" {
|
|
tc.b.DatabaseUrl = opts.DatabaseUrl
|
|
if err := tc.b.ConnectToDatabase("postgres"); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
} else if !opts.DisableDatabaseCreation {
|
|
var createOpts []base.Option
|
|
if opts.DisableAuthMethodCreation {
|
|
createOpts = append(createOpts, base.WithSkipAuthMethodCreation())
|
|
}
|
|
if err := tc.b.CreateDevDatabase("postgres", createOpts...); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
conf := &Config{
|
|
RawConfig: opts.Config,
|
|
Server: tc.b,
|
|
DisableAuthorizationFailures: opts.DisableAuthorizationFailures,
|
|
}
|
|
|
|
tc.c, err = New(conf)
|
|
if err != nil {
|
|
tc.Shutdown()
|
|
t.Fatal(err)
|
|
}
|
|
|
|
tc.buildClient()
|
|
|
|
if !opts.DisableAutoStart {
|
|
if err := tc.c.Start(); err != nil {
|
|
tc.Shutdown()
|
|
t.Fatal(err)
|
|
}
|
|
}
|
|
|
|
return tc
|
|
}
|
|
|
|
func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestControllerOpts) *TestController {
|
|
if opts == nil {
|
|
opts = new(TestControllerOpts)
|
|
}
|
|
nextOpts := &TestControllerOpts{
|
|
DatabaseUrl: tc.c.conf.DatabaseUrl,
|
|
DefaultAuthMethodId: tc.c.conf.DevAuthMethodId,
|
|
RootKms: tc.c.conf.RootKms,
|
|
WorkerAuthKms: tc.c.conf.WorkerAuthKms,
|
|
RecoveryKms: tc.c.conf.RecoveryKms,
|
|
Name: opts.Name,
|
|
Logger: tc.c.conf.Logger,
|
|
}
|
|
if opts.Logger != nil {
|
|
nextOpts.Logger = opts.Logger
|
|
}
|
|
if nextOpts.Name == "" {
|
|
var err error
|
|
nextOpts.Name, err = base62.Random(5)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
nextOpts.Logger.Info("controller name generated", "name", nextOpts.Name)
|
|
}
|
|
return NewTestController(t, nextOpts)
|
|
}
|