Initial worker porting steps (#232)

pull/257/head^2
Jeff Mitchell 6 years ago committed by GitHub
parent 73a38b1433
commit 203e2b5dc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,6 +27,9 @@ dev: build-ui-ifne
@echo "==> Building Boundary with dev and UI features enabled"
@CGO_ENABLED=$(CGO_ENABLED) BUILD_TAGS='$(BUILD_TAGS)' BOUNDARY_DEV_BUILD=1 sh -c "'$(CURDIR)/scripts/build.sh'"
fmt:
goimports -w $$(find . -name '*.go' | grep -v pb.go | grep -v pb.gw.go)
build-ui:
@scripts/uigen.sh

@ -421,6 +421,14 @@ func NewClient(c *Config) (*Client, error) {
}, nil
}
// Addr returns the current (parsed) address
func (c *Client) Addr() string {
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
return c.config.Addr
}
// Sets the address of Boundary in the client. The format of address should
// be "<Scheme>://<Host>:<Port>". Setting this on a client will override the
// value of the BOUNDARY_ADDR environment variable.
@ -433,8 +441,8 @@ func (c *Client) SetAddr(addr string) error {
// ScopeId fetches the scope ID the client will use by default
func (c *Client) ScopeId() string {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
return c.config.ScopeId
}
@ -504,8 +512,8 @@ func (c *Client) SetOutputCurlString(curl bool) {
// Token gets the configured token.
func (c *Client) Token() string {
c.modifyLock.Lock()
defer c.modifyLock.Unlock()
c.modifyLock.RLock()
defer c.modifyLock.RUnlock()
return c.config.Token
}

@ -3,6 +3,9 @@ module github.com/hashicorp/boundary
go 1.13
require (
github.com/alecthomas/assert v0.0.0-20170929043011-405dbfeb8e38
github.com/alecthomas/colour v0.1.0 // indirect
github.com/alecthomas/repr v0.0.0-20200325044227-4184120f674c // indirect
github.com/armon/go-metrics v0.3.3
github.com/bufbuild/buf v0.20.5
github.com/fatih/color v1.9.0
@ -44,13 +47,16 @@ require (
github.com/mitchellh/mapstructure v1.3.3
github.com/oligot/go-mod-upgrade v0.2.1
github.com/ory/dockertest/v3 v3.6.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pires/go-proxyproto v0.1.3
github.com/pkg/errors v0.9.1
github.com/posener/complete v1.2.3
github.com/ryanuber/columnize v2.1.0+incompatible
github.com/ryanuber/go-glob v1.0.0
github.com/sergi/go-diff v1.1.0 // indirect
github.com/stretchr/testify v1.6.1
github.com/zalando/go-keyring v0.1.0
go.uber.org/atomic v1.6.0
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de
golang.org/x/time v0.0.0-20200630173020-3af7569d3a1e
golang.org/x/tools v0.0.0-20200807210451-92211316783d

266
go.sum

File diff suppressed because it is too large Load Diff

@ -20,19 +20,24 @@ import (
"github.com/hashicorp/vault/internalshared/reloadutil"
"github.com/mitchellh/cli"
"github.com/pires/go-proxyproto"
"google.golang.org/grpc"
)
type ServerListener struct {
Mux *alpnmux.ALPNMux
Config *configutil.Listener
HTTPServer *http.Server
GrpcServer *grpc.Server
ALPNListener net.Listener
}
type WorkerAuthCertInfo struct {
CACertPEM []byte `json:"ca_cert"`
CertPEM []byte `json:"cert"`
KeyPEM []byte `json:"key"`
type WorkerAuthInfo struct {
CACertPEM []byte `json:"ca_cert"`
CertPEM []byte `json:"cert"`
KeyPEM []byte `json:"key"`
Name string `json:"name"`
Description string `json:"description"`
ConnectionNonce string `json:"connection_nonce"`
}
// Factory is the factory function to create a listener.
@ -56,10 +61,15 @@ func NewListener(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnm
func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui) (*alpnmux.ALPNMux, map[string]string, reloadutil.ReloadFunc, error) {
if l.Address == "" {
if len(l.Purpose) == 1 && l.Purpose[0] == "cluster" {
l.Address = "127.0.0.1:9201"
} else {
l.Address = "127.0.0.1:9200"
if len(l.Purpose) == 1 {
switch l.Purpose[0] {
case "cluster":
l.Address = "127.0.0.1:9201"
case "worker-alpn-tls":
l.Address = "127.0.0.1:9202"
default:
l.Address = "127.0.0.1:9200"
}
}
}
@ -102,9 +112,6 @@ func tcpListenerFactory(l *configutil.Listener, logger hclog.Logger, ui cli.Ui)
alpnMux := alpnmux.New(ln, logger)
if l.TLSDisable {
if _, err = alpnMux.RegisterProto(alpnmux.NoProto, nil); err != nil {
return nil, nil, nil, err
}
return alpnMux, props, nil, nil
}

@ -29,6 +29,7 @@ import (
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/logging"
"github.com/hashicorp/vault/sdk/helper/mlock"
"github.com/hashicorp/vault/sdk/helper/strutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/jinzhu/gorm"
"github.com/mitchellh/cli"
@ -66,7 +67,7 @@ type Server struct {
DevLoginName string
DevPassword string
DevDatabaseUrl string
DatabaseUrl string
DevDatabaseCleanupFunc func() error
Database *gorm.DB
@ -207,7 +208,7 @@ func (b *Server) PrintInfo(ui cli.Ui, mode string) {
}
// Server configuration output
padding := 24
padding := 36
sort.Strings(b.InfoKeys)
ui.Output(fmt.Sprintf("==> Boundary %s configuration:\n", mode))
for _, k := range b.InfoKeys {
@ -225,7 +226,7 @@ func (b *Server) PrintInfo(ui cli.Ui, mode string) {
}
}
func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig) error {
func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig, allowedPurposes []string) error {
// Initialize the listeners
b.Listeners = make([]*ServerListener, 0, len(config.Listeners))
// Make sure we close everything before we exit
@ -242,6 +243,13 @@ func (b *Server) SetupListeners(ui cli.Ui, config *configutil.SharedConfig) erro
defer b.ReloadFuncsLock.Unlock()
for i, lnConfig := range config.Listeners {
for _, purpose := range lnConfig.Purpose {
purpose = strings.ToLower(purpose)
if !strutil.StrListContains(allowedPurposes, purpose) {
return fmt.Errorf("Unknown listener purpose %q", purpose)
}
}
// Override for now
// TODO: Way to configure
lnConfig.TLSCipherSuites = []uint16{
@ -382,6 +390,18 @@ func (b *Server) RunShutdownFuncs() error {
return mErr.ErrorOrNil()
}
func (b *Server) ConnectToDatabase(dialect string) error {
dbase, err := gorm.Open(dialect, b.DatabaseUrl)
if err != nil {
return fmt.Errorf("unable to create db object with dialect %s: %w", dialect, err)
}
b.Database = dbase
gorm.LogFormatter = db.GetGormLogFormatter(b.Logger)
b.Database.SetLogger(db.GetGormLogger(b.Logger))
return nil
}
func (b *Server) CreateDevDatabase(dialect string) error {
c, url, container, err := db.InitDbInDocker(dialect)
// In case of an error, run the cleanup function. If we pass all errors, c should be set to a noop
@ -397,23 +417,19 @@ func (b *Server) CreateDevDatabase(dialect string) error {
}
b.DevDatabaseCleanupFunc = c
b.DevDatabaseUrl = url
b.DatabaseUrl = url
b.InfoKeys = append(b.InfoKeys, "dev database url")
b.Info["dev database url"] = b.DevDatabaseUrl
b.Info["dev database url"] = b.DatabaseUrl
if container != "" {
b.InfoKeys = append(b.InfoKeys, "dev database container")
b.Info["dev database container"] = strings.TrimPrefix(container, "/")
}
dbase, err := gorm.Open(dialect, url)
if err != nil {
return fmt.Errorf("unable to create db object with dialect %s: %w", dialect, err)
if err := b.ConnectToDatabase(dialect); err != nil {
return err
}
b.Database = dbase
gorm.LogFormatter = db.GetGormLogFormatter(b.Logger)
b.Database.SetLogger(db.GetGormLogger(b.Logger))
b.Database.LogMode(true)
rw := db.New(b.Database)
@ -547,5 +563,5 @@ func (b *Server) DestroyDevDatabase() error {
if b.DevDatabaseCleanupFunc != nil {
return b.DevDatabaseCleanupFunc()
}
return errors.New("no dev database cleanup function found")
return nil
}

@ -238,6 +238,8 @@ func (c *Command) Run(args []string) int {
foundCluster = true
case "api":
foundAPI = true
case "worker-alpn-tls":
// Do nothing, in a dev mode we might see it here
default:
c.UI.Error(fmt.Sprintf("Unknown listener purpose %q", lnConfig.Purpose[0]))
return 1
@ -264,7 +266,7 @@ func (c *Command) Run(args []string) int {
c.UI.Error("No listener marked for cluster purpose found, but listener explicitly marked for api was found")
return 1
}
if err := c.SetupListeners(c.UI, c.Config.SharedConfig); err != nil {
if err := c.SetupListeners(c.UI, c.Config.SharedConfig, []string{"api", "cluster"}); err != nil {
c.UI.Error(err.Error())
return 1
}
@ -351,7 +353,11 @@ func (c *Command) ParseFlagsAndConfig(args []string) int {
}
} else {
c.Config, err = config.DevController()
if len(c.flagConfig) == 0 {
c.Config, err = config.DevController()
} else {
c.Config, err = config.LoadFile(c.flagConfig, c.configKMS)
}
if err != nil {
c.UI.Error(fmt.Errorf("Error creating dev config: %w", err).Error())
return 1
@ -430,7 +436,7 @@ func (c *Command) Start() error {
if err := c.controller.Start(); err != nil {
retErr := fmt.Errorf("Error starting controller: %w", err)
if err := c.controller.Shutdown(); err != nil {
if err := c.controller.Shutdown(false); err != nil {
c.UI.Error(retErr.Error())
retErr = fmt.Errorf("Error with controller shutdown: %w", err)
}
@ -454,7 +460,7 @@ func (c *Command) WaitForInterrupt() int {
case <-shutdownCh:
c.UI.Output("==> Boundary controller shutdown triggered")
if err := c.controller.Shutdown(); err != nil {
if err := c.controller.Shutdown(false); err != nil {
c.UI.Error(fmt.Errorf("Error with controller shutdown: %w", err).Error())
}

@ -166,7 +166,7 @@ func (c *Command) Run(args []string) int {
childShutdownCh := make(chan struct{})
devConfig, err := config.DevController()
devConfig, err := config.DevCombined()
if err != nil {
c.UI.Error(fmt.Errorf("Error creating controller dev config: %w", err).Error())
return 1
@ -248,9 +248,17 @@ func (c *Command) Run(args []string) int {
c.UI.Error("Worker Auth KMS not found after parsing KMS blocks")
return 1
}
c.InfoKeys = append(c.InfoKeys, "[Controller] AEAD Key Bytes")
c.Info["[Controller] AEAD Key Bytes"] = devConfig.Controller.DevControllerKey
c.InfoKeys = append(c.InfoKeys, "[Worker-Auth] AEAD Key Bytes")
c.Info["[Worker-Auth] AEAD Key Bytes"] = devConfig.Controller.DevWorkerAuthKey
if c.WorkerAuthKMS == nil {
c.UI.Error("Worker Auth KMS not found after parsing KMS blocks")
return 1
}
// Initialize the listeners
if err := c.SetupListeners(c.UI, devConfig.SharedConfig); err != nil {
if err := c.SetupListeners(c.UI, devConfig.SharedConfig, []string{"api", "cluster", "worker-alpn-tls"}); err != nil {
c.UI.Error(err.Error())
return 1
}

@ -168,7 +168,7 @@ func (c *Command) Run(args []string) int {
"in a Docker container, provide the IPC_LOCK cap to the container."))
}
if err := c.SetupListeners(c.UI, c.Config.SharedConfig); err != nil {
if err := c.SetupListeners(c.UI, c.Config.SharedConfig, []string{"worker-alpn-tls"}); err != nil {
c.UI.Error(err.Error())
return 1
}
@ -235,7 +235,11 @@ func (c *Command) ParseFlagsAndConfig(args []string) int {
}
} else {
c.Config, err = config.DevWorker()
if len(c.flagConfig) == 0 {
c.Config, err = config.DevWorker()
} else {
c.Config, err = config.LoadFile(c.flagConfig, c.configKMS)
}
if err != nil {
c.UI.Error(fmt.Errorf("Error creating dev config: %s", err).Error())
return 1
@ -265,7 +269,7 @@ func (c *Command) Start() error {
if err := c.worker.Start(); err != nil {
retErr := fmt.Errorf("Error starting worker: %w", err)
if err := c.worker.Shutdown(); err != nil {
if err := c.worker.Shutdown(false); err != nil {
c.UI.Error(retErr.Error())
retErr = fmt.Errorf("Error with worker shutdown: %w", err)
}
@ -289,7 +293,7 @@ func (c *Command) WaitForInterrupt() int {
case <-shutdownCh:
c.UI.Output("==> Boundary worker shutdown triggered")
if err := c.worker.Shutdown(); err != nil {
if err := c.worker.Shutdown(false); err != nil {
c.UI.Error(fmt.Errorf("Error with worker shutdown: %w", err).Error())
}

@ -24,6 +24,11 @@ telemetry {
`
devControllerExtraConfig = `
controller {
name = "dev-controller"
description = "A default controller created in dev mode"
}
kms "aead" {
purpose = "controller"
aead_type = "aes-gcm"
@ -73,10 +78,18 @@ worker {
type Config struct {
*configutil.SharedConfig `hcl:"-"`
DevController bool `hcl:"-"`
DefaultOrgId string `hcl:"default_org_id"`
PassthroughDirectory string `hcl:"-"`
Worker *Worker `hcl:"worker"`
DevController bool `hcl:"-"`
DefaultOrgId string `hcl:"default_org_id"`
PassthroughDirectory string `hcl:"-"`
Worker *Worker `hcl:"worker"`
Controller *Controller `hcl:"controller"`
}
type Controller struct {
Name string `hcl:"name"`
Description string `hcl:"description"`
DevControllerKey string `hcl:"-"`
DevWorkerAuthKey string `hcl:"-"`
}
type Worker struct {
@ -123,6 +136,8 @@ func DevController() (*Config, error) {
return nil, fmt.Errorf("error parsing dev config: %w", err)
}
parsed.DevController = true
parsed.Controller.DevControllerKey = controllerKey
parsed.Controller.DevWorkerAuthKey = workerAuthKey
return parsed, nil
}
@ -134,6 +149,8 @@ func DevCombined() (*Config, error) {
return nil, fmt.Errorf("error parsing dev config: %w", err)
}
parsed.DevController = true
parsed.Controller.DevControllerKey = controllerKey
parsed.Controller.DevWorkerAuthKey = workerAuthKey
return parsed, nil
}

@ -69,6 +69,10 @@ func TestDevController(t *testing.T) {
MaximumGaugeCardinality: 500,
},
},
Controller: &Controller{
Name: "dev-controller",
Description: "A default controller created in dev mode",
},
DevController: true,
}
@ -76,6 +80,8 @@ func TestDevController(t *testing.T) {
exp.Listeners[1].RawConfig = actual.Listeners[1].RawConfig
exp.Seals[0].Config["key"] = actual.Seals[0].Config["key"]
exp.Seals[1].Config["key"] = actual.Seals[1].Config["key"]
exp.Controller.DevControllerKey = actual.Seals[0].Config["key"]
exp.Controller.DevWorkerAuthKey = actual.Seals[1].Config["key"]
assert.Equal(t, exp, actual)
}
@ -125,7 +131,6 @@ func TestDevWorker(t *testing.T) {
}
func TestConfigDecrypt(t *testing.T) {
const (
clr = `
kms "aead" {

@ -1321,6 +1321,54 @@ begin;
commit;
`),
},
"migrations/08_servers.down.sql": {
name: "08_servers.down.sql",
bytes: []byte(`
begin;
drop table workers;
drop table controllers;
commit;
`),
},
"migrations/08_servers.up.sql": {
name: "08_servers.up.sql",
bytes: []byte(`
begin;
-- For now at least the IDs will be the same as the name, because this allows us
-- to not have to persist some generated ID to worker and controller nodes.
-- Eventually we may want them to diverge, so we have both here for now.
create table servers (
private_id text,
type text,
name text not null unique,
description text,
address text,
create_time wt_timestamp,
update_time wt_timestamp,
primary key (private_id, type)
);
create trigger
immutable_columns
before
update on servers
for each row execute procedure immutable_columns('create_time');
create trigger
default_create_time_column
before
insert on servers
for each row execute procedure default_create_time();
commit;
`),
},
"migrations/10_static_host.down.sql": {

@ -0,0 +1,6 @@
begin;
drop table workers;
drop table controllers;
commit;

@ -0,0 +1,30 @@
begin;
-- For now at least the IDs will be the same as the name, because this allows us
-- to not have to persist some generated ID to worker and controller nodes.
-- Eventually we may want them to diverge, so we have both here for now.
create table servers (
private_id text,
type text,
name text not null unique,
description text,
address text,
create_time wt_timestamp,
update_time wt_timestamp,
primary key (private_id, type)
);
create trigger
immutable_columns
before
update on servers
for each row execute procedure immutable_columns('create_time');
create trigger
default_create_time_column
before
insert on servers
for each row execute procedure default_create_time();
commit;

@ -3049,6 +3049,25 @@
}
}
},
"controller.api.services.v1.StatusResponse": {
"type": "object",
"properties": {
"controllers": {
"type": "array",
"items": {
"$ref": "#/definitions/controller.servers.v1.Server"
},
"description": "Active controllers. This can be used (eventually) for conneciton\nmanagement."
},
"cancel_job_ids": {
"type": "array",
"items": {
"type": "string"
},
"description": "Jobs that should be canceled: ones assigned to this worker that have been\nreported as active but are in canceling state in the database. Once the\nworker cancels the job, it will no longer show up in active_jobs in the\nnext heartbeat, and we can move the job to canceled state."
}
}
},
"controller.api.services.v1.UpdateAccountResponse": {
"type": "object",
"properties": {
@ -3121,6 +3140,50 @@
}
}
},
"controller.servers.v1.Server": {
"type": "object",
"properties": {
"private_id": {
"type": "string",
"title": "Private ID of the resource"
},
"type": {
"type": "string",
"title": "Type of the resource (controller, worker)"
},
"name": {
"type": "string",
"title": "Name of the resource"
},
"description": {
"type": "string",
"title": "Description of the resource"
},
"address": {
"type": "string",
"title": "Address for the server"
},
"create_time": {
"$ref": "#/definitions/controller.storage.timestamp.v1.Timestamp",
"title": "First seen time from the RDBMS"
},
"update_time": {
"$ref": "#/definitions/controller.storage.timestamp.v1.Timestamp",
"title": "Last time there was an update"
}
},
"title": "Server contains all fields related to a Controller or Worker resource"
},
"controller.storage.timestamp.v1.Timestamp": {
"type": "object",
"properties": {
"timestamp": {
"type": "string",
"format": "date-time"
}
},
"title": "Timestamp for storage messages. We've defined a new local type wrapper\nof google.protobuf.Timestamp so we can implement sql.Scanner and sql.Valuer\ninterfaces. See:\nhttps://golang.org/pkg/database/sql/#Scanner\nhttps://golang.org/pkg/database/sql/driver/#Valuer"
},
"google.protobuf.FieldMask": {
"type": "object",
"properties": {

@ -0,0 +1,352 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.12.4
// source: controller/api/services/v1/worker_service.proto
package services
import (
context "context"
proto "github.com/golang/protobuf/proto"
servers "github.com/hashicorp/boundary/internal/servers"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
type StatusRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// The worker info. We could use information from the TLS connection but this
// is easier and going the other route doesn't provide much benefit -- if you
// get access to the key and spoof the connection, you're already compromised.
Worker *servers.Server `protobuf:"bytes,10,opt,name=worker,proto3" json:"worker,omitempty"`
// Jobs currently active on this worker.
ActiveJobIds []string `protobuf:"bytes,20,rep,name=active_job_ids,json=activeJobIds,proto3" json:"active_job_ids,omitempty"`
}
func (x *StatusRequest) Reset() {
*x = StatusRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_controller_api_services_v1_worker_service_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StatusRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StatusRequest) ProtoMessage() {}
func (x *StatusRequest) ProtoReflect() protoreflect.Message {
mi := &file_controller_api_services_v1_worker_service_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StatusRequest.ProtoReflect.Descriptor instead.
func (*StatusRequest) Descriptor() ([]byte, []int) {
return file_controller_api_services_v1_worker_service_proto_rawDescGZIP(), []int{0}
}
func (x *StatusRequest) GetWorker() *servers.Server {
if x != nil {
return x.Worker
}
return nil
}
func (x *StatusRequest) GetActiveJobIds() []string {
if x != nil {
return x.ActiveJobIds
}
return nil
}
type StatusResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Active controllers. This can be used (eventually) for conneciton
// management.
Controllers []*servers.Server `protobuf:"bytes,10,rep,name=controllers,proto3" json:"controllers,omitempty"`
// Jobs that should be canceled: ones assigned to this worker that have been
// reported as active but are in canceling state in the database. Once the
// worker cancels the job, it will no longer show up in active_jobs in the
// next heartbeat, and we can move the job to canceled state.
CancelJobIds []string `protobuf:"bytes,20,rep,name=cancel_job_ids,json=cancelJobIds,proto3" json:"cancel_job_ids,omitempty"`
}
func (x *StatusResponse) Reset() {
*x = StatusResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_controller_api_services_v1_worker_service_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StatusResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StatusResponse) ProtoMessage() {}
func (x *StatusResponse) ProtoReflect() protoreflect.Message {
mi := &file_controller_api_services_v1_worker_service_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StatusResponse.ProtoReflect.Descriptor instead.
func (*StatusResponse) Descriptor() ([]byte, []int) {
return file_controller_api_services_v1_worker_service_proto_rawDescGZIP(), []int{1}
}
func (x *StatusResponse) GetControllers() []*servers.Server {
if x != nil {
return x.Controllers
}
return nil
}
func (x *StatusResponse) GetCancelJobIds() []string {
if x != nil {
return x.CancelJobIds
}
return nil
}
var File_controller_api_services_v1_worker_service_proto protoreflect.FileDescriptor
var file_controller_api_services_v1_worker_service_proto_rawDesc = []byte{
0x0a, 0x2f, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x61, 0x70, 0x69,
0x2f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2f, 0x76, 0x31, 0x2f, 0x77, 0x6f, 0x72,
0x6b, 0x65, 0x72, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74,
0x6f, 0x12, 0x1a, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70,
0x69, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76, 0x31, 0x1a, 0x23, 0x63,
0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x73, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x70, 0x72, 0x6f,
0x74, 0x6f, 0x22, 0x6c, 0x0a, 0x0d, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75,
0x65, 0x73, 0x74, 0x12, 0x35, 0x0a, 0x06, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x18, 0x0a, 0x20,
0x01, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72,
0x2e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x72, 0x76,
0x65, 0x72, 0x52, 0x06, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x12, 0x24, 0x0a, 0x0e, 0x61, 0x63,
0x74, 0x69, 0x76, 0x65, 0x5f, 0x6a, 0x6f, 0x62, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x14, 0x20, 0x03,
0x28, 0x09, 0x52, 0x0c, 0x61, 0x63, 0x74, 0x69, 0x76, 0x65, 0x4a, 0x6f, 0x62, 0x49, 0x64, 0x73,
0x22, 0x77, 0x0a, 0x0e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e,
0x73, 0x65, 0x12, 0x3f, 0x0a, 0x0b, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72,
0x73, 0x18, 0x0a, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x1d, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f,
0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x76, 0x31, 0x2e,
0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x0b, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c,
0x65, 0x72, 0x73, 0x12, 0x24, 0x0a, 0x0e, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x5f, 0x6a, 0x6f,
0x62, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x14, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0c, 0x63, 0x61, 0x6e,
0x63, 0x65, 0x6c, 0x4a, 0x6f, 0x62, 0x49, 0x64, 0x73, 0x32, 0x72, 0x0a, 0x0d, 0x57, 0x6f, 0x72,
0x6b, 0x65, 0x72, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x61, 0x0a, 0x06, 0x53, 0x74,
0x61, 0x74, 0x75, 0x73, 0x12, 0x29, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65,
0x72, 0x2e, 0x61, 0x70, 0x69, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76,
0x31, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a,
0x2a, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x61, 0x70, 0x69,
0x2e, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61,
0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x4d, 0x5a,
0x4b, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68,
0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f, 0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69,
0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x63, 0x6f, 0x6e, 0x74,
0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x69,
0x63, 0x65, 0x73, 0x3b, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x73, 0x62, 0x06, 0x70, 0x72,
0x6f, 0x74, 0x6f, 0x33,
}
var (
file_controller_api_services_v1_worker_service_proto_rawDescOnce sync.Once
file_controller_api_services_v1_worker_service_proto_rawDescData = file_controller_api_services_v1_worker_service_proto_rawDesc
)
func file_controller_api_services_v1_worker_service_proto_rawDescGZIP() []byte {
file_controller_api_services_v1_worker_service_proto_rawDescOnce.Do(func() {
file_controller_api_services_v1_worker_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_controller_api_services_v1_worker_service_proto_rawDescData)
})
return file_controller_api_services_v1_worker_service_proto_rawDescData
}
var file_controller_api_services_v1_worker_service_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_controller_api_services_v1_worker_service_proto_goTypes = []interface{}{
(*StatusRequest)(nil), // 0: controller.api.services.v1.StatusRequest
(*StatusResponse)(nil), // 1: controller.api.services.v1.StatusResponse
(*servers.Server)(nil), // 2: controller.servers.v1.Server
}
var file_controller_api_services_v1_worker_service_proto_depIdxs = []int32{
2, // 0: controller.api.services.v1.StatusRequest.worker:type_name -> controller.servers.v1.Server
2, // 1: controller.api.services.v1.StatusResponse.controllers:type_name -> controller.servers.v1.Server
0, // 2: controller.api.services.v1.WorkerService.Status:input_type -> controller.api.services.v1.StatusRequest
1, // 3: controller.api.services.v1.WorkerService.Status:output_type -> controller.api.services.v1.StatusResponse
3, // [3:4] is the sub-list for method output_type
2, // [2:3] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_controller_api_services_v1_worker_service_proto_init() }
func file_controller_api_services_v1_worker_service_proto_init() {
if File_controller_api_services_v1_worker_service_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_controller_api_services_v1_worker_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StatusRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_controller_api_services_v1_worker_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StatusResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_controller_api_services_v1_worker_service_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_controller_api_services_v1_worker_service_proto_goTypes,
DependencyIndexes: file_controller_api_services_v1_worker_service_proto_depIdxs,
MessageInfos: file_controller_api_services_v1_worker_service_proto_msgTypes,
}.Build()
File_controller_api_services_v1_worker_service_proto = out.File
file_controller_api_services_v1_worker_service_proto_rawDesc = nil
file_controller_api_services_v1_worker_service_proto_goTypes = nil
file_controller_api_services_v1_worker_service_proto_depIdxs = nil
}
// Reference imports to suppress errors if they are not otherwise used.
var _ context.Context
var _ grpc.ClientConnInterface
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
const _ = grpc.SupportPackageIsVersion6
// WorkerServiceClient is the client API for WorkerService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
type WorkerServiceClient interface {
Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error)
}
type workerServiceClient struct {
cc grpc.ClientConnInterface
}
func NewWorkerServiceClient(cc grpc.ClientConnInterface) WorkerServiceClient {
return &workerServiceClient{cc}
}
func (c *workerServiceClient) Status(ctx context.Context, in *StatusRequest, opts ...grpc.CallOption) (*StatusResponse, error) {
out := new(StatusResponse)
err := c.cc.Invoke(ctx, "/controller.api.services.v1.WorkerService/Status", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// WorkerServiceServer is the server API for WorkerService service.
type WorkerServiceServer interface {
Status(context.Context, *StatusRequest) (*StatusResponse, error)
}
// UnimplementedWorkerServiceServer can be embedded to have forward compatible implementations.
type UnimplementedWorkerServiceServer struct {
}
func (*UnimplementedWorkerServiceServer) Status(context.Context, *StatusRequest) (*StatusResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method Status not implemented")
}
func RegisterWorkerServiceServer(s *grpc.Server, srv WorkerServiceServer) {
s.RegisterService(&_WorkerService_serviceDesc, srv)
}
func _WorkerService_Status_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StatusRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(WorkerServiceServer).Status(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/controller.api.services.v1.WorkerService/Status",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(WorkerServiceServer).Status(ctx, req.(*StatusRequest))
}
return interceptor(ctx, in, info, handler)
}
var _WorkerService_serviceDesc = grpc.ServiceDesc{
ServiceName: "controller.api.services.v1.WorkerService",
HandlerType: (*WorkerServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Status",
Handler: _WorkerService_Status_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "controller/api/services/v1/worker_service.proto",
}

@ -0,0 +1,33 @@
syntax = "proto3";
package controller.api.services.v1;
option go_package = "github.com/hashicorp/boundary/internal/gen/controller/api/services;services";
import "controller/servers/v1/servers.proto";
service WorkerService {
rpc Status(StatusRequest) returns (StatusResponse) {}
}
message StatusRequest {
// The worker info. We could use information from the TLS connection but this
// is easier and going the other route doesn't provide much benefit -- if you
// get access to the key and spoof the connection, you're already compromised.
servers.v1.Server worker = 10;
// Jobs currently active on this worker.
repeated string active_job_ids = 20;
}
message StatusResponse {
// Active controllers. This can be used (eventually) for conneciton
// management.
repeated servers.v1.Server controllers = 10;
// Jobs that should be canceled: ones assigned to this worker that have been
// reported as active but are in canceling state in the database. Once the
// worker cancels the job, it will no longer show up in active_jobs in the
// next heartbeat, and we can move the job to canceled state.
repeated string cancel_job_ids = 20;
}

@ -0,0 +1,30 @@
syntax = "proto3";
package controller.servers.v1;
option go_package = "github.com/hashicorp/boundary/internal/servers;servers";
import "controller/storage/timestamp/v1/timestamp.proto";
// Server contains all fields related to a Controller or Worker resource
message Server {
// Private ID of the resource
string private_id = 10;
// Type of the resource (controller, worker)
string type = 20;
// Name of the resource
string name = 30;
// Description of the resource
string description = 40;
// Address for the server
string address = 50;
// First seen time from the RDBMS
storage.timestamp.v1.Timestamp create_time = 60;
// Last time there was an update
storage.timestamp.v1.Timestamp update_time = 70;
}

@ -5,11 +5,13 @@ import (
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/servers"
)
type (
IamRepoFactory func() (*iam.Repository, error)
StaticRepoFactory func() (*static.Repository, error)
AuthTokenRepoFactory func() (*authtoken.Repository, error)
ServersRepoFactory func() (*servers.Repository, error)
PasswordAuthRepoFactory func() (*password.Repository, error)
)

@ -4,15 +4,21 @@ import (
"context"
"crypto/rand"
"fmt"
"sync"
"github.com/hashicorp/boundary/internal/auth/password"
"github.com/hashicorp/boundary/internal/authtoken"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/host/static"
"github.com/hashicorp/boundary/internal/iam"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/servers/controller/common"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/mlock"
"github.com/patrickmn/go-cache"
ua "go.uber.org/atomic"
)
type Controller struct {
@ -21,25 +27,46 @@ type Controller struct {
baseContext context.Context
baseCancel context.CancelFunc
started ua.Bool
// Repo factory methods
IamRepoFn common.IamRepoFactory
StaticHostRepoFn common.StaticRepoFactory
AuthTokenRepoFn common.AuthTokenRepoFactory
workerAuthCache *cache.Cache
// Used for testing
workerStatusUpdateTimes *sync.Map
// Repo factory methods
IamRepoFn common.IamRepoFactory
StaticHostRepoFn common.StaticRepoFactory
AuthTokenRepoFn common.AuthTokenRepoFactory
ServersRepoFn common.ServersRepoFactory
PasswordAuthRepoFn common.PasswordAuthRepoFactory
clusterAddress string
}
func New(conf *Config) (*Controller, error) {
c := &Controller{
conf: conf,
logger: conf.Logger.Named("controller"),
conf: conf,
logger: conf.Logger.Named("controller"),
workerStatusUpdateTimes: new(sync.Map),
}
c.started.Store(false)
if conf.SecureRandomReader == nil {
conf.SecureRandomReader = rand.Reader
}
var err error
if conf.RawConfig.Controller == nil {
conf.RawConfig.Controller = new(config.Controller)
}
if conf.RawConfig.Controller.Name == "" {
if conf.RawConfig.Controller.Name, err = base62.Random(10); err != nil {
return nil, fmt.Errorf("error auto-generating controller name: %w", err)
}
}
if !conf.RawConfig.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
@ -56,8 +83,6 @@ func New(conf *Config) (*Controller, error) {
}
}
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
// Set up repo stuff
dbase := db.New(c.conf.Database)
c.IamRepoFn = func() (*iam.Repository, error) {
@ -69,23 +94,53 @@ func New(conf *Config) (*Controller, error) {
c.AuthTokenRepoFn = func() (*authtoken.Repository, error) {
return authtoken.NewRepository(dbase, dbase, c.conf.ControllerKMS)
}
c.ServersRepoFn = func() (*servers.Repository, error) {
return servers.NewRepository(c.logger.Named("servers.repository"), dbase, dbase, c.conf.ControllerKMS)
}
c.PasswordAuthRepoFn = func() (*password.Repository, error) {
return password.NewRepository(dbase, dbase, c.conf.ControllerKMS)
}
c.workerAuthCache = cache.New(0, 0)
return c, nil
}
func (c *Controller) Start() error {
if c.started.Load() {
c.logger.Info("already started, skipping")
return nil
}
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
if err := c.startListeners(); err != nil {
return fmt.Errorf("error starting controller listeners: %w", err)
}
c.startStatusTicking(c.baseContext)
c.started.Store(true)
return nil
}
func (c *Controller) Shutdown() error {
if err := c.stopListeners(); err != nil {
func (c *Controller) Shutdown(serversOnly bool) error {
if !c.started.Load() {
c.logger.Info("already shut down, skipping")
return nil
}
c.baseCancel()
if err := c.stopListeners(serversOnly); err != nil {
return fmt.Errorf("error stopping controller listeners: %w", err)
}
c.clusterAddress = ""
c.started.Store(false)
return nil
}
// WorkerStatusUpdateTimes returns the map, which specifically is held in _this_
// controller, not the DB. It's used in tests to verify that a given controller
// is receiving updates from an expected set of workers, to test out balancing
// and auto reconnection.
func (c *Controller) WorkerStatusUpdateTimes() *sync.Map {
return c.workerStatusUpdateTimes
}

@ -32,6 +32,7 @@ import (
type HandlerProperties struct {
ListenerConfig *configutil.Listener
CancelCtx context.Context
}
// Handler returns an http.Handler for the services. This can be used on
@ -40,7 +41,7 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
// Create the muxer to handle the actual endpoints
mux := http.NewServeMux()
h, err := handleGrpcGateway(c)
h, err := handleGrpcGateway(c, props)
if err != nil {
return nil, err
}
@ -54,10 +55,10 @@ func (c *Controller) handler(props HandlerProperties) (http.Handler, error) {
return commonWrappedHandler, nil
}
func handleGrpcGateway(c *Controller) (http.Handler, error) {
// Register*ServiceHandlerServer methods ignore the passed in ctx. Using the baseContext now just in case this changes
// in the future, at which point we'll want to be using the baseContext.
ctx := c.baseContext
func handleGrpcGateway(c *Controller, props HandlerProperties) (http.Handler, error) {
// Register*ServiceHandlerServer methods ignore the passed in ctx. Using
// the a context now just in case this changes in the future
ctx := props.CancelCtx
mux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &runtime.HTTPBodyMarshaler{
Marshaler: &runtime.JSONPb{

@ -0,0 +1,47 @@
package workers
import (
"context"
"sync"
"time"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/boundary/internal/servers/controller/common"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/go-hclog"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type workerServiceServer struct {
logger hclog.Logger
repoFn common.ServersRepoFactory
updateTimes *sync.Map
}
func NewWorkerServiceServer(logger hclog.Logger, repoFn common.ServersRepoFactory, updateTimes *sync.Map) *workerServiceServer {
return &workerServiceServer{
logger: logger,
repoFn: repoFn,
updateTimes: updateTimes,
}
}
func (ws *workerServiceServer) Status(ctx context.Context, req *pbs.StatusRequest) (*pbs.StatusResponse, error) {
ws.logger.Trace("got status request from worker", "name", req.Worker.Name)
ws.updateTimes.Store(req.Worker.Name, time.Now())
repo, err := ws.repoFn()
if err != nil {
ws.logger.Error("error getting servers repo", "error", err)
return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error aqcuiring repo to store worker status: %v", err)
}
req.Worker.Type = resource.Worker.String()
controllers, _, err := repo.Upsert(ctx, req.Worker)
if err != nil {
ws.logger.Error("error storing worker status", "error", err)
return &pbs.StatusResponse{}, status.Errorf(codes.Internal, "Error storing worker status: %v", err)
}
return &pbs.StatusResponse{
Controllers: controllers,
}, nil
}

@ -0,0 +1,89 @@
package controller
import (
"errors"
"fmt"
"net"
)
// interceptingListener allows us to validate the nonce from a connection before
// handing it off to the gRPC server. It is expected that the first thing a
// connection sends after successful TLS validation is the nonce that was
// contained in the KMS-encrypted TLS info. The reason is for replay attacks --
// since the ALPN NextProtos are sent in the clear, we don't want anyone
// sniffing the connection to simply replay the hello message and gain access.
// By requiring the information encrypted within that message to match the first
// bytes sent in the connection itself, we require that the service making the
// incoming connection had to either be the service that did the initial
// encryption, or had to be able to decrypt that against the same KMS key. This
// means that KMS access is a requirement, and simple replay itself is not
// sufficient.
//
// Note that this is semi-weak against a scenario where the value is decrypted
// later since a controller restart would clear the cache. We could store a list
// of seen nonces in the database, but since the original certificate was only
// good for 3 minutes and 30 seconds anyways, the decryption would need to
// happen within a short time window instead of much later. We can adjust this
// window if we want (or even make it tunable), or store values in the DB as
// well until the certificate expiration.
type interceptingListener struct {
baseLn net.Listener
c *Controller
}
func newInterceptingListener(c *Controller, baseLn net.Listener) *interceptingListener {
ret := &interceptingListener{
c: c,
baseLn: baseLn,
}
return ret
}
func (m *interceptingListener) Accept() (net.Conn, error) {
conn, err := m.baseLn.Accept()
if err != nil {
if conn != nil {
if err := conn.Close(); err != nil {
m.c.logger.Error("error closing worker connection", "error", err)
}
}
return nil, err
}
if m.c.logger.IsTrace() {
m.c.logger.Trace("got connection", "addr", conn.RemoteAddr())
}
nonce := make([]byte, 20)
read, err := conn.Read(nonce)
if err != nil {
if err := conn.Close(); err != nil {
m.c.logger.Error("error closing worker connection", "error", err)
}
return nil, fmt.Errorf("error reading nonce from connection: %w", err)
}
if read != len(nonce) {
if err := conn.Close(); err != nil {
m.c.logger.Error("error closing worker connection", "error", err)
}
return nil, fmt.Errorf("error reading nonce from worker, expected %d bytes, got %d", 20, read)
}
workerInfoRaw, found := m.c.workerAuthCache.Get(string(nonce))
if !found {
if err := conn.Close(); err != nil {
m.c.logger.Error("error closing worker connection", "error", err)
}
return nil, errors.New("did not find valid nonce for incoming worker")
}
workerInfo := workerInfoRaw.(*workerAuthEntry)
workerInfo.conn = conn
m.c.logger.Info("worker successfully authed", "name", workerInfo.Name)
return conn, nil
}
func (m *interceptingListener) Close() error {
return m.baseLn.Close()
}
func (m *interceptingListener) Addr() net.Addr {
return m.baseLn.Addr()
}

@ -5,19 +5,21 @@ import (
"crypto/tls"
"errors"
"fmt"
"math"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/boundary/internal/servers/controller/handlers/workers"
"github.com/hashicorp/go-alpnmux"
"github.com/hashicorp/go-multierror"
"google.golang.org/grpc"
)
func (c *Controller) startListeners() error {
var retErr *multierror.Error
servers := make([]func(), 0, len(c.conf.Listeners))
configureForAPI := func(ln *base.ServerListener) error {
@ -44,6 +46,10 @@ func (c *Controller) startListeners() error {
}
*/
// Resolve it here to avoid race conditions if the base context is
// replaced
cancelCtx := c.baseContext
server := &http.Server{
Handler: handler,
ReadHeaderTimeout: 10 * time.Second,
@ -51,7 +57,7 @@ func (c *Controller) startListeners() error {
IdleTimeout: 5 * time.Minute,
ErrorLog: c.logger.StandardLogger(nil),
BaseContext: func(net.Listener) context.Context {
return c.baseContext
return cancelCtx
},
}
ln.HTTPServer = server
@ -71,7 +77,10 @@ func (c *Controller) startListeners() error {
switch ln.Config.TLSDisable {
case true:
l := ln.Mux.GetListener(alpnmux.NoProto)
l, err := ln.Mux.RegisterProto(alpnmux.NoProto, nil)
if err != nil {
return fmt.Errorf("error getting non-tls listener: %w", err)
}
if l == nil {
return errors.New("could not get non-tls listener")
}
@ -84,8 +93,7 @@ func (c *Controller) startListeners() error {
for _, v := range protos {
l := ln.Mux.GetListener(v)
if l == nil {
retErr = multierror.Append(retErr, fmt.Errorf("could not get tls proto %q listener", v))
continue
return fmt.Errorf("could not get tls proto %q listener", v)
}
servers = append(servers, func() {
go server.Serve(l)
@ -97,36 +105,30 @@ func (c *Controller) startListeners() error {
}
configureForCluster := func(ln *base.ServerListener) error {
// Clear out in case this is a second start of the controller
ln.Mux.UnregisterProto(alpnmux.DefaultProto)
l, err := ln.Mux.RegisterProto(alpnmux.DefaultProto, &tls.Config{
GetConfigForClient: c.validateWorkerTLS,
})
if err != nil {
return fmt.Errorf("error getting sub-listener for worker proto: %w", err)
}
ln.ALPNListener = l
c.clusterAddress = l.Addr().String()
c.logger.Info("cluster address", "addr", c.clusterAddress)
// TODO: Pass this to a handler, e.g. a grpc server, in the mean time
// just accepting what comes
go func() {
for {
conn, err := ln.ALPNListener.Accept()
if err != nil {
if !strings.Contains(err.Error(), "use of closed network connection") {
c.logger.Info("default alpn listener errored, exiting", "error", err)
}
return
}
_, err = conn.Read(make([]byte, 3))
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error reading test string from worker for worker auth: %w", err))
}
_, err = conn.Write([]byte("bar"))
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to worker for worker auth: %w", err))
}
conn.Close()
}
}()
workerServer := grpc.NewServer(
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.MaxSendMsgSize(math.MaxInt32),
)
services.RegisterWorkerServiceServer(workerServer, workers.NewWorkerServiceServer(c.logger.Named("worker-handler"), c.ServersRepoFn, c.workerStatusUpdateTimes))
interceptor := newInterceptingListener(c, l)
ln.ALPNListener = interceptor
ln.GrpcServer = workerServer
servers = append(servers, func() {
go workerServer.Serve(interceptor)
})
return nil
}
@ -137,23 +139,20 @@ func (c *Controller) startListeners() error {
case "api":
err = configureForAPI(ln)
case "cluster":
err = configureForCluster(ln)
if c.clusterAddress != "" {
err = errors.New("more than one cluster listener found")
} else {
err = configureForCluster(ln)
}
case "worker-alpn-tls":
// Do nothing, in a dev mode we might see it here
default:
err = fmt.Errorf("unknown listener purpose %q", purpose)
}
if err != nil {
break
return err
}
}
if err != nil {
retErr = multierror.Append(retErr, err)
continue
}
}
err := retErr.ErrorOrNil()
if err != nil {
return err
}
for _, s := range servers {
@ -163,23 +162,34 @@ func (c *Controller) startListeners() error {
return nil
}
func (c *Controller) stopListeners() error {
func (c *Controller) stopListeners(serversOnly bool) error {
serverWg := new(sync.WaitGroup)
for _, ln := range c.conf.Listeners {
if ln.HTTPServer == nil {
continue
}
localLn := ln
serverWg.Add(1)
go func() {
defer serverWg.Done()
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, localLn.Config.MaxRequestDuration)
defer shutdownKillCancel()
defer serverWg.Done()
localLn.HTTPServer.Shutdown(shutdownKill)
if localLn.GrpcServer != nil {
// Deal with the worst case
go func() {
<-shutdownKill.Done()
localLn.GrpcServer.Stop()
}()
localLn.GrpcServer.GracefulStop()
}
if localLn.HTTPServer != nil {
localLn.HTTPServer.Shutdown(shutdownKill)
}
}()
}
serverWg.Wait()
if serversOnly {
return nil
}
var retErr *multierror.Error
for _, ln := range c.conf.Listeners {
if err := ln.Mux.Close(); err != nil {

@ -0,0 +1,68 @@
package controller_test
import (
"testing"
"time"
"github.com/hashicorp/boundary/api/authmethods"
"github.com/hashicorp/boundary/api/scopes"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestAuthenticationMulti(t *testing.T) {
assert, require := assert.New(t), require.New(t)
amId := "paum_1234567890"
user := "user"
password := "passpass"
orgId := "o_1234567890"
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
DefaultOrgId: orgId,
DefaultAuthMethodId: amId,
DefaultLoginName: user,
DefaultPassword: password,
Logger: logger.Named("c1"),
})
defer c1.Shutdown()
c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{
Logger: logger.Named("c2"),
})
defer c2.Shutdown()
auth := authmethods.NewAuthMethodsClient(c1.Client())
token1, apiErr, err := auth.Authenticate(c1.Context(), amId, user, password)
require.Nil(err)
require.Nil(apiErr)
require.NotNil(token1)
time.Sleep(5 * time.Second)
auth = authmethods.NewAuthMethodsClient(c2.Client())
token2, apiErr, err := auth.Authenticate(c2.Context(), amId, user, password)
require.Nil(err)
require.Nil(apiErr)
require.NotNil(token2)
assert.NotEqual(token1.Token, token2.Token)
c1.Client().SetToken(token1.Token)
c2.Client().SetToken(token1.Token) // Same token, as it should work on both
// Create a project, read from the other
proj, apiErr, err := scopes.NewScopesClient(c1.Client()).Create(c1.Context(), orgId)
require.NoError(err)
require.Nil(apiErr)
require.NotNil(proj)
projId := proj.Id
proj, apiErr, err = scopes.NewScopesClient(c2.Client()).Read(c2.Context(), projId)
require.NoError(err)
require.Nil(apiErr)
require.NotNil(proj)
}

@ -0,0 +1,48 @@
package controller
import (
"context"
"time"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/types/resource"
)
// In the future we could make this configurable
const (
statusInterval = 10 * time.Second
)
func (c *Controller) startStatusTicking(cancelCtx context.Context) {
go func() {
timer := time.NewTimer(0)
for {
select {
case <-cancelCtx.Done():
c.logger.Info("status ticking shutting down")
return
case <-timer.C:
server := &servers.Server{
PrivateId: c.conf.RawConfig.Controller.Name,
Name: c.conf.RawConfig.Controller.Name,
Type: resource.Controller.String(),
Description: c.conf.RawConfig.Controller.Description,
Address: c.clusterAddress,
}
repo, err := c.ServersRepoFn()
if err != nil {
c.logger.Error("error fetching repository for status update", "error", err)
} else {
_, _, err = repo.Upsert(cancelCtx, server)
if err != nil {
c.logger.Error("error performing status update", "error", err)
} else {
c.logger.Trace("controller status successfully saved")
}
}
timer.Reset(statusInterval)
}
}
}()
}

@ -10,19 +10,23 @@ import (
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/vault/sdk/helper/base62"
)
// TestController wraps a base.Server and Controller to provide a
// fully-programmatic controller for tests. Error checking (for instance, for
// valid config) is not stringent at the moment.
type TestController struct {
b *base.Server
c *Controller
t *testing.T
addrs []string // The address the Controller API is listening on
client *api.Client
ctx context.Context
cancel context.CancelFunc
b *base.Server
c *Controller
t *testing.T
apiAddrs []string // The address the Controller API is listening on
clusterAddrs []string
client *api.Client
ctx context.Context
cancel context.CancelFunc
name string
}
// Controller returns the underlying controller
@ -30,6 +34,10 @@ func (tc *TestController) Controller() *Controller {
return tc.c
}
func (tc *TestController) Config() *Config {
return tc.c.conf
}
func (tc *TestController) Client() *api.Client {
return tc.client
}
@ -42,23 +50,52 @@ func (tc *TestController) Cancel() {
tc.cancel()
}
func (tc *TestController) Name() string {
return tc.name
}
func (tc *TestController) ApiAddrs() []string {
if tc.addrs != nil {
return tc.addrs
return tc.addrs("api")
}
func (tc *TestController) ClusterAddrs() []string {
return tc.addrs("cluster")
}
func (tc *TestController) addrs(purpose string) []string {
var prefix string
switch purpose {
case "api":
if tc.apiAddrs != nil {
return tc.apiAddrs
}
prefix = "http://"
case "cluster":
if tc.clusterAddrs != nil {
return tc.clusterAddrs
}
}
addrs := make([]string, 0, len(tc.b.Listeners))
for _, listener := range tc.b.Listeners {
if listener.Config.Purpose[0] == "api" {
if listener.Config.Purpose[0] == purpose {
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
if !ok {
tc.t.Fatal("could not parse address as a TCP addr")
}
addr := fmt.Sprintf("http://%s:%d", tcpAddr.IP.String(), tcpAddr.Port)
tc.addrs = append(tc.addrs, addr)
addr := fmt.Sprintf("%s%s:%d", prefix, tcpAddr.IP.String(), tcpAddr.Port)
addrs = append(addrs, addr)
}
}
return tc.addrs
switch purpose {
case "api":
tc.apiAddrs = addrs
case "cluster":
tc.clusterAddrs = addrs
}
return addrs
}
func (tc *TestController) buildClient() {
@ -90,7 +127,7 @@ func (tc *TestController) Shutdown() {
tc.cancel()
if tc.c != nil {
if err := tc.c.Shutdown(); err != nil {
if err := tc.c.Shutdown(false); err != nil {
tc.t.Error(err)
}
}
@ -127,6 +164,10 @@ type TestControllerOpts struct {
// database
DisableDatabaseCreation bool
// If set, instead of creating a dev database, it will connect to an
// existing database given the url
DatabaseUrl string
// If true, the controller will not be started
DisableAutoStart bool
@ -134,6 +175,19 @@ type TestControllerOpts struct {
// performed but they won't cause 403 Forbidden. Useful for API-level
// testing to avoid a lot of faff.
DisableAuthorizationFailures bool
// The controller KMS to use, or one will be created
ControllerKMS wrapping.Wrapper
// The worker auth KMS to use, or one will be created
WorkerAuthKMS wrapping.Wrapper
// The name to use for the controller, otherwise one will be randomly
// generated, unless provided in a non-nil Config
Name string
// The logger to use, or one will be created
Logger hclog.Logger
}
func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
@ -162,6 +216,7 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
if err != nil {
t.Fatal(err)
}
opts.Config.Controller.Name = opts.Name
}
// Set default org ID, preferring one passed in from opts over config
@ -182,24 +237,52 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
}
// Start a logger
tc.b.Logger = hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
tc.b.Logger = opts.Logger
if tc.b.Logger == nil {
tc.b.Logger = hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
}
if opts.Config.Controller == nil {
opts.Config.Controller = new(config.Controller)
}
if opts.Config.Controller.Name == "" {
opts.Config.Controller.Name, err = base62.Random(5)
if err != nil {
t.Fatal(err)
}
tc.b.Logger.Info("controller name generated", "name", opts.Config.Controller.Name)
}
tc.name = opts.Config.Controller.Name
// Set up KMSes
if err := tc.b.SetupKMSes(nil, opts.Config.SharedConfig, []string{"controller", "worker-auth"}); err != nil {
t.Fatal(err)
switch {
case opts.ControllerKMS != nil && opts.WorkerAuthKMS != nil:
tc.b.ControllerKMS = opts.ControllerKMS
tc.b.WorkerAuthKMS = opts.WorkerAuthKMS
case opts.ControllerKMS == nil && opts.WorkerAuthKMS == nil:
if err := tc.b.SetupKMSes(nil, opts.Config.SharedConfig, []string{"controller", "worker-auth"}); err != nil {
t.Fatal(err)
}
default:
t.Fatal("either controller and worker auth KMS must both be set, or neither")
}
// Ensure the listeners use random port allocation
for _, listener := range opts.Config.Listeners {
listener.RandomPort = true
}
if err := tc.b.SetupListeners(nil, opts.Config.SharedConfig); err != nil {
if err := tc.b.SetupListeners(nil, opts.Config.SharedConfig, []string{"api", "cluster"}); err != nil {
t.Fatal(err)
}
if !opts.DisableDatabaseCreation {
if opts.DatabaseUrl != "" {
tc.b.DatabaseUrl = opts.DatabaseUrl
if err := tc.b.ConnectToDatabase("postgres"); err != nil {
t.Fatal(err)
}
} else if !opts.DisableDatabaseCreation {
if err := tc.b.CreateDevDatabase("postgres"); err != nil {
t.Fatal(err)
}
@ -228,3 +311,30 @@ func NewTestController(t *testing.T, opts *TestControllerOpts) *TestController {
return tc
}
func (tc *TestController) AddClusterControllerMember(t *testing.T, opts *TestControllerOpts) *TestController {
if opts == nil {
opts = new(TestControllerOpts)
}
nextOpts := &TestControllerOpts{
DatabaseUrl: tc.c.conf.DatabaseUrl,
DefaultAuthMethodId: tc.c.conf.DevAuthMethodId,
DefaultOrgId: tc.c.conf.DefaultOrgId,
ControllerKMS: tc.c.conf.ControllerKMS,
WorkerAuthKMS: tc.c.conf.WorkerAuthKMS,
Name: opts.Name,
Logger: tc.c.conf.Logger,
}
if opts.Logger != nil {
nextOpts.Logger = opts.Logger
}
if nextOpts.Name == "" {
var err error
nextOpts.Name, err = base62.Random(5)
if err != nil {
t.Fatal(err)
}
nextOpts.Logger.Info("controller name generated", "name", nextOpts.Name)
}
return NewTestController(t, nextOpts)
}

@ -7,6 +7,7 @@ import (
"encoding/base64"
"encoding/json"
"errors"
"net"
"strings"
"github.com/hashicorp/boundary/internal/cmd/base"
@ -14,17 +15,29 @@ import (
"google.golang.org/protobuf/proto"
)
type workerAuthEntry struct {
*base.WorkerAuthInfo
conn net.Conn
}
func (c Controller) validateWorkerTLS(hello *tls.ClientHelloInfo) (*tls.Config, error) {
for _, p := range hello.SupportedProtos {
switch {
case strings.HasPrefix(p, "v1workerauth-"):
return c.v1WorkerAuthConfig(hello.SupportedProtos)
tlsConf, workerInfo, err := c.v1WorkerAuthConfig(hello.SupportedProtos)
if err == nil {
// Set the info we need to prevent replays
c.workerAuthCache.Set(workerInfo.ConnectionNonce, &workerAuthEntry{
WorkerAuthInfo: workerInfo,
}, 0)
}
return tlsConf, err
}
}
return nil, nil
}
func (c Controller) v1WorkerAuthConfig(protos []string) (*tls.Config, error) {
func (c Controller) v1WorkerAuthConfig(protos []string) (*tls.Config, *base.WorkerAuthInfo, error) {
var firstMatchProto string
var encString string
for _, p := range protos {
@ -37,32 +50,32 @@ func (c Controller) v1WorkerAuthConfig(protos []string) (*tls.Config, error) {
}
}
if firstMatchProto == "" {
return nil, errors.New("no matching proto found")
return nil, nil, errors.New("no matching proto found")
}
marshaledEncInfo, err := base64.RawStdEncoding.DecodeString(encString)
if err != nil {
return nil, err
return nil, nil, err
}
encInfo := new(wrapping.EncryptedBlobInfo)
if err := proto.Unmarshal(marshaledEncInfo, encInfo); err != nil {
return nil, err
return nil, nil, err
}
marshaledInfo, err := c.conf.WorkerAuthKMS.Decrypt(context.Background(), encInfo, nil)
if err != nil {
return nil, err
return nil, nil, err
}
info := new(base.WorkerAuthCertInfo)
info := new(base.WorkerAuthInfo)
if err := json.Unmarshal(marshaledInfo, info); err != nil {
return nil, err
return nil, nil, err
}
rootCAs := x509.NewCertPool()
if ok := rootCAs.AppendCertsFromPEM(info.CACertPEM); !ok {
return nil, errors.New("unable to add ca cert to cert pool")
return nil, info, errors.New("unable to add ca cert to cert pool")
}
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
if err != nil {
return nil, err
return nil, info, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
@ -73,5 +86,5 @@ func (c Controller) v1WorkerAuthConfig(protos []string) (*tls.Config, error) {
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
return tlsConfig, info, nil
}

@ -0,0 +1,45 @@
package servers
import "time"
// getOpts - iterate the inbound Options and return a struct
func getOpts(opt ...Option) options {
opts := getDefaultOptions()
for _, o := range opt {
o(&opts)
}
return opts
}
// Option - how Options are passed as arguments
type Option func(*options)
// options = how options are represented
type options struct {
withLimit int
withLiveness time.Duration
}
func getDefaultOptions() options {
return options{
withLimit: 0,
withLiveness: 0,
}
}
// WithLimit provides an option to provide a limit. Intentionally allowing
// negative integers. If WithLimit < 0, then unlimited results are returned.
// If WithLimit == 0, then default limits are used for results.
func WithLimit(limit int) Option {
return func(o *options) {
o.withLimit = limit
}
}
// WithSkipVetForWrite provides an option to allow skipping vet checks to allow
// testing lower-level SQL triggers and constraints
func WithLiveness(liveness time.Duration) Option {
return func(o *options) {
o.withLiveness = liveness
}
}

@ -0,0 +1,144 @@
package servers
import (
"context"
"errors"
"fmt"
"time"
"github.com/hashicorp/boundary/internal/db"
timestamp "github.com/hashicorp/boundary/internal/db/timestamp"
"github.com/hashicorp/boundary/internal/types/resource"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
)
const (
defaultLiveness = 15 * time.Second
)
// Repository is the jobs database repository
type Repository struct {
logger hclog.Logger
reader db.Reader
writer db.Writer
wrapper wrapping.Wrapper
}
// NewRepository creates a new jobs Repository. Supports the options: WithLimit
// which sets a default limit on results returned by repo operations.
func NewRepository(logger hclog.Logger, r db.Reader, w db.Writer, wrapper wrapping.Wrapper) (*Repository, error) {
if r == nil {
return nil, errors.New("error creating db repository with nil reader")
}
if w == nil {
return nil, errors.New("error creating db repository with nil writer")
}
if wrapper == nil {
return nil, errors.New("error creating db repository with nil wrapper")
}
return &Repository{
logger: logger,
reader: r,
writer: w,
wrapper: wrapper,
}, nil
}
// list will return a listing of resources and honor the WithLimit option or the
// repo defaultLimit
func (r *Repository) List(ctx context.Context, serverType string, opt ...Option) ([]*Server, error) {
opts := getOpts(opt...)
liveness := opts.withLiveness
if liveness == 0 {
liveness = defaultLiveness
}
updateTime := time.Now().Add(-1 * liveness)
q := `
select * from servers
where
type = $1
and
update_time > $2;`
underlying, err := r.reader.DB()
if err != nil {
return nil, fmt.Errorf("error fetching underlying DB for server list operation: %w", err)
}
rows, err := underlying.QueryContext(ctx, q,
serverType,
updateTime.Format(time.RFC3339))
if err != nil {
return nil, fmt.Errorf("error performing server list: %w", err)
}
results := make([]*Server, 0, 3)
for rows.Next() {
server := &Server{
CreateTime: new(timestamp.Timestamp),
UpdateTime: new(timestamp.Timestamp),
}
if err := rows.Scan(
&server.PrivateId,
&server.Type,
&server.Name,
&server.Description,
&server.Address,
server.CreateTime,
server.UpdateTime,
); err != nil {
r.logger.Error("error scanning server row", "error", err)
break
}
results = append(results, server)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error performing scan over server rows: %w", err)
}
return results, nil
}
// upsert will upsert
func (r *Repository) Upsert(ctx context.Context, server *Server, opt ...Option) ([]*Server, int, error) {
if server == nil {
return nil, db.NoRowsAffected, errors.New("cannot update server that is nil")
}
// Ensure, for now at least, the private ID is always equivalent to the name
server.PrivateId = server.Name
// Build query
q := `
insert into servers
(private_id, type, name, description, address, update_time)
values
($1, $2, $3, $4, $5, $6)
on conflict on constraint servers_pkey
do update set
name = $3,
description = $4,
address = $5,
update_time = $6;
`
underlying, err := r.writer.DB()
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("error fetching underlying DB for upsert operation: %w", err)
}
result, err := underlying.ExecContext(ctx, q,
server.PrivateId,
server.Type,
server.Name,
server.Description,
server.Address,
time.Now().Format(time.RFC3339))
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("error performing status upsert: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, db.NoRowsAffected, fmt.Errorf("unable to fetch number of rows affected from query: %w", err)
}
// If updating a controller, done
if server.Type == resource.Controller.String() {
return nil, int(rowsAffected), nil
}
// Fetch current controllers to feed to the workers
controllers, err := r.List(ctx, resource.Controller.String())
return controllers, len(controllers), err
}

@ -0,0 +1,231 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.25.0
// protoc v3.12.4
// source: controller/servers/v1/servers.proto
package servers
import (
proto "github.com/golang/protobuf/proto"
timestamp "github.com/hashicorp/boundary/internal/db/timestamp"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// This is a compile-time assertion that a sufficiently up-to-date version
// of the legacy proto package is being used.
const _ = proto.ProtoPackageIsVersion4
// Server contains all fields related to a Controller or Worker resource
type Server struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
// Private ID of the resource
PrivateId string `protobuf:"bytes,10,opt,name=private_id,json=privateId,proto3" json:"private_id,omitempty"`
// Type of the resource (controller, worker)
Type string `protobuf:"bytes,20,opt,name=type,proto3" json:"type,omitempty"`
// Name of the resource
Name string `protobuf:"bytes,30,opt,name=name,proto3" json:"name,omitempty"`
// Description of the resource
Description string `protobuf:"bytes,40,opt,name=description,proto3" json:"description,omitempty"`
// Address for the server
Address string `protobuf:"bytes,50,opt,name=address,proto3" json:"address,omitempty"`
// First seen time from the RDBMS
CreateTime *timestamp.Timestamp `protobuf:"bytes,60,opt,name=create_time,json=createTime,proto3" json:"create_time,omitempty"`
// Last time there was an update
UpdateTime *timestamp.Timestamp `protobuf:"bytes,70,opt,name=update_time,json=updateTime,proto3" json:"update_time,omitempty"`
}
func (x *Server) Reset() {
*x = Server{}
if protoimpl.UnsafeEnabled {
mi := &file_controller_servers_v1_servers_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Server) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Server) ProtoMessage() {}
func (x *Server) ProtoReflect() protoreflect.Message {
mi := &file_controller_servers_v1_servers_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Server.ProtoReflect.Descriptor instead.
func (*Server) Descriptor() ([]byte, []int) {
return file_controller_servers_v1_servers_proto_rawDescGZIP(), []int{0}
}
func (x *Server) GetPrivateId() string {
if x != nil {
return x.PrivateId
}
return ""
}
func (x *Server) GetType() string {
if x != nil {
return x.Type
}
return ""
}
func (x *Server) GetName() string {
if x != nil {
return x.Name
}
return ""
}
func (x *Server) GetDescription() string {
if x != nil {
return x.Description
}
return ""
}
func (x *Server) GetAddress() string {
if x != nil {
return x.Address
}
return ""
}
func (x *Server) GetCreateTime() *timestamp.Timestamp {
if x != nil {
return x.CreateTime
}
return nil
}
func (x *Server) GetUpdateTime() *timestamp.Timestamp {
if x != nil {
return x.UpdateTime
}
return nil
}
var File_controller_servers_v1_servers_proto protoreflect.FileDescriptor
var file_controller_servers_v1_servers_proto_rawDesc = []byte{
0x0a, 0x23, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x65, 0x72,
0x76, 0x65, 0x72, 0x73, 0x2f, 0x76, 0x31, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e,
0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x15, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65,
0x72, 0x2e, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x2e, 0x76, 0x31, 0x1a, 0x2f, 0x63, 0x6f,
0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2f, 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65,
0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2f, 0x76, 0x31, 0x2f, 0x74, 0x69,
0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa5, 0x02,
0x0a, 0x06, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x69, 0x76,
0x61, 0x74, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72,
0x69, 0x76, 0x61, 0x74, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18,
0x14, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x6e,
0x61, 0x6d, 0x65, 0x18, 0x1e, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12,
0x20, 0x0a, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x28,
0x20, 0x01, 0x28, 0x09, 0x52, 0x0b, 0x64, 0x65, 0x73, 0x63, 0x72, 0x69, 0x70, 0x74, 0x69, 0x6f,
0x6e, 0x12, 0x18, 0x0a, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x18, 0x32, 0x20, 0x01,
0x28, 0x09, 0x52, 0x07, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x4b, 0x0a, 0x0b, 0x63,
0x72, 0x65, 0x61, 0x74, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x3c, 0x20, 0x01, 0x28, 0x0b,
0x32, 0x2a, 0x2e, 0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x74,
0x6f, 0x72, 0x61, 0x67, 0x65, 0x2e, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e,
0x76, 0x31, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x63, 0x72,
0x65, 0x61, 0x74, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x4b, 0x0a, 0x0b, 0x75, 0x70, 0x64, 0x61,
0x74, 0x65, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x46, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x2a, 0x2e,
0x63, 0x6f, 0x6e, 0x74, 0x72, 0x6f, 0x6c, 0x6c, 0x65, 0x72, 0x2e, 0x73, 0x74, 0x6f, 0x72, 0x61,
0x67, 0x65, 0x2e, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x76, 0x31, 0x2e,
0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x75, 0x70, 0x64, 0x61, 0x74,
0x65, 0x54, 0x69, 0x6d, 0x65, 0x42, 0x38, 0x5a, 0x36, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e,
0x63, 0x6f, 0x6d, 0x2f, 0x68, 0x61, 0x73, 0x68, 0x69, 0x63, 0x6f, 0x72, 0x70, 0x2f, 0x62, 0x6f,
0x75, 0x6e, 0x64, 0x61, 0x72, 0x79, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x2f,
0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x3b, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x62,
0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_controller_servers_v1_servers_proto_rawDescOnce sync.Once
file_controller_servers_v1_servers_proto_rawDescData = file_controller_servers_v1_servers_proto_rawDesc
)
func file_controller_servers_v1_servers_proto_rawDescGZIP() []byte {
file_controller_servers_v1_servers_proto_rawDescOnce.Do(func() {
file_controller_servers_v1_servers_proto_rawDescData = protoimpl.X.CompressGZIP(file_controller_servers_v1_servers_proto_rawDescData)
})
return file_controller_servers_v1_servers_proto_rawDescData
}
var file_controller_servers_v1_servers_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_controller_servers_v1_servers_proto_goTypes = []interface{}{
(*Server)(nil), // 0: controller.servers.v1.Server
(*timestamp.Timestamp)(nil), // 1: controller.storage.timestamp.v1.Timestamp
}
var file_controller_servers_v1_servers_proto_depIdxs = []int32{
1, // 0: controller.servers.v1.Server.create_time:type_name -> controller.storage.timestamp.v1.Timestamp
1, // 1: controller.servers.v1.Server.update_time:type_name -> controller.storage.timestamp.v1.Timestamp
2, // [2:2] is the sub-list for method output_type
2, // [2:2] is the sub-list for method input_type
2, // [2:2] is the sub-list for extension type_name
2, // [2:2] is the sub-list for extension extendee
0, // [0:2] is the sub-list for field type_name
}
func init() { file_controller_servers_v1_servers_proto_init() }
func file_controller_servers_v1_servers_proto_init() {
if File_controller_servers_v1_servers_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_controller_servers_v1_servers_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Server); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_controller_servers_v1_servers_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_controller_servers_v1_servers_proto_goTypes,
DependencyIndexes: file_controller_servers_v1_servers_proto_depIdxs,
MessageInfos: file_controller_servers_v1_servers_proto_msgTypes,
}.Build()
File_controller_servers_v1_servers_proto = out.File
file_controller_servers_v1_servers_proto_rawDesc = nil
file_controller_servers_v1_servers_proto_goTypes = nil
file_controller_servers_v1_servers_proto_depIdxs = nil
}

@ -1,146 +0,0 @@
package worker
import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math/big"
mathrand "math/rand"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/vault/sdk/helper/base62"
"google.golang.org/protobuf/proto"
)
func (c Worker) workerAuthTLSConfig() (*tls.Config, error) {
info := new(base.WorkerAuthCertInfo)
_, caKey, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
if err != nil {
return nil, err
}
caHost, err := base62.Random(20)
if err != nil {
return nil, err
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: caHost,
},
DNSNames: []string{caHost},
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(3 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
return nil, err
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
info.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, err
}
//
// Certs generation
//
_, key, err := ed25519.GenerateKey(c.conf.SecureRandomReader)
if err != nil {
return nil, err
}
host, err := base62.Random(20)
if err != nil {
return nil, err
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(2 * time.Minute),
}
certBytes, err := x509.CreateCertificate(c.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
return nil, err
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
info.CertPEM = pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
info.KeyPEM = pem.EncodeToMemory(keyPEMBlock)
// Marshal and encrypt
marshaledInfo, err := json.Marshal(info)
if err != nil {
return nil, err
}
encInfo, err := c.conf.WorkerAuthKMS.Encrypt(context.Background(), marshaledInfo, nil)
if err != nil {
return nil, err
}
marshaledEncInfo, err := proto.Marshal(encInfo)
if err != nil {
return nil, err
}
b64alpn := base64.RawStdEncoding.EncodeToString(marshaledEncInfo)
var nextProtos []string
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
// Build local tls config
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
ServerName: host,
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}

@ -6,6 +6,10 @@ import (
)
type Config struct {
// The base Server object, containing things shared between Controllers and
// Workers
*base.Server
// The underlying configuration, passed in here to avoid duplicating values
// everwyehere
RawConfig *config.Config
}

@ -0,0 +1,263 @@
package worker
import (
"context"
"crypto/ed25519"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"math"
"math/big"
mathrand "math/rand"
"net"
"strings"
"time"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/vault/sdk/helper/base62"
"google.golang.org/grpc"
"google.golang.org/grpc/resolver"
"google.golang.org/protobuf/proto"
)
type controllerConnection struct {
controllerAddr string
client services.WorkerServiceClient
}
func newControllerConnection(controllerAddr string, client services.WorkerServiceClient) *controllerConnection {
ret := &controllerConnection{
controllerAddr: controllerAddr,
client: client,
}
return ret
}
func (w *Worker) startControllerConnections() error {
initialAddrs := make([]resolver.Address, 0, len(w.conf.RawConfig.Worker.Controllers))
for _, addr := range w.conf.RawConfig.Worker.Controllers {
host, port, err := net.SplitHostPort(addr)
if err != nil && strings.Contains(err.Error(), "missing port in address") {
host, port, err = net.SplitHostPort(fmt.Sprintf("%s:%s", addr, "9201"))
}
if err != nil {
return fmt.Errorf("error parsing controller address: %w", err)
}
initialAddrs = append(initialAddrs, resolver.Address{Addr: fmt.Sprintf("%s:%s", host, port)})
}
w.Resolver().InitialState(resolver.State{
Addresses: initialAddrs,
})
for _, addr := range initialAddrs {
if err := w.createClientConn(addr.Addr); err != nil {
return fmt.Errorf("error making client connection to controller: %w", err)
}
}
return nil
}
func (w Worker) controllerDialerFunc() func(context.Context, string) (net.Conn, error) {
return func(ctx context.Context, addr string) (net.Conn, error) {
tlsConf, authInfo, err := w.workerAuthTLSConfig()
if err != nil {
return nil, fmt.Errorf("error creating tls config for worker auth: %w", err)
}
dialer := &net.Dialer{}
nonTlsConn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("unable to dial to controller: %w", err)
}
tlsConn := tls.Client(nonTlsConn, tlsConf)
written, err := tlsConn.Write([]byte(authInfo.ConnectionNonce))
if err != nil {
if err := nonTlsConn.Close(); err != nil {
w.logger.Error("error closing connection after writing failure", "error", err)
}
return nil, fmt.Errorf("unable to write connection nonce: %w", err)
}
if written != len(authInfo.ConnectionNonce) {
if err := nonTlsConn.Close(); err != nil {
w.logger.Error("error closing connection after writing failure", "error", err)
}
return nil, fmt.Errorf("expected to write %d bytes of connection nonce, wrote %d", len(authInfo.ConnectionNonce), written)
}
return tlsConn, nil
}
}
func (w *Worker) createClientConn(addr string) error {
defaultTimeout := (time.Second + time.Nanosecond).String()
defServiceConfig := fmt.Sprintf(`
{
"loadBalancingConfig": [ { "round_robin": {} } ],
"methodConfig": [
{
"name": [],
"timeout": %q,
"waitForReady": true
}
]
}
`, defaultTimeout)
cc, err := grpc.DialContext(w.baseContext,
fmt.Sprintf("%s:///%s", w.Resolver().Scheme(), addr),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
grpc.WithDefaultCallOptions(grpc.MaxCallSendMsgSize(math.MaxInt32)),
grpc.WithContextDialer(w.controllerDialerFunc()),
grpc.WithInsecure(),
grpc.WithDefaultServiceConfig(defServiceConfig),
// Don't have the resolver reach out for a service config from the
// resolver, use the one specified as default
grpc.WithDisableServiceConfig(),
)
if err != nil {
return fmt.Errorf("error dialing controller for worker auth: %w", err)
}
client := services.NewWorkerServiceClient(cc)
w.controllerConns.Store(addr, newControllerConnection(addr, client))
w.logger.Info("connected to controller", "address", addr)
return nil
}
func (w Worker) workerAuthTLSConfig() (*tls.Config, *base.WorkerAuthInfo, error) {
var err error
info := &base.WorkerAuthInfo{
Name: w.conf.RawConfig.Worker.Name,
Description: w.conf.RawConfig.Worker.Description,
}
if info.ConnectionNonce, err = base62.Random(20); err != nil {
return nil, nil, err
}
_, caKey, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
caHost, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
caCertTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: caHost,
},
DNSNames: []string{caHost},
KeyUsage: x509.KeyUsage(x509.KeyUsageCertSign | x509.KeyUsageCRLSign),
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(3 * time.Minute),
BasicConstraintsValid: true,
IsCA: true,
}
caBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, caCertTemplate, caCertTemplate, caKey.Public(), caKey)
if err != nil {
return nil, nil, err
}
caCertPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: caBytes,
}
info.CACertPEM = pem.EncodeToMemory(caCertPEMBlock)
caCert, err := x509.ParseCertificate(caBytes)
if err != nil {
return nil, nil, err
}
//
// Certs generation
//
_, key, err := ed25519.GenerateKey(w.conf.SecureRandomReader)
if err != nil {
return nil, nil, err
}
host, err := base62.Random(20)
if err != nil {
return nil, nil, err
}
certTemplate := &x509.Certificate{
Subject: pkix.Name{
CommonName: host,
},
DNSNames: []string{host},
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
x509.ExtKeyUsageClientAuth,
},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageKeyAgreement,
SerialNumber: big.NewInt(mathrand.Int63()),
NotBefore: time.Now().Add(-30 * time.Second),
NotAfter: time.Now().Add(2 * time.Minute),
}
certBytes, err := x509.CreateCertificate(w.conf.SecureRandomReader, certTemplate, caCert, key.Public(), caKey)
if err != nil {
return nil, nil, err
}
certPEMBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: certBytes,
}
info.CertPEM = pem.EncodeToMemory(certPEMBlock)
marshaledKey, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
return nil, nil, err
}
keyPEMBlock := &pem.Block{
Type: "PRIVATE KEY",
Bytes: marshaledKey,
}
info.KeyPEM = pem.EncodeToMemory(keyPEMBlock)
// Marshal and encrypt
marshaledInfo, err := json.Marshal(info)
if err != nil {
return nil, nil, err
}
encInfo, err := w.conf.WorkerAuthKMS.Encrypt(context.Background(), marshaledInfo, nil)
if err != nil {
return nil, nil, err
}
marshaledEncInfo, err := proto.Marshal(encInfo)
if err != nil {
return nil, nil, err
}
b64alpn := base64.RawStdEncoding.EncodeToString(marshaledEncInfo)
var nextProtos []string
var count int
for i := 0; i < len(b64alpn); i += 230 {
end := i + 230
if end > len(b64alpn) {
end = len(b64alpn)
}
nextProtos = append(nextProtos, fmt.Sprintf("v1workerauth-%02d-%s", count, b64alpn[i:end]))
count++
}
// Build local tls config
rootCAs := x509.NewCertPool()
rootCAs.AddCert(caCert)
tlsCert, err := tls.X509KeyPair(info.CertPEM, info.KeyPEM)
if err != nil {
return nil, nil, err
}
tlsConfig := &tls.Config{
ServerName: host,
Certificates: []tls.Certificate{tlsCert},
RootCAs: rootCAs,
NextProtos: nextProtos,
MinVersion: tls.VersionTLS13,
}
tlsConfig.BuildNameToCertificate()
return tlsConfig, info, nil
}

@ -1,89 +1,38 @@
package worker
import (
"context"
"crypto/tls"
"errors"
"fmt"
"sync"
"github.com/hashicorp/go-multierror"
"github.com/hashicorp/vault/sdk/helper/strutil"
)
func (c *Worker) startListeners() error {
var retErr *multierror.Error
servers := make([]func(), 0, len(c.conf.Listeners))
for _, ln := range c.conf.Listeners {
switch c.conf.RawConfig.DevController {
case false:
// TODO: We'll eventually need to configure HTTP listening here for
// org-provided certificate handling, and configure the mux's
// defaultproto for accepting client connections for ALPN-based
// auth
default:
// TODO: We'll need to go through any listeners marked for api
// usage and add our websocket handlers to the server. Eventually
// we may want to make the config function able to handle arbitrary
// ALPNs in a dynamic way, so that in dev mode we can also register
// alpn mode client auth handling via the cluster ports.
// TODO again because...I like that idea and don't want to forget
// about it :-)
// For now just testing out the ability to authorize via the ALPN handler
if strutil.StrListContains(ln.Config.Purpose, "cluster") {
tlsConf, err := c.workerAuthTLSConfig()
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error creating tls config for worker auth: %w", err))
continue
}
if ln.ALPNListener != nil {
conn, err := tls.Dial(ln.ALPNListener.Addr().Network(), ln.ALPNListener.Addr().String(), tlsConf)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error dialing controller for worker auth: %w", err))
continue
}
_, err = conn.Write([]byte("foo"))
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to controller for worker auth: %w", err))
continue
}
_, err = conn.Read(make([]byte, 3))
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error reading test string from controller for worker auth: %w", err))
continue
}
c.logger.Info("done good writing/reading")
conn.Close()
newTLSConf, _ := c.workerAuthTLSConfig()
tlsConf.Certificates = newTLSConf.Certificates
conn, err = tls.Dial(ln.ALPNListener.Addr().Network(), ln.ALPNListener.Addr().String(), tlsConf)
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error dialing controller for worker auth: %w", err))
continue
}
_, err = conn.Write([]byte("foo"))
if err != nil {
retErr = multierror.Append(retErr, fmt.Errorf("error writing test string to controller for worker auth: %w", err))
continue
}
_, err = conn.Read(make([]byte, 3))
if err == nil {
retErr = multierror.Append(retErr, errors.New("expected error reading test string from controller for worker auth"))
continue
}
c.logger.Info("done bad writing/reading")
conn.Close()
func (w *Worker) startListeners() error {
servers := make([]func(), 0, len(w.conf.Listeners))
for _, ln := range w.conf.Listeners {
var err error
for _, purpose := range ln.Config.Purpose {
switch purpose {
case "api", "cluster":
// Do nothing, in dev mode we might see it here
case "worker-alpn-tls":
if w.listeningAddress != "" {
return errors.New("more than one listening address found")
}
w.listeningAddress = ln.Config.Address
w.logger.Info("reporting listening address to controllers", "address", w.listeningAddress)
// TODO: other stuff
// TODO: once we have an actual listener, record in w.listeningAddress the actual address with port
default:
err = fmt.Errorf("unknown listener purpose %q", purpose)
}
if err != nil {
return err
}
}
}
err := retErr.ErrorOrNil()
if err != nil {
return err
}
for _, s := range servers {
s()
}
@ -91,37 +40,16 @@ func (c *Worker) startListeners() error {
return nil
}
func (c *Worker) stopListeners() error {
serverWg := new(sync.WaitGroup)
for _, ln := range c.conf.Listeners {
if c.conf.RawConfig.DevController {
// These will get closed by the controller's dev instance
continue
}
if ln.HTTPServer == nil {
continue
}
localLn := ln
serverWg.Add(1)
go func() {
shutdownKill, shutdownKillCancel := context.WithTimeout(c.baseContext, localLn.Config.MaxRequestDuration)
defer shutdownKillCancel()
defer serverWg.Done()
localLn.HTTPServer.Shutdown(shutdownKill)
}()
}
serverWg.Wait()
func (w *Worker) stopListeners() error {
var retErr *multierror.Error
for _, ln := range c.conf.Listeners {
for _, ln := range w.conf.Listeners {
if ln.ALPNListener != nil {
if err := ln.ALPNListener.Close(); err != nil {
retErr = multierror.Append(retErr, err)
}
}
if !c.conf.RawConfig.DevController {
if !w.conf.RawConfig.DevController {
if err := ln.Mux.Close(); err != nil {
retErr = multierror.Append(retErr, err)
}

@ -0,0 +1,82 @@
package worker
import (
"context"
"math/rand"
"time"
pbs "github.com/hashicorp/boundary/internal/gen/controller/api/services"
"github.com/hashicorp/boundary/internal/servers"
"github.com/hashicorp/boundary/internal/types/resource"
"google.golang.org/grpc/resolver"
)
// In the future we could make this configurable
const (
statusInterval = 2 * time.Second
)
func (w *Worker) startStatusTicking(cancelCtx context.Context) {
go func() {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
// This function exists to desynchronize calls to controllers from
// workers so we aren't always getting status updates at the exact same
// intervals, to ease the load on the DB.
getRandomInterval := func() time.Duration {
// 0 to 0.5 adjustment to the base
f := r.Float64() / 2
// Half a chance to be faster, not slower
if r.Float32() > 0.5 {
f = -1 * f
}
return statusInterval + time.Duration(f*float64(time.Second))
}
timer := time.NewTimer(0)
for {
select {
case <-cancelCtx.Done():
w.logger.Info("status ticking shutting down")
return
case <-timer.C:
w.controllerConns.Range(func(_, v interface{}) bool {
// If something is removed from the map while ranging, ignore it
if v == nil {
return true
}
c := v.(*controllerConnection)
result, err := c.client.Status(cancelCtx, &pbs.StatusRequest{
Worker: &servers.Server{
PrivateId: w.conf.RawConfig.Worker.Name,
Name: w.conf.RawConfig.Worker.Name,
Type: resource.Worker.String(),
Description: w.conf.RawConfig.Worker.Description,
Address: w.listeningAddress,
},
})
if err != nil {
w.logger.Error("error making status request to controller", "error", err)
} else {
w.logger.Trace("successfully sent status to controller")
addrs := make([]resolver.Address, 0, len(result.Controllers))
strAddrs := make([]string, 0, len(result.Controllers))
for _, v := range result.Controllers {
addrs = append(addrs, resolver.Address{Addr: v.Address})
strAddrs = append(strAddrs, v.Address)
}
w.Resolver().UpdateState(resolver.State{Addresses: addrs})
w.logger.Trace("found controllers", "addresses", strAddrs)
w.lastStatusSuccess.Store(time.Now())
}
return true
})
timer.Reset(getRandomInterval())
}
}
}()
}
func (w *Worker) LastStatusSuccess() time.Time {
return w.lastStatusSuccess.Load().(time.Time)
}

@ -0,0 +1,240 @@
package worker
import (
"context"
"fmt"
"net"
"testing"
"github.com/hashicorp/boundary/internal/cmd/base"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/go-hclog"
wrapping "github.com/hashicorp/go-kms-wrapping"
"github.com/hashicorp/vault/sdk/helper/base62"
)
// TestWorker wraps a base.Server and Worker to provide a
// fully-programmatic worker for tests. Error checking (for instance, for
// valid config) is not stringent at the moment.
type TestWorker struct {
b *base.Server
w *Worker
t *testing.T
addrs []string // The address the worker proxies are listening on
ctx context.Context
cancel context.CancelFunc
name string
}
// Worker returns the underlying controller
func (tw *TestWorker) Worker() *Worker {
return tw.w
}
func (tw *TestWorker) Config() *Config {
return tw.w.conf
}
func (tw *TestWorker) Context() context.Context {
return tw.ctx
}
func (tw *TestWorker) Cancel() {
tw.cancel()
}
func (tw *TestWorker) Name() string {
return tw.name
}
func (tw *TestWorker) ControllerAddrs() []string {
var addrs []string
tw.w.controllerConns.Range(func(_, v interface{}) bool {
// If something is removed from the map while ranging, ignore it
if v == nil {
return true
}
c := v.(*controllerConnection)
addrs = append(addrs, c.controllerAddr)
return true
})
return addrs
}
func (tw *TestWorker) ProxyAddrs() []string {
if tw.addrs != nil {
return tw.addrs
}
for _, listener := range tw.b.Listeners {
if listener.Config.Purpose[0] == "worker-alpn-tls" {
tcpAddr, ok := listener.Mux.Addr().(*net.TCPAddr)
if !ok {
tw.t.Fatal("could not parse address as a TCP addr")
}
addr := fmt.Sprintf("%s:%d", tcpAddr.IP.String(), tcpAddr.Port)
tw.addrs = append(tw.addrs, addr)
}
}
return tw.addrs
}
// Shutdown runs any cleanup functions; be sure to run this after your test is
// done
func (tw *TestWorker) Shutdown() {
if tw.b != nil {
close(tw.b.ShutdownCh)
}
tw.cancel()
if tw.w != nil {
if err := tw.w.Shutdown(false); err != nil {
tw.t.Error(err)
}
}
if tw.b != nil {
if err := tw.b.RunShutdownFuncs(); err != nil {
tw.t.Error(err)
}
}
}
type TestWorkerOpts struct {
// Config; if not provided a dev one will be created
Config *config.Config
// Sets initial controller addresses
InitialControllers []string
// If true, the worker will not be started
DisableAutoStart bool
// The worker auth KMS to use, or one will be created
WorkerAuthKMS wrapping.Wrapper
// The name to use for the worker, otherwise one will be randomly
// generated, unless provided in a non-nil Config
Name string
// The logger to use, or one will be created
Logger hclog.Logger
}
func NewTestWorker(t *testing.T, opts *TestWorkerOpts) *TestWorker {
ctx, cancel := context.WithCancel(context.Background())
tw := &TestWorker{
t: t,
ctx: ctx,
cancel: cancel,
}
if opts == nil {
opts = new(TestWorkerOpts)
}
// Base server
tw.b = base.NewServer(nil)
tw.b.Command = &base.Command{
ShutdownCh: make(chan struct{}),
}
// Get dev config, or use a provided one
var err error
if opts.Config == nil {
opts.Config, err = config.DevWorker()
if err != nil {
t.Fatal(err)
}
opts.Config.Worker.Name = opts.Name
}
if len(opts.InitialControllers) > 0 {
opts.Config.Worker.Controllers = opts.InitialControllers
}
// Start a logger
tw.b.Logger = opts.Logger
if tw.b.Logger == nil {
tw.b.Logger = hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
}
if opts.Config.Worker == nil {
opts.Config.Worker = new(config.Worker)
}
if opts.Config.Worker.Name == "" {
opts.Config.Worker.Name, err = base62.Random(5)
if err != nil {
t.Fatal(err)
}
tw.b.Logger.Info("worker name generated", "name", opts.Config.Worker.Name)
}
tw.name = opts.Config.Worker.Name
// Set up KMSes
switch {
case opts.WorkerAuthKMS != nil:
tw.b.WorkerAuthKMS = opts.WorkerAuthKMS
default:
if err := tw.b.SetupKMSes(nil, opts.Config.SharedConfig, []string{"worker-auth"}); err != nil {
t.Fatal(err)
}
}
// Ensure the listeners use random port allocation
for _, listener := range opts.Config.Listeners {
listener.RandomPort = true
}
if err := tw.b.SetupListeners(nil, opts.Config.SharedConfig, []string{"worker-alpn-tls"}); err != nil {
t.Fatal(err)
}
conf := &Config{
RawConfig: opts.Config,
Server: tw.b,
}
tw.w, err = New(conf)
if err != nil {
tw.Shutdown()
t.Fatal(err)
}
if !opts.DisableAutoStart {
if err := tw.w.Start(); err != nil {
tw.Shutdown()
t.Fatal(err)
}
}
return tw
}
func (tw *TestWorker) AddClusterWorkerMember(t *testing.T, opts *TestWorkerOpts) *TestWorker {
if opts == nil {
opts = new(TestWorkerOpts)
}
nextOpts := &TestWorkerOpts{
WorkerAuthKMS: tw.w.conf.WorkerAuthKMS,
Name: opts.Name,
InitialControllers: tw.ControllerAddrs(),
Logger: tw.w.conf.Logger,
}
if opts.Logger != nil {
nextOpts.Logger = opts.Logger
}
if nextOpts.Name == "" {
var err error
nextOpts.Name, err = base62.Random(5)
if err != nil {
t.Fatal(err)
}
nextOpts.Logger.Info("worker name generated", "name", nextOpts.Name)
}
return NewTestWorker(t, nextOpts)
}

@ -4,9 +4,17 @@ import (
"context"
"crypto/rand"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/hashicorp/boundary/internal/cmd/config"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/base62"
"github.com/hashicorp/vault/sdk/helper/mlock"
ua "go.uber.org/atomic"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
)
type Worker struct {
@ -15,18 +23,46 @@ type Worker struct {
baseContext context.Context
baseCancel context.CancelFunc
started ua.Bool
controllerConns *sync.Map
lastStatusSuccess *atomic.Value
listeningAddress string
controllerResolver *atomic.Value
controllerResolverCleanup *atomic.Value
}
func New(conf *Config) (*Worker, error) {
c := &Worker{
conf: conf,
logger: conf.Logger.Named("worker"),
w := &Worker{
conf: conf,
logger: conf.Logger.Named("worker"),
controllerConns: new(sync.Map),
lastStatusSuccess: new(atomic.Value),
controllerResolver: new(atomic.Value),
controllerResolverCleanup: new(atomic.Value),
}
w.lastStatusSuccess.Store(time.Time{})
w.started.Store(false)
w.controllerResolver.Store((*manual.Resolver)(nil))
w.controllerResolverCleanup.Store(func() {})
if conf.SecureRandomReader == nil {
conf.SecureRandomReader = rand.Reader
}
var err error
if conf.RawConfig.Worker == nil {
conf.RawConfig.Worker = new(config.Worker)
}
if conf.RawConfig.Worker.Name == "" {
if conf.RawConfig.Worker.Name, err = base62.Random(10); err != nil {
return nil, fmt.Errorf("error auto-generating worker name: %w", err)
}
}
if !conf.RawConfig.DisableMlock {
// Ensure our memory usage is locked into physical RAM
if err := mlock.LockMemory(); err != nil {
@ -43,21 +79,60 @@ func New(conf *Config) (*Worker, error) {
}
}
c.baseContext, c.baseCancel = context.WithCancel(context.Background())
return c, nil
return w, nil
}
func (c *Worker) Start() error {
if err := c.startListeners(); err != nil {
func (w *Worker) Start() error {
if w.started.Load() {
w.logger.Info("already started, skipping")
return nil
}
w.baseContext, w.baseCancel = context.WithCancel(context.Background())
controllerResolver, controllerResolverCleanup := manual.GenerateAndRegisterManualResolver()
w.controllerResolver.Store(controllerResolver)
w.controllerResolverCleanup.Store(controllerResolverCleanup)
if err := w.startListeners(); err != nil {
return fmt.Errorf("error starting worker listeners: %w", err)
}
if err := w.startControllerConnections(); err != nil {
return fmt.Errorf("error making controller connections: %w", err)
}
w.startStatusTicking(w.baseContext)
w.started.Store(true)
return nil
}
func (c *Worker) Shutdown() error {
if err := c.stopListeners(); err != nil {
return fmt.Errorf("error stopping worker listeners: %w", err)
// Shutdown shuts down the workers. skipListeners can be used to not stop
// listeners, useful for tests if we want to stop and start a worker. In order
// to create new listeners we'd have to migrate listener setup logic here --
// doable, but work for later.
func (w *Worker) Shutdown(skipListeners bool) error {
if !w.started.Load() {
w.logger.Info("already shut down, skipping")
return nil
}
w.Resolver().UpdateState(resolver.State{Addresses: []resolver.Address{}})
w.controllerResolverCleanup.Load().(func())()
w.baseCancel()
if !skipListeners {
if err := w.stopListeners(); err != nil {
return fmt.Errorf("error stopping worker listeners: %w", err)
}
}
w.listeningAddress = ""
w.started.Store(false)
return nil
}
func (w *Worker) Resolver() *manual.Resolver {
raw := w.controllerResolver.Load()
if raw == nil {
panic("nil resolver")
}
return raw.(*manual.Resolver)
}

@ -0,0 +1,100 @@
package cluster
import (
"testing"
"time"
"github.com/alecthomas/assert"
"github.com/hashicorp/boundary/internal/servers/controller"
"github.com/hashicorp/boundary/internal/servers/worker"
"github.com/hashicorp/go-hclog"
"github.com/stretchr/testify/require"
)
func TestMultiControllerMultiWorkerConnections(t *testing.T) {
assert, require := assert.New(t), require.New(t)
amId := "paum_1234567890"
user := "user"
password := "passpass"
orgId := "o_1234567890"
logger := hclog.New(&hclog.LoggerOptions{
Level: hclog.Trace,
})
c1 := controller.NewTestController(t, &controller.TestControllerOpts{
DefaultOrgId: orgId,
DefaultAuthMethodId: amId,
DefaultLoginName: user,
DefaultPassword: password,
Logger: logger.Named("c1"),
})
defer c1.Shutdown()
c2 := c1.AddClusterControllerMember(t, &controller.TestControllerOpts{
Logger: c1.Config().Logger.ResetNamed("c2"),
})
defer c2.Shutdown()
expectWorkers := func(c *controller.TestController, workers ...*worker.TestWorker) {
updateTimes := c.Controller().WorkerStatusUpdateTimes()
workerMap := map[string]*worker.TestWorker{}
for _, w := range workers {
workerMap[w.Name()] = w
}
updateTimes.Range(func(k, v interface{}) bool {
require.NotNil(k)
require.NotNil(v)
if workerMap[k.(string)] == nil {
// We don't remove from updateTimes currently so if we're not
// expecting it we'll see an out-of-date entry
return true
}
assert.WithinDuration(time.Now(), v.(time.Time), 7*time.Second)
delete(workerMap, k.(string))
return true
})
assert.Empty(workerMap)
}
expectWorkers(c1)
expectWorkers(c2)
w1 := worker.NewTestWorker(t, &worker.TestWorkerOpts{
WorkerAuthKMS: c1.Config().WorkerAuthKMS,
InitialControllers: c1.ClusterAddrs(),
Logger: logger.Named("w1"),
})
defer w1.Shutdown()
time.Sleep(10 * time.Second)
expectWorkers(c1, w1)
expectWorkers(c2, w1)
w2 := w1.AddClusterWorkerMember(t, &worker.TestWorkerOpts{
Logger: logger.Named("w2"),
})
defer w2.Shutdown()
time.Sleep(10 * time.Second)
expectWorkers(c1, w1, w2)
expectWorkers(c2, w1, w2)
require.NoError(w1.Worker().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c1, w2)
expectWorkers(c2, w2)
require.NoError(w1.Worker().Start())
time.Sleep(10 * time.Second)
expectWorkers(c1, w1, w2)
expectWorkers(c2, w1, w2)
require.NoError(c1.Controller().Shutdown(true))
time.Sleep(10 * time.Second)
expectWorkers(c2, w1, w2)
require.NoError(c1.Controller().Start())
time.Sleep(10 * time.Second)
expectWorkers(c1, w1, w2)
expectWorkers(c2, w1, w2)
}

@ -17,6 +17,8 @@ const (
HostSet Type = 10
Host Type = 11
Target Type = 12
Controller Type = 13
Worker Type = 14
)
func (r Type) String() string {
@ -34,6 +36,8 @@ func (r Type) String() string {
"host-set",
"host",
"target",
"controller",
"worker",
}[r]
}
@ -51,4 +55,6 @@ var Map = map[string]Type{
HostSet.String(): HostSet,
Host.String(): Host,
Target.String(): Target,
Controller.String(): Controller,
Worker.String(): Worker,
}

@ -64,6 +64,14 @@ func Test_Resource(t *testing.T) {
typeString: "target",
want: Target,
},
{
typeString: "controller",
want: Controller,
},
{
typeString: "worker",
want: Worker,
},
}
for _, tt := range tests {
t.Run(tt.typeString, func(t *testing.T) {

Loading…
Cancel
Save