diff --git a/internal/census/census.go b/internal/census/census.go index a5d58049cc..921bfb57ca 100644 --- a/internal/census/census.go +++ b/internal/census/census.go @@ -6,6 +6,7 @@ package census import ( "context" "fmt" + "io" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" @@ -14,7 +15,7 @@ import ( ) // RegisterJob registers the census job with the provided scheduler. -func RegisterJob(ctx context.Context, s *scheduler.Scheduler, lurEnabled bool, r db.Reader, w db.Writer) error { +func RegisterJob(ctx context.Context, s *scheduler.Scheduler, lurEnabled bool, r db.Reader, w db.Writer, randomReader io.Reader) error { const op = "census.RegisterJob" if s == nil { return errors.New(ctx, errors.InvalidParameter, "nil scheduler", op, errors.WithoutEvent()) @@ -26,7 +27,7 @@ func RegisterJob(ctx context.Context, s *scheduler.Scheduler, lurEnabled bool, r return errors.New(ctx, errors.Internal, "nil DB writer", op, errors.WithoutEvent()) } - censusJob, err := NewCensusJobFn(ctx, lurEnabled, r, w) + censusJob, err := NewCensusJobFn(ctx, lurEnabled, r, w, randomReader) if err != nil { return fmt.Errorf("error creating census job: %w", err) } diff --git a/internal/census/census_job.go b/internal/census/census_job.go index 1c116cdc54..233a7eb1aa 100644 --- a/internal/census/census_job.go +++ b/internal/census/census_job.go @@ -5,6 +5,7 @@ package census import ( "context" + "io" "time" "github.com/hashicorp/boundary/internal/db" @@ -25,9 +26,10 @@ type censusJob struct { sessionsAgent any activeUsersAgent any eventCtx context.Context + randReader io.Reader } -func newCensusJob(ctx context.Context, lurEnabled bool, r db.Reader, w db.Writer) (*censusJob, error) { +func newCensusJob(ctx context.Context, lurEnabled bool, r db.Reader, w db.Writer, randomReader io.Reader) (*censusJob, error) { const op = "censusJob.newCensusJob" switch { case r == nil: @@ -44,6 +46,7 @@ func newCensusJob(ctx context.Context, lurEnabled bool, r db.Reader, w db.Writer sessionsAgent: nil, activeUsersAgent: nil, eventCtx: ctx, + randReader: randomReader, }, nil } diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 46ac502e60..6793118ba8 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -499,7 +499,6 @@ func DevKeyGeneration(opt ...Option) string { } n, err := randBuf.ReadFrom(&io.LimitedReader{ R: opts.withRandomReader, - N: numBytes, }) if err != nil { diff --git a/internal/daemon/controller/controller.go b/internal/daemon/controller/controller.go index 4efcc4ab9b..34e324a90d 100644 --- a/internal/daemon/controller/controller.go +++ b/internal/daemon/controller/controller.go @@ -445,7 +445,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { return host.NewCatalogRepository(ctx, dbase, dbase) } c.ServersRepoFn = func() (*server.Repository, error) { - return server.NewRepository(ctx, dbase, dbase, c.kms) + return server.NewRepository(ctx, dbase, dbase, c.kms, server.WithRandomReader(c.conf.SecureRandomReader)) } c.OidcRepoFn = func() (*oidc.Repository, error) { return oidc.NewRepository(ctx, dbase, dbase, c.kms) @@ -649,7 +649,7 @@ func (c *Controller) registerJobs() error { if err := snapshot.RegisterJob(c.baseContext, c.scheduler, rw, rw); err != nil { return err } - if err := census.RegisterJob(c.baseContext, c.scheduler, c.conf.RawConfig.Reporting.License.Enabled, rw, rw); err != nil { + if err := census.RegisterJob(c.baseContext, c.scheduler, c.conf.RawConfig.Reporting.License.Enabled, rw, rw, c.conf.SecureRandomReader); err != nil { return err } if err := purge.RegisterJobs(c.baseContext, c.scheduler, rw, rw); err != nil { diff --git a/internal/daemon/worker/handler.go b/internal/daemon/worker/handler.go index 9065e604d6..ad1c79b443 100644 --- a/internal/daemon/worker/handler.go +++ b/internal/daemon/worker/handler.go @@ -281,7 +281,7 @@ func (w *Worker) handleProxy(listenerCfg *listenerutil.ListenerConfig, sessionMa event.WriteError(ctx, op, err) } - handlerOpts := []proxyHandlers.Option{proxyHandlers.WithLogger(w.logger)} + handlerOpts := []proxyHandlers.Option{proxyHandlers.WithLogger(w.logger), proxyHandlers.WithRandomReader(w.conf.SecureRandomReader)} if cb := w.SshKnownHostsCallback.Load(); cb != nil { handlerOpts = append(handlerOpts, proxyHandlers.WithSshHostKeyCallback(*cb)) } diff --git a/internal/daemon/worker/proxy/options.go b/internal/daemon/worker/proxy/options.go index 7bdceecafb..c45bfebedd 100644 --- a/internal/daemon/worker/proxy/options.go +++ b/internal/daemon/worker/proxy/options.go @@ -4,6 +4,8 @@ package proxy import ( + "crypto/rand" + "io" "net" serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -32,6 +34,7 @@ type Options struct { WithTestKerberosServerHostname string WithLogger hclog.Logger WithSshHostKeyCallback ssh.HostKeyCallback + WithRandomReader io.Reader } func getDefaultOptions() Options { @@ -39,6 +42,7 @@ func getDefaultOptions() Options { WithInjectedApplicationCredentials: nil, WithPostConnectionHook: nil, WithLogger: hclog.NewNullLogger(), + WithRandomReader: rand.Reader, } } @@ -97,3 +101,10 @@ func WithSshHostKeyCallback(with ssh.HostKeyCallback) Option { o.WithSshHostKeyCallback = with } } + +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *Options) { + o.WithRandomReader = reader + } +} diff --git a/internal/daemon/worker/proxy/options_test.go b/internal/daemon/worker/proxy/options_test.go index 1b20f9e5bd..402ad4e036 100644 --- a/internal/daemon/worker/proxy/options_test.go +++ b/internal/daemon/worker/proxy/options_test.go @@ -5,9 +5,11 @@ package proxy import ( "crypto/ed25519" + "io" "net" "reflect" "runtime" + "strings" "testing" serverpb "github.com/hashicorp/boundary/internal/gen/controller/servers/services" @@ -76,4 +78,11 @@ func Test_GetOpts(t *testing.T) { opts = GetOpts(WithSshHostKeyCallback(ssh.FixedHostKey(signer.PublicKey()))) assert.NotNil(opts.WithSshHostKeyCallback) }) + t.Run("WithRandomReader", func(t *testing.T) { + reader := io.Reader(&strings.Reader{}) + opts := GetOpts(WithRandomReader(reader)) + testOpts := getDefaultOptions() + testOpts.WithRandomReader = reader + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/server/options_test.go b/internal/server/options_test.go index 7f0584cd9f..65d7bc57fe 100644 --- a/internal/server/options_test.go +++ b/internal/server/options_test.go @@ -5,8 +5,10 @@ package server import ( "context" + "io" "reflect" "runtime" + "strings" "testing" "time" @@ -268,4 +270,13 @@ func Test_GetOpts(t *testing.T) { assert.Equal(t, writer, opts.WithWriter) assert.Equal(t, opts, testOpts) }) + t.Run("WithRandomReader", func(t *testing.T) { + reader := io.Reader(&strings.Reader{}) + opts := GetOpts(WithRandomReader(reader)) + testOpts := getDefaultOptions() + testOpts.withRandomReader = reader + opts.withNewIdFunc = nil + testOpts.withNewIdFunc = nil + assert.Equal(t, opts, testOpts) + }) } diff --git a/internal/server/repository.go b/internal/server/repository.go index c740783014..3a71034fda 100644 --- a/internal/server/repository.go +++ b/internal/server/repository.go @@ -5,6 +5,7 @@ package server import ( "context" + "io" "reflect" "time" @@ -26,6 +27,7 @@ type Repository struct { kms *kms.Kms // defaultLimit provides a default for limiting the number of results returned from the repo defaultLimit int + randomReader io.Reader } // NewRepository creates a new server Repository. Supports the options: WithLimit @@ -52,6 +54,7 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, writer: w, kms: kms, defaultLimit: opts.withLimit, + randomReader: opts.withRandomReader, }, nil }