refactor : ICU-18284-Unify SecureRandomReader (#6303)

* Unify SecureRandomReader

* vault changes updated

* target changes added

* reviwed changes updated
pull/6310/head
Abhishek Manjegowda 5 months ago committed by GitHub
parent 11fc97622f
commit c4dc2a43cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -510,7 +510,7 @@ func (b *Server) CreateInitialTargetWithAddress(ctx context.Context) (target.Tar
return nil, fmt.Errorf("failed to add config keys to kms: %w", err)
}
targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache)
targetRepo, err := target.NewRepository(ctx, rw, rw, kmsCache, target.WithRandomReader(b.SecureRandomReader))
if err != nil {
return nil, fmt.Errorf("failed to create target repository: %w", err)
}

@ -503,9 +503,10 @@ func (c *Command) Run(args []string) int {
c.Config, err = config.DevController(
config.WithObservationsEnabled(true),
config.WithSysEventsEnabled(true),
config.WithRandomReader(c.SecureRandomReader),
)
default:
c.Config, err = config.DevCombined()
c.Config, err = config.DevCombined(config.WithRandomReader(c.SecureRandomReader))
}
if err != nil {
c.UI.Error(fmt.Errorf("Error creating controller dev config: %w", err).Error())
@ -905,7 +906,7 @@ func (c *Command) Run(args []string) int {
Worker: &store.Worker{
ScopeId: scope.Global.String(),
},
}, server.WithCreateControllerLedActivationToken(true))
}, server.WithCreateControllerLedActivationToken(true), server.WithRandomReader(c.SecureRandomReader))
if err != nil {
c.UI.Error(fmt.Errorf("Error creating worker in database: %w", err).Error())
if err := c.controller.Shutdown(); err != nil {

@ -12,6 +12,7 @@
package server
import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
@ -92,9 +93,9 @@ func TestServer_ReloadListener(t *testing.T) {
td := t.TempDir()
controllerKey := config.DevKeyGeneration()
workerAuthKey := config.DevKeyGeneration()
recoveryKey := config.DevKeyGeneration()
controllerKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader))
workerAuthKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader))
recoveryKey := config.DevKeyGeneration(config.WithRandomReader(rand.Reader))
cmd := testServerCommand(t, testServerCommandOpts{
CreateDevDatabase: true,

@ -6,7 +6,6 @@ package config
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
@ -471,7 +470,7 @@ type License struct {
// workers. Supported options: WithObservationsEnabled, WithSysEventsEnabled,
// WithAuditEventsEnabled, TestWithErrorEventsEnabled
func DevWorker(opt ...Option) (*Config, error) {
workerAuthStorageKey := DevKeyGeneration()
workerAuthStorageKey := DevKeyGeneration(opt...)
opts, err := getOpts(opt...)
if err != nil {
return nil, fmt.Errorf("error parsing options: %w", err)
@ -491,11 +490,16 @@ func DevWorker(opt ...Option) (*Config, error) {
return parsed, nil
}
func DevKeyGeneration() string {
func DevKeyGeneration(opt ...Option) string {
var numBytes int64 = 32
randBuf := new(bytes.Buffer)
opts, err := getOpts(opt...)
if err != nil {
return fmt.Errorf("error parsing options: %w", err).Error()
}
n, err := randBuf.ReadFrom(&io.LimitedReader{
R: rand.Reader,
R: opts.withRandomReader,
N: numBytes,
})
if err != nil {
@ -516,10 +520,10 @@ func DevController(opt ...Option) (*Config, error) {
return nil, fmt.Errorf("error parsing options: %w", err)
}
controllerKey := DevKeyGeneration()
workerAuthKey := DevKeyGeneration()
bsrKey := DevKeyGeneration()
recoveryKey := DevKeyGeneration()
controllerKey := DevKeyGeneration(opt...)
workerAuthKey := DevKeyGeneration(opt...)
bsrKey := DevKeyGeneration(opt...)
recoveryKey := DevKeyGeneration(opt...)
hclStr := fmt.Sprintf(devConfig+devControllerExtraConfig, controllerKey, workerAuthKey, bsrKey, recoveryKey)
if opts.withIPv6Enabled {
@ -547,11 +551,11 @@ func DevCombined(opt ...Option) (*Config, error) {
return nil, fmt.Errorf("error parsing options: %w", err)
}
controllerKey := DevKeyGeneration()
workerAuthKey := DevKeyGeneration()
workerAuthStorageKey := DevKeyGeneration()
bsrKey := DevKeyGeneration()
recoveryKey := DevKeyGeneration()
controllerKey := DevKeyGeneration(opt...)
workerAuthKey := DevKeyGeneration(opt...)
workerAuthStorageKey := DevKeyGeneration(opt...)
bsrKey := DevKeyGeneration(opt...)
recoveryKey := DevKeyGeneration(opt...)
hclStr := fmt.Sprintf(devConfig+devControllerExtraConfig+devWorkerExtraConfig, controllerKey, workerAuthKey, bsrKey, recoveryKey, workerAuthStorageKey)
if opts.withIPv6Enabled {

@ -4,6 +4,7 @@
package config
import (
"crypto/rand"
"encoding/base64"
"fmt"
"net"
@ -892,7 +893,7 @@ func TestDevCombinedIpv6(t *testing.T) {
func TestDevKeyGeneration(t *testing.T) {
t.Parallel()
dk := DevKeyGeneration()
dk := DevKeyGeneration(WithRandomReader(rand.Reader))
buf, err := base64.StdEncoding.DecodeString(dk)
require.NoError(t, err)
require.Len(t, buf, 32)

@ -3,6 +3,8 @@
package config
import (
"crypto/rand"
"io"
"os"
"testing"
@ -37,6 +39,7 @@ type options struct {
withObservationsEnabled bool
withIPv6Enabled bool
testWithErrorEventsEnabled bool
withRandomReader io.Reader
}
func getDefaultOptions() (options, error) {
@ -72,6 +75,8 @@ func getDefaultOptions() (options, error) {
}
opts.testWithErrorEventsEnabled = errEvents
opts.withRandomReader = rand.Reader
return opts, nil
}
@ -107,6 +112,14 @@ func WithIPv6Enabled(enable bool) Option {
}
}
// WithRandomReader provides an option to specify a random reader.
func WithRandomReader(reader io.Reader) Option {
return func(o *options) error {
o.withRandomReader = reader
return nil
}
}
// TestWithErrorEventsEnabled provides an option for enabling error events
// during tests.
func TestWithErrorEventsEnabled(_ testing.TB, enable bool) Option {

@ -4,7 +4,9 @@
package credential
import (
"crypto/rand"
"errors"
"io"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/pagination"
@ -36,10 +38,13 @@ type options struct {
WithWriter db.Writer
WithLimit int
WithStartPageAfterItem pagination.Item
WithRandomReader io.Reader
}
func getDefaultOptions() *options {
return &options{}
return &options{
WithRandomReader: rand.Reader,
}
}
// WithTemplateData provides a way to pass in template information
@ -87,3 +92,11 @@ func WithStartPageAfterItem(item pagination.Item) Option {
return nil
}
}
// WithRandomReader provides an option to specify a random reader.
func WithRandomReader(reader io.Reader) Option {
return func(o *options) error {
o.WithRandomReader = reader
return nil
}
}

@ -4,6 +4,7 @@
package credential
import (
"strings"
"testing"
"time"
@ -119,4 +120,14 @@ func Test_GetOpts(t *testing.T) {
assert.Equal(opts.WithStartPageAfterItem.GetUpdateTime(), timestamp.New(updateTime))
assert.Equal(opts.WithStartPageAfterItem.GetCreateTime(), timestamp.New(createTime))
})
t.Run("WithRandomReader", func(t *testing.T) {
assert := assert.New(t)
reader := strings.NewReader("notrandom")
opts, err := GetOpts(WithRandomReader(reader))
require.NoError(t, err)
testOpts := getDefaultOptions()
testOpts.WithRandomReader = reader
assert.Equal(opts, testOpts)
})
}

@ -3,7 +3,12 @@
package vault
import "github.com/hashicorp/boundary/globals"
import (
"crypto/rand"
"io"
"github.com/hashicorp/boundary/globals"
)
// getOpts - iterate the inbound Options and return a struct
func getOpts(opt ...Option) options {
@ -46,10 +51,13 @@ type options struct {
withCriticalOptions string
withExtensions string
withAdditionalValidPrincipals []string
withRandomReader io.Reader
}
func getDefaultOptions() options {
return options{}
return options{
withRandomReader: rand.Reader,
}
}
// WithDescription provides an optional description.
@ -247,3 +255,10 @@ func WithAdditionalValidPrincipals(p []string) Option {
o.withAdditionalValidPrincipals = p
}
}
// WithRandomReader provides an option to specify a random reader.
func WithRandomReader(reader io.Reader) Option {
return func(o *options) {
o.withRandomReader = reader
}
}

@ -5,6 +5,7 @@ package vault
import (
"context"
"strings"
"testing"
"github.com/hashicorp/boundary/globals"
@ -131,4 +132,13 @@ func Test_GetOpts(t *testing.T) {
testOpts.withMappingOverride = unknownMapper(1)
assert.Equal(t, opts, testOpts)
})
t.Run("WithRandomReader", func(t *testing.T) {
assert := assert.New(t)
reader := strings.NewReader("notrandom")
opts := getOpts(WithRandomReader(reader))
testOpts := getDefaultOptions()
testOpts.withRandomReader = reader
assert.Equal(opts, testOpts)
})
}

@ -8,7 +8,6 @@ import (
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"database/sql"
@ -936,16 +935,21 @@ func (lib *sshCertIssuingCredentialLibrary) client(ctx context.Context) (vaultCl
return client, nil
}
func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int) (string, []byte, error) {
func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int, opt ...credential.Option) (string, []byte, error) {
const op = "vault.generatePublicPrivateKeys"
pemBlock := pem.Block{}
var sshKey ssh.PublicKey
opts, err := credential.GetOpts(opt...)
if err != nil {
return "", nil, errors.Wrap(ctx, err, op)
}
switch keyType {
case KeyTypeRsa:
pemBlock.Type = "RSA PRIVATE KEY" // these values are copied from the crypto ssh library in ssh/keys.go
key, err := rsa.GenerateKey(rand.Reader, keyBits)
key, err := rsa.GenerateKey(opts.WithRandomReader, keyBits)
if err != nil {
return "", nil, errors.Wrap(ctx, err, op)
}
@ -958,7 +962,7 @@ func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int)
case KeyTypeEd25519:
pemBlock.Type = "OPENSSH PRIVATE KEY" // these values are copied from the crypto ssh library in ssh/keys.go
pubKey, privKey, err := ed25519.GenerateKey(rand.Reader)
pubKey, privKey, err := ed25519.GenerateKey(opts.WithRandomReader)
if err != nil {
return "", nil, errors.Wrap(ctx, err, op)
}
@ -985,7 +989,7 @@ func generatePublicPrivateKeys(ctx context.Context, keyType string, keyBits int)
return "", nil, errors.New(ctx, errors.InvalidParameter, op, "invalid KeyBits. when KeyType=ecdsa, KeyBits must be one of: 256, 384, or 521")
}
key, err := ecdsa.GenerateKey(curve, rand.Reader)
key, err := ecdsa.GenerateKey(curve, opts.WithRandomReader)
if err != nil {
return "", nil, errors.Wrap(ctx, err, op)
}
@ -1109,7 +1113,7 @@ func (lib *sshCertIssuingCredentialLibrary) retrieveCredential(ctx context.Conte
// by definition, if match exists, then match[1] == "sign" or "issue"
switch match[1] {
case "sign":
payload.PublicKey, privateKey, err = generatePublicPrivateKeys(ctx, lib.KeyType, lib.KeyBits)
payload.PublicKey, privateKey, err = generatePublicPrivateKeys(ctx, lib.KeyType, lib.KeyBits, credential.WithRandomReader(opts.WithRandomReader))
if err != nil {
return nil, errors.Wrap(ctx, err, op)
}

@ -5,6 +5,7 @@ package vault
import (
"context"
"io"
"github.com/hashicorp/boundary/internal/db"
"github.com/hashicorp/boundary/internal/errors"
@ -22,6 +23,7 @@ type Repository struct {
// defaultLimit provides a default for limiting the number of results
// returned from the repo
defaultLimit int
randomReader io.Reader
}
// NewRepository creates a new Repository. The returned repository should
@ -53,5 +55,6 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms,
kms: kms,
scheduler: scheduler,
defaultLimit: opts.withLimit,
randomReader: opts.withRandomReader,
}, nil
}

@ -72,6 +72,9 @@ func (r *Repository) Issue(ctx context.Context, sessionId string, requests []cre
var creds []credential.Dynamic
var minLease time.Duration
runJobsInterval := r.scheduler.GetRunJobsInterval()
// passing SecureRandomReader to credential libraries
opt = append(opt, credential.WithRandomReader(r.randomReader))
for _, lib := range libs {
cred, err := lib.retrieveCredential(ctx, op, opt...)
if err != nil {

@ -5,6 +5,8 @@ package vault
import (
"context"
"crypto/rand"
"strings"
"testing"
"github.com/hashicorp/boundary/internal/db"
@ -22,6 +24,7 @@ func TestRepository_New(t *testing.T) {
wrapper := db.TestWrapper(t)
kmsCache := kms.TestKms(t, conn, wrapper)
sche := scheduler.TestScheduler(t, conn, wrapper)
testReader := strings.NewReader("notrandom")
type args struct {
r db.Reader
@ -51,6 +54,7 @@ func TestRepository_New(t *testing.T) {
kms: kmsCache,
scheduler: sche,
defaultLimit: db.DefaultLimit,
randomReader: rand.Reader,
},
},
{
@ -60,7 +64,8 @@ func TestRepository_New(t *testing.T) {
w: rw,
kms: kmsCache,
scheduler: sche,
opts: []Option{WithLimit(5)},
opts: []Option{WithLimit(5),
WithRandomReader(testReader)},
},
want: &Repository{
reader: rw,
@ -68,6 +73,7 @@ func TestRepository_New(t *testing.T) {
kms: kmsCache,
scheduler: sche,
defaultLimit: 5,
randomReader: testReader,
},
},
{

@ -433,7 +433,7 @@ func New(ctx context.Context, conf *Config) (*Controller, error) {
authtoken.WithTokenTimeToStaleDuration(c.conf.RawConfig.Controller.AuthTokenTimeToStaleDuration))
}
c.VaultCredentialRepoFn = func() (*vault.Repository, error) {
return vault.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler)
return vault.NewRepository(ctx, dbase, dbase, c.kms, c.scheduler, vault.WithRandomReader(c.conf.SecureRandomReader))
}
c.StaticCredentialRepoFn = func() (*credstatic.Repository, error) {
return credstatic.NewRepository(ctx, dbase, dbase, c.kms)
@ -454,12 +454,13 @@ func New(ctx context.Context, conf *Config) (*Controller, error) {
return ldap.NewRepository(ctx, dbase, dbase, c.kms)
}
c.PasswordAuthRepoFn = func() (*password.Repository, error) {
return password.NewRepository(ctx, dbase, dbase, c.kms)
return password.NewRepository(ctx, dbase, dbase, c.kms, password.WithRandomReader(c.conf.SecureRandomReader))
}
c.AuthMethodRepoFn = func() (*auth.AuthMethodRepository, error) {
return auth.NewAuthMethodRepository(ctx, dbase, dbase, c.kms)
}
c.TargetRepoFn = func(o ...target.Option) (*target.Repository, error) {
o = append(o, target.WithRandomReader(c.conf.SecureRandomReader))
return target.NewRepository(ctx, dbase, dbase, c.kms, o...)
}
c.SessionRepoFn = func(opt ...session.Option) (*session.Repository, error) {

@ -5,6 +5,8 @@ package server
import (
"context"
"crypto/rand"
"io"
"time"
"github.com/hashicorp/boundary/internal/db"
@ -66,6 +68,7 @@ type options struct {
withFilterWorkersByLocalStorageState bool
WithReader db.Reader
WithWriter db.Writer
withRandomReader io.Reader
}
func getDefaultOptions() options {
@ -73,6 +76,7 @@ func getDefaultOptions() options {
withNewIdFunc: newWorkerId,
withOperationalState: ActiveOperationalState.String(),
withLocalStorageState: UnknownWorkerType.String(),
withRandomReader: rand.Reader,
}
}
@ -315,3 +319,10 @@ func WithReaderWriter(r db.Reader, w db.Writer) Option {
o.WithWriter = w
}
}
// WithRandomReader provides an option to specify a random reader.
func WithRandomReader(reader io.Reader) Option {
return func(o *options) {
o.withRandomReader = reader
}
}

@ -71,10 +71,13 @@ func (w WorkerList) SupportsFeature(f version.Feature) WorkerList {
// Shuffle returns a randomly-shuffled copy of the caller's Workers (using
// crypto/rand). If the caller's WorkerList has one element or less, this
// function is a no-op.
func (w WorkerList) Shuffle() (WorkerList, error) {
// Supported options:
// - WithRandomReader
func (w WorkerList) Shuffle(opt ...Option) (WorkerList, error) {
if len(w) <= 1 {
return w, nil
}
opts := GetOpts(opt...)
ret := make(WorkerList, len(w))
copy(ret, w)
@ -83,7 +86,7 @@ func (w WorkerList) Shuffle() (WorkerList, error) {
// math/rand.Shuffle, but using the crypto/rand package instead. The same
// caveats as math/rand.Shuffle apply.
for i := len(ret) - 1; i > 0; i-- {
j, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
j, err := rand.Int(opts.withRandomReader, big.NewInt(int64(i+1)))
if err != nil {
return nil, err
}

@ -4,6 +4,8 @@
package target
import (
"crypto/rand"
"io"
"net"
"time"
@ -58,6 +60,7 @@ type options struct {
WithAlias *talias.Alias
withAliases []*talias.Alias
withTargetId string
withRandomReader io.Reader
}
func getDefaultOptions() options {
@ -86,6 +89,7 @@ func getDefaultOptions() options {
WithAddress: "",
WithNetResolver: net.DefaultResolver,
withTargetId: "",
withRandomReader: rand.Reader,
}
}
@ -301,3 +305,10 @@ func WithTargetId(in string) Option {
o.withTargetId = in
}
}
// WithRandomReader provides an option to specify a random reader.
func WithRandomReader(reader io.Reader) Option {
return func(o *options) {
o.withRandomReader = reader
}
}

@ -5,6 +5,7 @@ package target
import (
"context"
"strings"
"testing"
"time"
@ -282,4 +283,12 @@ func Test_GetOpts(t *testing.T) {
testOpts.withTargetId = "testId"
assert.Equal(opts, testOpts)
})
t.Run("WithRandomReader", func(t *testing.T) {
assert := assert.New(t)
reader := strings.NewReader("notrandom")
opts := GetOpts(WithRandomReader(reader))
testOpts := getDefaultOptions()
testOpts.withRandomReader = reader
assert.Equal(opts, testOpts)
})
}

@ -7,6 +7,7 @@ import (
"context"
"database/sql"
"fmt"
"io"
"strings"
"time"
@ -48,7 +49,8 @@ type Repository struct {
// has access to in terms of actions and resources and we use it to build queries.
// These are passed in on the repository constructor using `WithPermissions`, meaning the
// `Repository` object is contextualized to whatever the request context is.
permissions []perms.Permission
permissions []perms.Permission
randomReader io.Reader
}
// NewRepository creates a new target Repository.
@ -86,6 +88,7 @@ func NewRepository(ctx context.Context, r db.Reader, w db.Writer, kms *kms.Kms,
kms: kms,
defaultLimit: opts.WithLimit,
permissions: opts.WithPermissions,
randomReader: opts.withRandomReader,
}, nil
}
@ -141,7 +144,7 @@ func (r *Repository) LookupTargetForSessionAuthorization(ctx context.Context, pu
}
if opts.WithAlias != nil {
cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds())
cert, err = fetchTargetAliasProxyServerCertificate(ctx, read, w, target.PublicId, target.ProjectId, opts.WithAlias, databaseWrapper, target.GetSessionMaxSeconds(), WithRandomReader(r.randomReader))
if err != nil && !errors.IsNotFoundError(err) {
return errors.Wrap(ctx, err, op)
}

@ -108,7 +108,7 @@ func fetchTargetProxyServerCertificate(ctx context.Context, r db.Reader, w db.Wr
return maybeRegenerateCert(ctx, targetCert, w, wrapper, sessionMaxSeconds)
}
func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w db.Writer, targetId, scopeId string, alias *talias.Alias, wrapper wrapping.Wrapper, sessionMaxSeconds uint32) (*ServerCertificate, error) {
func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w db.Writer, targetId, scopeId string, alias *talias.Alias, wrapper wrapping.Wrapper, sessionMaxSeconds uint32, opt ...Option) (*ServerCertificate, error) {
const op = "target.fetchTargetProxyServerCert"
switch {
case wrapper == nil:
@ -146,7 +146,7 @@ func fetchTargetAliasProxyServerCertificate(ctx context.Context, r db.Reader, w
// Create the cert, if not found- alias certs are not created as part of target creation.
var err error
if aliasCert.Certificate == nil {
aliasCert, err = NewTargetAliasProxyCertificate(ctx, targetId, alias)
aliasCert, err = NewTargetAliasProxyCertificate(ctx, targetId, alias, opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error creating new target alias proxy certificate"))
}

@ -5,6 +5,8 @@ package target
import (
"context"
"crypto/rand"
"strings"
"testing"
"time"
@ -22,6 +24,8 @@ func TestNewRepository(t *testing.T) {
rw := db.New(conn)
wrapper := db.TestWrapper(t)
testKms := kms.TestKms(t, conn, wrapper)
testReader := strings.NewReader("notrandom")
type args struct {
r db.Reader
w db.Writer
@ -47,6 +51,7 @@ func TestNewRepository(t *testing.T) {
writer: rw,
kms: testKms,
defaultLimit: db.DefaultLimit,
randomReader: rand.Reader,
},
wantErr: false,
},
@ -78,6 +83,9 @@ func TestNewRepository(t *testing.T) {
r: nil,
w: rw,
kms: testKms,
opts: []Option{
WithRandomReader(testReader),
},
},
want: nil,
wantErr: true,
@ -94,6 +102,7 @@ func TestNewRepository(t *testing.T) {
{GrantScopeId: "test1", Resource: resource.Target},
{GrantScopeId: "test2", Resource: resource.Target},
}),
WithRandomReader(testReader),
},
},
want: &Repository{
@ -105,6 +114,7 @@ func TestNewRepository(t *testing.T) {
{GrantScopeId: "test1", Resource: resource.Target},
{GrantScopeId: "test2", Resource: resource.Target},
},
randomReader: testReader,
},
wantErr: false,
},

@ -11,6 +11,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math"
"math/big"
"net"
@ -27,10 +28,10 @@ import (
"google.golang.org/protobuf/proto"
)
func generatePrivAndPubKeys(ctx context.Context) (privKeyBytes []byte, pubKeyBytes []byte, err error) {
func generatePrivAndPubKeys(ctx context.Context, randomReader io.Reader) (privKeyBytes []byte, pubKeyBytes []byte, err error) {
const op = "target.generatePrivAndPubKeys"
// Generate a private key using the P521 curve
key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
key, err := ecdsa.GenerateKey(elliptic.P521(), randomReader)
if err != nil {
return nil, nil, errors.New(ctx, errors.InvalidParameter, op, "failed to generate ECDSA key")
}
@ -62,7 +63,7 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time
opts := GetOpts(opt...)
randomSerialNumber, err := rand.Int(rand.Reader, big.NewInt(int64(math.MaxInt64)))
randomSerialNumber, err := rand.Int(opts.withRandomReader, big.NewInt(int64(math.MaxInt64)))
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating random serial number"))
}
@ -87,7 +88,7 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time
template.DNSNames = append(template.DNSNames, opts.WithAlias.Value)
}
certBytes, err := x509.CreateCertificate(rand.Reader, template, template, &privKey.PublicKey, privKey)
certBytes, err := x509.CreateCertificate(opts.withRandomReader, template, template, &privKey.PublicKey, privKey)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.GenCert))
}
@ -97,7 +98,9 @@ func generateTargetCert(ctx context.Context, privKey *ecdsa.PrivateKey, exp time
func generateKeysAndCert(ctx context.Context, notValidAfter time.Time, opt ...Option) (privKey []byte, pubKey []byte, cert []byte, err error) {
const op = "target.generateKeysAndCert"
privKey, pubKey, err = generatePrivAndPubKeys(ctx)
opts := GetOpts(opt...)
privKey, pubKey, err = generatePrivAndPubKeys(ctx, opts.withRandomReader)
if err != nil {
return nil, nil, nil, errors.Wrap(ctx, err, op)
}
@ -128,7 +131,7 @@ func NewTargetProxyCertificate(ctx context.Context, opt ...Option) (*TargetProxy
opts := GetOpts(opt...)
notValidAfter := time.Now().AddDate(1, 0, 0) // 1 year from now
privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter)
privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating target proxy cert and keys"))
}
@ -245,7 +248,7 @@ type TargetAliasProxyCertificate struct {
}
// NewTargetAliasProxyCertificate creates a new in memory TargetAliasProxyCertificate
func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias *talias.Alias) (*TargetAliasProxyCertificate, error) {
func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias *talias.Alias, opt ...Option) (*TargetAliasProxyCertificate, error) {
const op = "target.NewTargetAliasProxyCertificate"
switch {
case targetId == "":
@ -255,7 +258,9 @@ func NewTargetAliasProxyCertificate(ctx context.Context, targetId string, alias
}
notValidAfter := time.Now().AddDate(1, 0, 0) // 1 year from now
privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, WithAlias(alias))
opt = append(opt, WithAlias(alias))
privKey, pubKey, cert, err := generateKeysAndCert(ctx, notValidAfter, opt...)
if err != nil {
return nil, errors.Wrap(ctx, err, op, errors.WithMsg("error generating target proxy cert and keys"))
}

Loading…
Cancel
Save