diff --git a/internal/auth/password/argon2.go b/internal/auth/password/argon2.go index 933e7a0791..ec219f296e 100644 --- a/internal/auth/password/argon2.go +++ b/internal/auth/password/argon2.go @@ -11,7 +11,6 @@ import ( "github.com/hashicorp/boundary/internal/auth/password/store" "github.com/hashicorp/boundary/internal/errors" "github.com/hashicorp/boundary/internal/oplog" - random "github.com/hashicorp/boundary/internal/securerandom" wrapping "github.com/hashicorp/go-kms-wrapping/v2" "github.com/hashicorp/go-kms-wrapping/v2/extras/structwrapping" "golang.org/x/crypto/argon2" @@ -153,7 +152,7 @@ type Argon2Credential struct { tableName string } -func newArgon2Credential(ctx context.Context, accountId string, password string, conf *Argon2Configuration) (*Argon2Credential, error) { +func newArgon2Credential(ctx context.Context, accountId string, password string, conf *Argon2Configuration, opt ...Option) (*Argon2Credential, error) { const op = "password.newArgon2Credential" if accountId == "" { return nil, errors.New(ctx, errors.InvalidParameter, op, "missing accountId") @@ -165,6 +164,8 @@ func newArgon2Credential(ctx context.Context, accountId string, password string, return nil, errors.New(ctx, errors.InvalidParameter, op, "missing argon2 configuration") } + opts := GetOpts(opt...) + id, err := newArgon2CredentialId(ctx) if err != nil { return nil, errors.Wrap(ctx, err, op) @@ -181,7 +182,7 @@ func newArgon2Credential(ctx context.Context, accountId string, password string, // Generate a random salt salt := make([]byte, conf.SaltLength) - if _, err := io.ReadFull(random.SecureRandomReader(), salt); err != nil { + if _, err := io.ReadFull(opts.withRandomReader, salt); err != nil { return nil, errors.Wrap(ctx, err, op, errors.WithCode(errors.Io)) } c.Salt = salt diff --git a/internal/auth/password/options.go b/internal/auth/password/options.go index b07a53fdc0..591ae265bf 100644 --- a/internal/auth/password/options.go +++ b/internal/auth/password/options.go @@ -3,7 +3,12 @@ package password -import "github.com/hashicorp/boundary/internal/pagination" +import ( + "crypto/rand" + "io" + + "github.com/hashicorp/boundary/internal/pagination" +) // GetOpts - iterate the inbound Options and return a struct. func GetOpts(opt ...Option) options { @@ -30,11 +35,13 @@ type options struct { withOrderByCreateTime bool ascending bool withStartPageAfterItem pagination.Item + withRandomReader io.Reader } func getDefaultOptions() options { return options{ - withConfig: NewArgon2Configuration(), + withConfig: NewArgon2Configuration(), + withRandomReader: rand.Reader, } } @@ -99,6 +106,15 @@ func WithOrderByCreateTime(ascending bool) Option { } } +// WithRandomReader provides an option to specify a random reader for generating +// cryptographic salts and other random data. If not specified, crypto/rand.Reader +// will be used. This is primarily useful for testing purposes. +func WithRandomReader(reader io.Reader) Option { + return func(o *options) { + o.withRandomReader = reader + } +} + // WithStartPageAfterItem is used to paginate over the results. // The next page will start after the provided item. func WithStartPageAfterItem(item pagination.Item) Option { diff --git a/internal/auth/password/options_test.go b/internal/auth/password/options_test.go index e9b4bae739..633157749d 100644 --- a/internal/auth/password/options_test.go +++ b/internal/auth/password/options_test.go @@ -4,6 +4,8 @@ package password import ( + "crypto/rand" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -68,4 +70,18 @@ func Test_GetOpts(t *testing.T) { testOpts.ascending = true assert.Equal(opts, testOpts) }) + t.Run("WithRandomReader", func(t *testing.T) { + assert := assert.New(t) + customReader := strings.NewReader("test-random-data") + opts := GetOpts(WithRandomReader(customReader)) + testOpts := getDefaultOptions() + testOpts.withRandomReader = customReader + assert.Equal(opts, testOpts) + }) + t.Run("default-random-reader", func(t *testing.T) { + assert := assert.New(t) + opts := GetOpts() + // Verify the default is crypto/rand.Reader + assert.Equal(rand.Reader, opts.withRandomReader) + }) } diff --git a/internal/auth/password/repository_account.go b/internal/auth/password/repository_account.go index 63ab713406..ec4cd379f0 100644 --- a/internal/auth/password/repository_account.go +++ b/internal/auth/password/repository_account.go @@ -84,7 +84,7 @@ func (r *Repository) CreateAccount(ctx context.Context, scopeId string, a *Accou if cc.MinPasswordLength > len(opts.password) { return nil, errors.New(ctx, errors.PasswordTooShort, op, fmt.Sprintf("must be longer than %v", cc.MinPasswordLength)) } - if cred, err = newArgon2Credential(ctx, a.PublicId, opts.password, cc.argon2()); err != nil { + if cred, err = newArgon2Credential(ctx, a.PublicId, opts.password, cc.argon2(), WithRandomReader(opts.withRandomReader)); err != nil { return nil, errors.Wrap(ctx, err, op) } }