diff --git a/internal/cmd/base/initial_resources.go b/internal/cmd/base/initial_resources.go index 8c2535c350..3e314e1cb2 100644 --- a/internal/cmd/base/initial_resources.go +++ b/internal/cmd/base/initial_resources.go @@ -510,7 +510,7 @@ func (b *Server) CreateInitialTargetWithAddress(ctx context.Context) (target.Tar return nil, fmt.Errorf("failed to add config keys to kms: %w", err) } - targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache) + targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache, target.WithRandomReader(b.SecureRandomReader)) if err != nil { return nil, fmt.Errorf("failed to create target repository: %w", err) } diff --git a/internal/cmd/commands/dev/dev.go b/internal/cmd/commands/dev/dev.go index 4df1c8dbfd..76a4b6f3ba 100644 --- a/internal/cmd/commands/dev/dev.go +++ b/internal/cmd/commands/dev/dev.go @@ -503,9 +503,10 @@ func (c *Command) Run(args []string) int { c.Config, err = config.DevController( config.WithObservationsEnabled(true), config.WithSysEventsEnabled(true), + config.WithRandomReader(c.SecureRandomReader), ) default: - c.Config, err = config.DevCombined() + c.Config, err = config.DevCombined(config.WithRandomReader(c.SecureRandomReader)) } if err != nil { c.UI.Error(fmt.Errorf("Error creating controller dev config: %w", err).Error()) @@ -905,7 +906,7 @@ func (c *Command) Run(args []string) int { Worker: &store.Worker{ ScopeId: scope.Global.String(), }, - }, server.WithCreateControllerLedActivationToken(true)) + }, server.WithCreateControllerLedActivationToken(true), server.WithRandomReader(c.SecureRandomReader)) if err != nil { c.UI.Error(fmt.Errorf("Error creating worker in database: %w", err).Error()) if err := c.controller.Shutdown(); err != nil { diff --git a/internal/cmd/commands/server/listener_reload_test.go b/internal/cmd/commands/server/listener_reload_test.go index a5472f4501..6650dec508 100644 --- a/internal/cmd/commands/server/listener_reload_test.go +++ b/internal/cmd/commands/server/listener_reload_test.go @@ -12,6 +12,7 @@ package server import ( + "crypto/rand" "crypto/tls" "crypto/x509" "fmt" @@ -92,9 +93,9 @@ func TestServer_ReloadListener(t *testing.T) { td := t.TempDir() - controllerKey := config.DevKeyGeneration() - workerAuthKey := config.DevKeyGeneration() - recoveryKey := config.DevKeyGeneration() + controllerKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader)) + workerAuthKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader)) + recoveryKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader)) cmd := testServerCommand(t, testServerCommandOpts{ CreateDevDatabase: true, diff --git a/internal/cmd/config/config.go b/internal/cmd/config/config.go index 7935977246..46ac502e60 100644 --- a/internal/cmd/config/config.go +++ b/internal/cmd/config/config.go @@ -6,7 +6,6 @@ package config import ( "bytes" "context" - "crypto/rand" "encoding/base64" "encoding/json" "errors" @@ -471,7 +470,7 @@ type License struct { // workers. Supported options: WithObservationsEnabled, WithSysEventsEnabled, // WithAuditEventsEnabled, TestWithErrorEventsEnabled func DevWorker(opt ...Option) (*Config, error) { - workerAuthStorageKey := DevKeyGeneration() + workerAuthStorageKey := DevKeyGeneration(opt...) opts, err := getOpts(opt...) if err != nil { return nil, fmt.Errorf("error parsing options: %w", err) @@ -491,11 +490,16 @@ func DevWorker(opt ...Option) (*Config, error) { return parsed, nil } -func DevKeyGeneration() string { +func DevKeyGeneration(opt ...Option) string { var numBytes int64 = 32 randBuf := new(bytes.Buffer) + opts, err := getOpts(opt...) + if err != nil { + return fmt.Errorf("error parsing options: %w", err).Error() + } n, err := randBuf.ReadFrom(&io.LimitedReader{ - R: rand.Reader, + R: opts.withRandomReader, + N: numBytes, }) if err != nil { @@ -516,10 +520,10 @@ func DevController(opt ...Option) (*Config, error) { return nil, fmt.Errorf("error parsing options: %w", err) } - controllerKey := DevKeyGeneration() - workerAuthKey := DevKeyGeneration() - bsrKey := DevKeyGeneration() - recoveryKey := DevKeyGeneration() + controllerKey := DevKeyGeneration(opt...) + workerAuthKey := DevKeyGeneration(opt...) + bsrKey := DevKeyGeneration(opt...) + recoveryKey := DevKeyGeneration(opt...) hclStr := fmt.Sprintf(devConfig+devControllerExtraConfig, controllerKey, workerAuthKey, bsrKey, recoveryKey) if opts.withIPv6Enabled { @@ -547,11 +551,11 @@ func DevCombined(opt ...Option) (*Config, error) { return nil, fmt.Errorf("error parsing options: %w", err) } - controllerKey := DevKeyGeneration() - workerAuthKey := DevKeyGeneration() - workerAuthStorageKey := DevKeyGeneration() - bsrKey := DevKeyGeneration() - recoveryKey := DevKeyGeneration() + controllerKey := DevKeyGeneration(opt...) + workerAuthKey := DevKeyGeneration(opt...) + workerAuthStorageKey := DevKeyGeneration(opt...) + bsrKey := DevKeyGeneration(opt...) + recoveryKey := DevKeyGeneration(opt...) hclStr := fmt.Sprintf(devConfig+devControllerExtraConfig+devWorkerExtraConfig, controllerKey, workerAuthKey, bsrKey, recoveryKey, workerAuthStorageKey) if opts.withIPv6Enabled { diff --git a/internal/cmd/config/config_test.go b/internal/cmd/config/config_test.go index faf05d77c8..c6109a5bb1 100644 --- a/internal/cmd/config/config_test.go +++ b/internal/cmd/config/config_test.go @@ -4,6 +4,7 @@ package config import ( + "crypto/rand" "encoding/base64" "fmt" "net" @@ -892,7 +893,7 @@ func TestDevCombinedIpv6(t *testing.T) { func TestDevKeyGeneration(t *testing.T) { t.Parallel() - dk := DevKeyGeneration() + dk := DevKeyGeneration(WithRandomReader(rand.Reader)) buf, err := base64.StdEncoding.DecodeString(dk) require.NoError(t, err) require.Len(t, buf, 32) diff --git a/internal/cmd/config/options.go b/internal/cmd/config/options.go index e95564f786..cb7fd03368 100644 --- a/internal/cmd/config/options.go +++ b/internal/cmd/config/options.go @@ -3,6 +3,8 @@ package config import ( + "crypto/rand" + "io" "os" "testing" @@ -37,6 +39,7 @@ type options struct { withObservationsEnabled bool withIPv6Enabled bool testWithErrorEventsEnabled bool + withRandomReader io.Reader } func getDefaultOptions() (options, error) { @@ -72,6 +75,8 @@ func getDefaultOptions() (options, error) { } opts.testWithErrorEventsEnabled = errEvents + opts.withRandomReader = rand.Reader + return opts, nil } @@ -107,6 +112,14 @@ func WithIPv6Enabled(enable bool) Option { } } +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) error { + o.withRandomReader = reader + return nil + } +} + // TestWithErrorEventsEnabled provides an option for enabling error events // during tests. func TestWithErrorEventsEnabled(_ testing.TB, enable bool) Option { diff --git a/internal/credential/options.go b/internal/credential/options.go index 4f74833792..53de2e69d9 100644 --- a/internal/credential/options.go +++ b/internal/credential/options.go @@ -4,7 +4,9 @@ package credential import ( + "crypto/rand" "errors" + "io" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/pagination" @@ -36,10 +38,13 @@ type options struct { WithWriter db.Writer WithLimit int WithStartPageAfterItem pagination.Item + WithRandomReader io.Reader } func getDefaultOptions() *options { - return &options{} + return &options{ + WithRandomReader: rand.Reader, + } } // WithTemplateData provides a way to pass in template information @@ -87,3 +92,11 @@ func WithStartPageAfterItem(item pagination.Item) Option { return nil } } + +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) error { + o.WithRandomReader = reader + return nil + } +} diff --git a/internal/credential/options_test.go b/internal/credential/options_test.go index dbb2c55f6e..545a5e1e3a 100644 --- a/internal/credential/options_test.go +++ b/internal/credential/options_test.go @@ -4,6 +4,7 @@ package credential import ( + "strings" "testing" "time" @@ -119,4 +120,14 @@ func Test_GetOpts(t *testing.T) { assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime)) assert.Equal(opts.WithStartPageAfterItem.GetCreateTime(), timestamp.New(createTime)) }) + + t.Run("WithRandomReader", func(t *testing.T) { + assert := assert.New(t) + reader := strings.NewReader("notrandom") + opts, err := GetOpts(WithRandomReader(reader)) + require.NoError(t, err) + testOpts := getDefaultOptions() + testOpts.WithRandomReader = reader + assert.Equal(opts, testOpts) + }) } diff --git a/internal/credential/vault/options.go b/internal/credential/vault/options.go index 271d4bcdf0..21f7e73bc8 100644 --- a/internal/credential/vault/options.go +++ b/internal/credential/vault/options.go @@ -3,7 +3,12 @@ package vault -import "github.com/hashicorp/boundary/globals" +import ( + "crypto/rand" + "io" + + "github.com/hashicorp/boundary/globals" +) // getOpts - iterate the inbound Options and return a struct func getOpts(opt ...Option) options { @@ -46,10 +51,13 @@ type options struct { withCriticalOptions string withExtensions string withAdditionalValidPrincipals []string + withRandomReader io.Reader } func getDefaultOptions() options { - return options{} + return options{ + withRandomReader: rand.Reader, + } } // WithDescription provides an optional description. @@ -247,3 +255,10 @@ func WithAdditionalValidPrincipals(p []string) Option { o.withAdditionalValidPrincipals = p } } + +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) { + o.withRandomReader = reader + } +} diff --git a/internal/credential/vault/options_test.go b/internal/credential/vault/options_test.go index 77d1201b15..1e622a46ff 100644 --- a/internal/credential/vault/options_test.go +++ b/internal/credential/vault/options_test.go @@ -5,6 +5,7 @@ package vault import ( "context" + "strings" "testing" "github.com/hashicorp/boundary/globals" @@ -131,4 +132,13 @@ func Test_GetOpts(t *testing.T) { testOpts.withMappingOverride = unknownMapper(1) assert.Equal(t, opts, testOpts) }) + + t.Run("WithRandomReader", func(t *testing.T) { + assert := assert.New(t) + reader := strings.NewReader("notrandom") + opts := getOpts(WithRandomReader(reader)) + testOpts := getDefaultOptions() + testOpts.withRandomReader = reader + assert.Equal(opts, testOpts) + }) } diff --git a/internal/credential/vault/private_library.go b/internal/credential/vault/private_library.go index 0011a6df04..255728c627 100644 --- a/internal/credential/vault/private_library.go +++ b/internal/credential/vault/private_library.go @@ -8,7 +8,6 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" - "crypto/rand" "crypto/rsa" "crypto/x509" "database/sql" @@ -936,16 +935,21 @@ func (lib *sshCertIssuingCredentialLibrary) client(ctx context.Context) (vaultCl return client, nil } -func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int) (string, []byte, error) { +func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int, opt ...credential.Option) (string, []byte, error) { const op = "vault.generatePublicPrivateKeys" pemBlock := pem.Block{} var sshKey ssh.PublicKey + opts, err := credential.GetOpts(opt...) + if err != nil { + return "", nil, errors.Wrap(ctx, err, op) + } + switch keyType { case KeyTypeRsa: pemBlock.Type = "RSA PRIVATE KEY" // these values are copied from the crypto ssh library in ssh/keys.go - key, err := rsa.GenerateKey(rand.Reader, keyBits) + key, err := rsa.GenerateKey(opts.WithRandomReader, keyBits) if err != nil { return "", nil, errors.Wrap(ctx, err, op) } @@ -958,7 +962,7 @@ func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int) case KeyTypeEd25519: pemBlock.Type = "OPENSSH PRIVATE KEY" // these values are copied from the crypto ssh library in ssh/keys.go - pubKey, privKey, err := ed25519.GenerateKey(rand.Reader) + pubKey, privKey, err := ed25519.GenerateKey(opts.WithRandomReader) if err != nil { return "", nil, errors.Wrap(ctx, err, op) } @@ -985,7 +989,7 @@ func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int) return "", nil, errors.New(ctx, errors.InvalidParameter, op, "invalid KeyBits. when KeyType=ecdsa, KeyBits must be one of: 256, 384, or 521") } - key, err := ecdsa.GenerateKey(curve, rand.Reader) + key, err := ecdsa.GenerateKey(curve, opts.WithRandomReader) if err != nil { return "", nil, errors.Wrap(ctx, err, op) } @@ -1109,7 +1113,7 @@ func (lib *sshCertIssuingCredentialLibrary) retrieveCredential(ctx context.Conte // by definition, if match exists, then match[1] == "sign" or "issue" switch match[1] { case "sign": - payload.PublicKey, privateKey, err = generatePublicPrivateKeys(ctx, lib.KeyType, lib.KeyBits) + payload.PublicKey, privateKey, err = generatePublicPrivateKeys(ctx, lib.KeyType, lib.KeyBits, credential.WithRandomReader(opts.WithRandomReader)) if err != nil { return nil, errors.Wrap(ctx, err, op) } diff --git a/internal/credential/vault/repository.go b/internal/credential/vault/repository.go index 31d80050be..9c28059e85 100644 --- a/internal/credential/vault/repository.go +++ b/internal/credential/vault/repository.go @@ -5,6 +5,7 @@ package vault import ( "context" + "io" "github.com/hashicorp/boundary/internal/db" "github.com/hashicorp/boundary/internal/errors" @@ -22,6 +23,7 @@ type Repository struct { // defaultLimit provides a default for limiting the number of results // returned from the repo defaultLimit int + randomReader io.Reader } // NewRepository creates a new Repository. The returned repository should @@ -53,5 +55,6 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, kms: kms, scheduler: scheduler, defaultLimit: opts.withLimit, + randomReader: opts.withRandomReader, }, nil } diff --git a/internal/credential/vault/repository_credentials.go b/internal/credential/vault/repository_credentials.go index 968fc1332a..84f2ca0d73 100644 --- a/internal/credential/vault/repository_credentials.go +++ b/internal/credential/vault/repository_credentials.go @@ -72,6 +72,9 @@ func (r *Repository) Issue(ctx context.Context, sessionId string, requests []cre var creds []credential.Dynamic var minLease time.Duration runJobsInterval := r.scheduler.GetRunJobsInterval() + + // passing SecureRandomReader to credential libraries + opt = append(opt, credential.WithRandomReader(r.randomReader)) for _, lib := range libs { cred, err := lib.retrieveCredential(ctx, op, opt...) if err != nil { diff --git a/internal/credential/vault/repository_test.go b/internal/credential/vault/repository_test.go index c8129d2164..0c025d5453 100644 --- a/internal/credential/vault/repository_test.go +++ b/internal/credential/vault/repository_test.go @@ -5,6 +5,8 @@ package vault import ( "context" + "crypto/rand" + "strings" "testing" "github.com/hashicorp/boundary/internal/db" @@ -22,6 +24,7 @@ func TestRepository_New(t *testing.T) { wrapper := db.TestWrapper(t) kmsCache := kms.TestKms(t, conn, wrapper) sche := scheduler.TestScheduler(t, conn, wrapper) + testReader := strings.NewReader("notrandom") type args struct { r db.Reader @@ -51,6 +54,7 @@ func TestRepository_New(t *testing.T) { kms: kmsCache, scheduler: sche, defaultLimit: db.DefaultLimit, + randomReader: rand.Reader, }, }, { @@ -60,7 +64,8 @@ func TestRepository_New(t *testing.T) { w: rw, kms: kmsCache, scheduler: sche, - opts: []Option{WithLimit(5)}, + opts: []Option{WithLimit(5), + WithRandomReader(testReader)}, }, want: &Repository{ reader: rw, @@ -68,6 +73,7 @@ func TestRepository_New(t *testing.T) { kms: kmsCache, scheduler: sche, defaultLimit: 5, + randomReader: testReader, }, }, { diff --git a/internal/daemon/controller/controller.go b/internal/daemon/controller/controller.go index 57acc13605..4efcc4ab9b 100644 --- a/internal/daemon/controller/controller.go +++ b/internal/daemon/controller/controller.go @@ -433,7 +433,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { authtoken.WithTokenTimeToStaleDuration(c.conf.RawConfig.Controller.AuthTokenTimeToStaleDuration)) } c.VaultCredentialRepoFn = func() (*vault.Repository, error) { - return vault.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler) + return vault.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler, vault.WithRandomReader(c.conf.SecureRandomReader)) } c.StaticCredentialRepoFn = func() (*credstatic.Repository, error) { return credstatic.NewRepository(ctx, dbase, dbase, c.kms) @@ -454,12 +454,13 @@ func New(ctx context.Context, conf *Config) (*Controller, error) { return ldap.NewRepository(ctx, dbase, dbase, c.kms) } c.PasswordAuthRepoFn = func() (*password.Repository, error) { - return password.NewRepository(ctx, dbase, dbase, c.kms) + return password.NewRepository(ctx, dbase, dbase, c.kms, password.WithRandomReader(c.conf.SecureRandomReader)) } c.AuthMethodRepoFn = func() (*auth.AuthMethodRepository, error) { return auth.NewAuthMethodRepository(ctx, dbase, dbase, c.kms) } c.TargetRepoFn = func(o ...target.Option) (*target.Repository, error) { + o = append(o, target.WithRandomReader(c.conf.SecureRandomReader)) return target.NewRepository(ctx, dbase, dbase, c.kms, o...) } c.SessionRepoFn = func(opt ...session.Option) (*session.Repository, error) { diff --git a/internal/server/options.go b/internal/server/options.go index a28a8c2288..083b1bd83f 100644 --- a/internal/server/options.go +++ b/internal/server/options.go @@ -5,6 +5,8 @@ package server import ( "context" + "crypto/rand" + "io" "time" "github.com/hashicorp/boundary/internal/db" @@ -66,6 +68,7 @@ type options struct { withFilterWorkersByLocalStorageState bool WithReader db.Reader WithWriter db.Writer + withRandomReader io.Reader } func getDefaultOptions() options { @@ -73,6 +76,7 @@ func getDefaultOptions() options { withNewIdFunc: newWorkerId, withOperationalState: ActiveOperationalState.String(), withLocalStorageState: UnknownWorkerType.String(), + withRandomReader: rand.Reader, } } @@ -315,3 +319,10 @@ func WithReaderWriter(r db.Reader, w db.Writer) Option { o.WithWriter = w } } + +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) { + o.withRandomReader = reader + } +} diff --git a/internal/server/worker_list.go b/internal/server/worker_list.go index fb505694b7..1ca7d69088 100644 --- a/internal/server/worker_list.go +++ b/internal/server/worker_list.go @@ -71,10 +71,13 @@ func (w WorkerList) SupportsFeature(f version.Feature) WorkerList { // Shuffle returns a randomly-shuffled copy of the caller's Workers (using // crypto/rand). If the caller's WorkerList has one element or less, this // function is a no-op. -func (w WorkerList) Shuffle() (WorkerList, error) { +// Supported options: +// - WithRandomReader +func (w WorkerList) Shuffle(opt ...Option) (WorkerList, error) { if len(w) <= 1 { return w, nil } + opts := GetOpts(opt...) ret := make(WorkerList, len(w)) copy(ret, w) @@ -83,7 +86,7 @@ func (w WorkerList) Shuffle() (WorkerList, error) { // math/rand.Shuffle, but using the crypto/rand package instead. The same // caveats as math/rand.Shuffle apply. for i := len(ret) - 1; i > 0; i-- { - j, err := rand.Int(rand.Reader, big.NewInt(int64(i+1))) + j, err := rand.Int(opts.withRandomReader, big.NewInt(int64(i+1))) if err != nil { return nil, err } diff --git a/internal/target/options.go b/internal/target/options.go index af97d36723..59392d62b5 100644 --- a/internal/target/options.go +++ b/internal/target/options.go @@ -4,6 +4,8 @@ package target import ( + "crypto/rand" + "io" "net" "time" @@ -58,6 +60,7 @@ type options struct { WithAlias *talias.Alias withAliases []*talias.Alias withTargetId string + withRandomReader io.Reader } func getDefaultOptions() options { @@ -86,6 +89,7 @@ func getDefaultOptions() options { WithAddress: "", WithNetResolver: net.DefaultResolver, withTargetId: "", + withRandomReader: rand.Reader, } } @@ -301,3 +305,10 @@ func WithTargetId(in string) Option { o.withTargetId = in } } + +// WithRandomReader provides an option to specify a random reader. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) { + o.withRandomReader = reader + } +} diff --git a/internal/target/options_test.go b/internal/target/options_test.go index 88b7d7fc95..49812af30a 100644 --- a/internal/target/options_test.go +++ b/internal/target/options_test.go @@ -5,6 +5,7 @@ package target import ( "context" + "strings" "testing" "time" @@ -282,4 +283,12 @@ func Test_GetOpts(t *testing.T) { testOpts.withTargetId = "testId" assert.Equal(opts, testOpts) }) + t.Run("WithRandomReader", func(t *testing.T) { + assert := assert.New(t) + reader := strings.NewReader("notrandom") + opts := GetOpts(WithRandomReader(reader)) + testOpts := getDefaultOptions() + testOpts.withRandomReader = reader + assert.Equal(opts, testOpts) + }) } diff --git a/internal/target/repository.go b/internal/target/repository.go index bd8b7aa035..a646c8f93a 100644 --- a/internal/target/repository.go +++ b/internal/target/repository.go @@ -7,6 +7,7 @@ import ( "context" "database/sql" "fmt" + "io" "strings" "time" @@ -48,7 +49,8 @@ type Repository struct { // has access to in terms of actions and resources and we use it to build queries. // These are passed in on the repository constructor using `WithPermissions`, meaning the // `Repository` object is contextualized to whatever the request context is. - permissions []perms.Permission + permissions []perms.Permission + randomReader io.Reader } // NewRepository creates a new target Repository. @@ -86,6 +88,7 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms, kms: kms, defaultLimit: opts.WithLimit, permissions: opts.WithPermissions, + randomReader: opts.withRandomReader, }, nil } @@ -141,7 +144,7 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu } if opts.WithAlias != nil { - cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds()) + cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds(), WithRandomReader(r.randomReader)) if err != nil && !errors.IsNotFoundError(err) { return errors.Wrap(ctx, err, op) } diff --git a/internal/target/repository_proxy_server_certificate.go b/internal/target/repository_proxy_server_certificate.go index e9b62bc33a..21048dfff9 100644 --- a/internal/target/repository_proxy_server_certificate.go +++ b/internal/target/repository_proxy_server_certificate.go @@ -108,7 +108,7 @@ func fetchTargetProxyServerCertificate(ctx context.Context, r db.Reader, w db.Wr return maybeRegenerateCert(ctx, targetCert, w, wrapper, sessionMaxSeconds) } -func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w db.Writer, targetId, scopeId string, alias *talias.Alias, wrapper wrapping.Wrapper, sessionMaxSeconds uint32) (*ServerCertificate, error) { +func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w db.Writer, targetId, scopeId string, alias *talias.Alias, wrapper wrapping.Wrapper, sessionMaxSeconds uint32, opt ...Option) (*ServerCertificate, error) { const op = "target.fetchTargetProxyServerCert" switch { case wrapper == nil: @@ -146,7 +146,7 @@ func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w // Create the cert, if not found- alias certs are not created as part of target creation. var err error if aliasCert.Certificate == nil { - aliasCert, err = NewTargetAliasProxyCertificate(ctx, targetId, alias) + aliasCert, err = NewTargetAliasProxyCertificate(ctx, targetId, alias, opt...) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error creating new target alias proxy certificate")) } diff --git a/internal/target/repository_test.go b/internal/target/repository_test.go index 237537cd05..a1e02b9f45 100644 --- a/internal/target/repository_test.go +++ b/internal/target/repository_test.go @@ -5,6 +5,8 @@ package target import ( "context" + "crypto/rand" + "strings" "testing" "time" @@ -22,6 +24,8 @@ func TestNewRepository(t *testing.T) { rw := db.New(conn) wrapper := db.TestWrapper(t) testKms := kms.TestKms(t, conn, wrapper) + testReader := strings.NewReader("notrandom") + type args struct { r db.Reader w db.Writer @@ -47,6 +51,7 @@ func TestNewRepository(t *testing.T) { writer: rw, kms: testKms, defaultLimit: db.DefaultLimit, + randomReader: rand.Reader, }, wantErr: false, }, @@ -78,6 +83,9 @@ func TestNewRepository(t *testing.T) { r: nil, w: rw, kms: testKms, + opts: []Option{ + WithRandomReader(testReader), + }, }, want: nil, wantErr: true, @@ -94,6 +102,7 @@ func TestNewRepository(t *testing.T) { {GrantScopeId: "test1", Resource: resource.Target}, {GrantScopeId: "test2", Resource: resource.Target}, }), + WithRandomReader(testReader), }, }, want: &Repository{ @@ -105,6 +114,7 @@ func TestNewRepository(t *testing.T) { {GrantScopeId: "test1", Resource: resource.Target}, {GrantScopeId: "test2", Resource: resource.Target}, }, + randomReader: testReader, }, wantErr: false, }, diff --git a/internal/target/target_certificate.go b/internal/target/target_certificate.go index eb83e30b77..db5585dca9 100644 --- a/internal/target/target_certificate.go +++ b/internal/target/target_certificate.go @@ -11,6 +11,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "io" "math" "math/big" "net" @@ -27,10 +28,10 @@ import ( "google.golang.org/protobuf/proto" ) -func generatePrivAndPubKeys(ctx context.Context) (privKeyBytes []byte, pubKeyBytes []byte, err error) { +func generatePrivAndPubKeys(ctx context.Context, randomReader io.Reader) (privKeyBytes []byte, pubKeyBytes []byte, err error) { const op = "target.generatePrivAndPubKeys" // Generate a private key using the P521 curve - key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + key, err := ecdsa.GenerateKey(elliptic.P521(), randomReader) if err != nil { return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "failed to generate ECDSA key") } @@ -62,7 +63,7 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time opts := GetOpts(opt...) - randomSerialNumber, err := rand.Int(rand.Reader, big.NewInt(int64(math.MaxInt64))) + randomSerialNumber, err := rand.Int(opts.withRandomReader, big.NewInt(int64(math.MaxInt64))) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating random serial number")) } @@ -87,7 +88,7 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time template.DNSNames = append(template.DNSNames, opts.WithAlias.Value) } - certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey) + certBytes, err := x509.CreateCertificate(opts.withRandomReader, template, template, &privKey.PublicKey, privKey) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert)) } @@ -97,7 +98,9 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time func generateKeysAndCert(ctx context.Context, notValidAfter time.Time, opt ...Option) (privKey []byte, pubKey []byte, cert []byte, err error) { const op = "target.generateKeysAndCert" - privKey, pubKey, err = generatePrivAndPubKeys(ctx) + opts := GetOpts(opt...) + + privKey, pubKey, err = generatePrivAndPubKeys(ctx, opts.withRandomReader) if err != nil { return nil, nil, nil, errors.Wrap(ctx, err, op) } @@ -128,7 +131,7 @@ func NewTargetProxyCertificate(ctx context.Context, opt ...Option) (*TargetProxy opts := GetOpts(opt...) notValidAfter := time.Now().AddDate(1, 0, 0) // 1 year from now - privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter) + privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, opt...) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating target proxy cert and keys")) } @@ -245,7 +248,7 @@ type TargetAliasProxyCertificate struct { } // NewTargetAliasProxyCertificate creates a new in memory TargetAliasProxyCertificate -func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias *talias.Alias) (*TargetAliasProxyCertificate, error) { +func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias *talias.Alias, opt ...Option) (*TargetAliasProxyCertificate, error) { const op = "target.NewTargetAliasProxyCertificate" switch { case targetId == "": @@ -255,7 +258,9 @@ func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias } notValidAfter := time.Now().AddDate(1, 0, 0) // 1 year from now - privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, WithAlias(alias)) + + opt = append(opt, WithAlias(alias)) + privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, opt...) if err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating target proxy cert and keys")) }