You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/internal/cmd/commands/server/controller_db_swap_test.go

459 lines
13 KiB

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package server
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/daemon/controller"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/testing/dbtest"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
)
const dbSwapConfig = `
disable_mlock = true
telemetry {
prometheus_retention_time = "24h"
disable_hostname = true
}
controller {
name = "test-controller"
description = "A default controller created for tests"
database {
url = "%s"
}
}
kms "aead" {
purpose = "root"
aead_type = "aes-gcm"
key = "%s"
key_id = "global_root"
}
kms "aead" {
purpose = "worker-auth"
aead_type = "aes-gcm"
key = "%s"
key_id = "global_worker-auth"
}
kms "aead" {
purpose = "recovery"
aead_type = "aes-gcm"
key = "%s"
key_id = "global_recovery"
}
listener "tcp" {
purpose = "api"
address = "127.0.0.1:9500"
tls_disable = true
}
listener "tcp" {
address = "127.0.0.1:9501"
purpose = "cluster"
}
`
const getDatabaseLockQuery = `
select count(*) from pg_locks
-- We need to constrain this query to the correct database because we're not
-- running two distinct Postgres instances, rather two separate databases in
-- the same Postgres instance.
left join pg_database on pg_locks.database = pg_database.oid -- pg_locks.database refers to the database oid, we need to check the name.
where
pg_locks.locktype = 'advisory'
and pg_locks.granted = true -- the lock must be granted, not awaited
and pg_locks.objid = 3865661975 -- magic number set by the schema manager
and pg_database.datname = $1
`
func TestReloadControllerDatabase(t *testing.T) {
td := t.TempDir()
// Set the close time to something small
db.CloseSwappedDbDuration = 5 * time.Second
// Create and migrate database A and B.
controllerKey := config.DevKeyGeneration()
closeA, urlA, dbNameA, err := getInitDatabase(t, controllerKey)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, closeA()) })
closeB, urlB, dbNameB, err := getInitDatabase(t, controllerKey)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, closeB()) })
cmd := testServerCommand(t, testServerCommandOpts{})
workerAuthKey := config.DevKeyGeneration()
recoveryKey := config.DevKeyGeneration()
cfgHcl := fmt.Sprintf(dbSwapConfig, urlA, controllerKey, workerAuthKey, recoveryKey)
require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644))
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
args := []string{"-config", td + "/config.hcl"}
exitCode := cmd.Run(args)
if exitCode != 0 {
output := cmd.UI.(*cli.MockUi).ErrorWriter.String() + cmd.UI.(*cli.MockUi).OutputWriter.String()
fmt.Printf("%s: got a non-zero exit status: %s", t.Name(), output)
}
}()
// Wait until things are up and running (or timeout).
select {
case <-cmd.startedCh:
case <-time.After(15 * time.Second):
t.Fatal("timeout")
}
require.NotNil(t, cmd.schemaManager)
sqlDB, err := cmd.Server.Database.SqlDB(context.Background())
require.NoError(t, err)
require.NotNil(t, sqlDB)
// Assert we're connected to database A.
var currentDB string
row := sqlDB.QueryRow("select current_database();")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&currentDB))
require.Equal(t, dbNameA, currentDB)
// Assert we've grabbed a lock on database A.
var lock int
row = sqlDB.QueryRow(getDatabaseLockQuery, dbNameA)
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
// Get old values so we can test against them later.
oldSchemaManager := cmd.schemaManager
oldDB := cmd.Server.Database
// Change config and SIGHUP.
cfgHcl = fmt.Sprintf(dbSwapConfig, urlB, controllerKey, workerAuthKey, recoveryKey)
require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644))
cmd.SighupCh <- struct{}{}
select {
case <-cmd.reloadedCh:
case <-time.After(15 * time.Second):
t.Fatal("timeout")
}
// Assert that the schema manager ptr and value changed
if oldSchemaManager == cmd.schemaManager {
t.Fatalf("schema manager pointers are equal (%p). expected difference", oldSchemaManager)
}
require.NotEqualValues(t, oldSchemaManager, cmd.schemaManager)
// Assert that the db.DB hasn't changed.
// It looks like everything is the same because we're not
// replacing `Server.Database`, and we're not changing the
// underlying pointer to *dbw.DB either.
// We're actually replacing the value of `wrapped` (*dbw.DB)
// in-place (without changing memory addr).
// Since we can't access `db.DB.wrapped` here, that has to be
// tested separately on the appropriate package.
if oldDB != cmd.Server.Database {
t.Fatalf("server *db.DB pointers differ, expected equal. old ptr %p | new ptr %p", oldDB, cmd.Server.Database)
}
require.EqualValues(t, oldDB, cmd.Server.Database)
// Wait for the old connection to be closed
time.Sleep(db.CloseSwappedDbDuration)
// `sqlDB` still points to database A here. Assert that the object
// is Closed.
row = sqlDB.QueryRow("select 1")
require.ErrorContains(t, row.Err(), "database is closed")
// Get underlying *sql.DB again. We only need to do this on the test
// because we're getting a pointer to a *sql.DB to call Query directly
// and we swap the database at a higher level.
// At this point `sqlDB` is pointing to a memory address containing
// the *sql.DB used for db A, so we need to call the function again
// to update that reference.
sqlDB, err = cmd.Server.Database.SqlDB(context.Background())
require.NoError(t, err)
require.NotNil(t, sqlDB)
// Assert we're connected to database B.
row = sqlDB.QueryRow("select current_database();")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&currentDB))
require.Equal(t, dbNameB, currentDB)
// Assert the lock on database A has been released.
row = sqlDB.QueryRow(getDatabaseLockQuery, dbNameA)
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&lock))
require.Equal(t, 0, lock)
// Assert we've grabbed a lock on database B.
row = sqlDB.QueryRow(getDatabaseLockQuery, dbNameB)
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}
func TestReloadControllerDatabase_InvalidNewDatabaseState(t *testing.T) {
td := t.TempDir()
// Create and migrate database A and B.
controllerKey := config.DevKeyGeneration()
closeA, urlA, dbNameA, err := getInitDatabase(t, controllerKey)
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, closeA()) })
invalidDatabaseClose, invalidDatabaseUrl, _, err := dbtest.StartUsingTemplate("postgres") // no kms set-up.
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, invalidDatabaseClose()) })
cmd := testServerCommand(t, testServerCommandOpts{})
workerAuthKey := config.DevKeyGeneration()
recoveryKey := config.DevKeyGeneration()
cfgHcl := fmt.Sprintf(dbSwapConfig, urlA, controllerKey, workerAuthKey, recoveryKey)
require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644))
wg := &sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
args := []string{"-config", td + "/config.hcl"}
exitCode := cmd.Run(args)
if exitCode != 0 {
output := cmd.UI.(*cli.MockUi).ErrorWriter.String() + cmd.UI.(*cli.MockUi).OutputWriter.String()
fmt.Printf("%s: got a non-zero exit status: %s", t.Name(), output)
}
}()
// Wait until things are up and running (or timeout).
select {
case <-cmd.startedCh:
case <-time.After(15 * time.Second):
t.Fatal("timeout")
}
require.NotNil(t, cmd.schemaManager)
sqlDB, err := cmd.Server.Database.SqlDB(context.Background())
require.NoError(t, err)
require.NotNil(t, sqlDB)
// Assert we're connected to database A.
var currentDB string
row := sqlDB.QueryRow("select current_database();")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&currentDB))
require.Equal(t, dbNameA, currentDB)
// Assert we've grabbed a lock on database A.
var lock int
row = sqlDB.QueryRow(getDatabaseLockQuery, dbNameA)
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
// Get old values so we can test against them later.
oldSchemaManager := cmd.schemaManager
// Change config and SIGHUP.
cfgHcl = fmt.Sprintf(dbSwapConfig, invalidDatabaseUrl, controllerKey, workerAuthKey, recoveryKey)
require.NoError(t, os.WriteFile(td+"/config.hcl", []byte(cfgHcl), 0o644))
cmd.SighupCh <- struct{}{}
select {
case <-cmd.reloadedCh:
case <-time.After(15 * time.Second):
t.Fatal("timeout")
}
// Assert that the schema manager ptr and value did not change.
if oldSchemaManager != cmd.schemaManager {
t.Fatalf("schema manager pointers are different (old: %p / new: %p). expected equal", oldSchemaManager, cmd.schemaManager)
}
require.EqualValues(t, oldSchemaManager, cmd.schemaManager)
// Assert we're still connected to and locked on database A.
sqlDB, err = cmd.Server.Database.SqlDB(context.Background())
require.NoError(t, err)
require.NotNil(t, sqlDB)
row = sqlDB.QueryRow("select current_database();")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&currentDB))
require.Equal(t, dbNameA, currentDB)
row = sqlDB.QueryRow(getDatabaseLockQuery, dbNameA)
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&lock))
require.Equal(t, 1, lock)
cmd.ShutdownCh <- struct{}{}
wg.Wait()
}
func TestReloadControllerDatabase_VariousNilValues(t *testing.T) {
// There's not much we can test in these cases, however
// we can ensure things don't panic.
tests := []struct {
name string
cmd *Command
newConfig *config.Config
}{
{
name: "nilServer",
cmd: &Command{
Server: nil,
controller: &controller.Controller{},
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url_new"},
},
},
},
{
name: "nilServerDatabase",
cmd: &Command{
Server: &base.Server{Database: nil},
controller: &controller.Controller{},
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url_new"},
},
},
},
{
name: "nilController",
cmd: &Command{
Server: &base.Server{Database: &db.DB{}},
controller: nil,
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url_new"},
},
},
},
{
name: "nilNewConfig",
cmd: &Command{
Server: &base.Server{Database: &db.DB{}},
controller: &controller.Controller{},
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: nil,
},
{
name: "nilNewConfigController",
cmd: &Command{
Server: &base.Server{Database: &db.DB{}},
controller: &controller.Controller{},
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: &config.Config{Controller: nil},
},
{
name: "nilNewConfigControllerDatabase",
cmd: &Command{
Server: &base.Server{Database: &db.DB{}},
controller: &controller.Controller{},
Config: &config.Config{
Controller: &config.Controller{
Database: &config.Database{Url: "db_url"},
},
},
},
newConfig: &config.Config{
Controller: &config.Controller{Database: nil},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.NotPanics(t, func() {
require.NoError(t, tt.cmd.reloadControllerDatabase(tt.newConfig))
})
})
}
}
// getInitDatabase creates a database and sets the KMS up.
func getInitDatabase(t *testing.T, kmsRootKey string) (func() error, string, string, error) {
close, url, dbName, err := dbtest.StartUsingTemplate("postgres")
require.NoError(t, err)
cmd := testServerCommand(t, testServerCommandOpts{})
cmd.Server.DatabaseUrl = url
kmsConfig := fmt.Sprintf(`
kms "aead" {
purpose = "root"
aead_type = "aes-gcm"
key = "%s"
key_id = "global_root"
}`, kmsRootKey)
cfg, err := config.Parse(kmsConfig)
require.NoError(t, err)
require.NoError(t, cmd.SetupKMSes(context.Background(), cli.NewMockUi(), cfg))
require.NoError(t, cmd.Server.OpenAndSetServerDatabase(context.Background(), "postgres"))
require.NoError(t, cmd.Server.CreateGlobalKmsKeys(context.Background()))
return close, url, dbName, err
}