From 2ef6b9247ddfcd17762fd9eac0e52bbedb2981ab Mon Sep 17 00:00:00 2001 From: Stephen Fox Date: Mon, 25 Feb 2019 17:16:24 -0500 Subject: [PATCH] Initial take on code review feedback from @azr. Do not use builder pattern or interfaces; stick to structs and some basic functions. --- .../virtualbox/common/step_ssh_key_pair.go | 32 +- helper/ssh/dsa_key_pair.go | 38 -- helper/ssh/ecdsa_key_pair.go | 38 -- helper/ssh/ed25519_key_pair.go | 38 -- helper/ssh/key_pair.go | 323 ++++------ helper/ssh/key_pair_test.go | 568 +++++++----------- helper/ssh/rsa_key_pair.go | 38 -- 7 files changed, 357 insertions(+), 718 deletions(-) delete mode 100644 helper/ssh/dsa_key_pair.go delete mode 100644 helper/ssh/ecdsa_key_pair.go delete mode 100644 helper/ssh/ed25519_key_pair.go delete mode 100644 helper/ssh/rsa_key_pair.go diff --git a/builder/virtualbox/common/step_ssh_key_pair.go b/builder/virtualbox/common/step_ssh_key_pair.go index 167030deb..20de4655b 100644 --- a/builder/virtualbox/common/step_ssh_key_pair.go +++ b/builder/virtualbox/common/step_ssh_key_pair.go @@ -36,19 +36,19 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis return multistep.ActionHalt } - kp, err := ssh.NewKeyPairBuilder(). - SetPrivateKey(privateKeyBytes). - SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())). - Build() + kp, err := ssh.KeyPairFromPrivateKey(ssh.FromPrivateKeyConfig{ + RawPrivateKeyPemBlock: privateKeyBytes, + Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()), + }) if err != nil { state.Put("error", err) return multistep.ActionHalt } s.Comm.SSHPrivateKey = privateKeyBytes - s.Comm.SSHKeyPairName = kp.Name() - s.Comm.SSHTemporaryKeyPairName = kp.Name() - s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine) + s.Comm.SSHKeyPairName = kp.Name + s.Comm.SSHTemporaryKeyPairName = kp.Name + s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine return multistep.ActionContinue } @@ -60,21 +60,21 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis ui.Say("Creating ephemeral key pair for SSH communicator...") - kp, err := ssh.NewKeyPairBuilder(). - SetName(fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID())). - Build() + kp, err := ssh.NewKeyPair(ssh.CreateKeyPairConfig{ + Name: fmt.Sprintf("packer_%s", uuid.TimeOrderedUUID()), + }) if err != nil { state.Put("error", fmt.Errorf("Error creating temporary keypair: %s", err)) return multistep.ActionHalt } - s.Comm.SSHKeyPairName = kp.Name() - s.Comm.SSHTemporaryKeyPairName = kp.Name() - s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock() - s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine(ssh.UnixNewLine) + s.Comm.SSHKeyPairName = kp.Name + s.Comm.SSHTemporaryKeyPairName = kp.Name + s.Comm.SSHPrivateKey = kp.PrivateKeyPemBlock + s.Comm.SSHPublicKey = kp.PublicKeyAuthorizedKeysLine s.Comm.SSHClearAuthorizedKeys = true - ui.Say(fmt.Sprintf("Created ephemeral SSH key pair of type %s", kp.Description())) + ui.Say("Created ephemeral SSH key pair for communicator") // If we're in debug mode, output the private key to the working // directory. @@ -90,7 +90,7 @@ func (s *StepSshKeyPair) Run(_ context.Context, state multistep.StateBag) multis defer f.Close() // Write the key out - if _, err := f.Write(kp.PrivateKeyPemBlock()); err != nil { + if _, err := f.Write(kp.PrivateKeyPemBlock); err != nil { state.Put("error", fmt.Errorf("Error saving debug key: %s", err)) return multistep.ActionHalt } diff --git a/helper/ssh/dsa_key_pair.go b/helper/ssh/dsa_key_pair.go deleted file mode 100644 index 01f0c7007..000000000 --- a/helper/ssh/dsa_key_pair.go +++ /dev/null @@ -1,38 +0,0 @@ -package ssh - -import ( - "crypto/dsa" - - gossh "golang.org/x/crypto/ssh" -) - -type dsaKeyPair struct { - privateKey *dsa.PrivateKey - publicKey gossh.PublicKey - name string - privatePemBlock []byte -} - -func (o dsaKeyPair) Type() KeyPairType { - return Dsa -} - -func (o dsaKeyPair) Bits() int { - return 1024 -} - -func (o dsaKeyPair) Name() string { - return o.name -} - -func (o dsaKeyPair) Description() string { - return description(o) -} - -func (o dsaKeyPair) PrivateKeyPemBlock() []byte { - return o.privatePemBlock -} - -func (o dsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte { - return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl) -} diff --git a/helper/ssh/ecdsa_key_pair.go b/helper/ssh/ecdsa_key_pair.go deleted file mode 100644 index 9f3dde341..000000000 --- a/helper/ssh/ecdsa_key_pair.go +++ /dev/null @@ -1,38 +0,0 @@ -package ssh - -import ( - "crypto/ecdsa" - - gossh "golang.org/x/crypto/ssh" -) - -type ecdsaKeyPair struct { - privateKey *ecdsa.PrivateKey - publicKey gossh.PublicKey - name string - privatePemBlock []byte -} - -func (o ecdsaKeyPair) Type() KeyPairType { - return Ecdsa -} - -func (o ecdsaKeyPair) Bits() int { - return o.privateKey.Curve.Params().BitSize -} - -func (o ecdsaKeyPair) Name() string { - return o.name -} - -func (o ecdsaKeyPair) Description() string { - return description(o) -} - -func (o ecdsaKeyPair) PrivateKeyPemBlock() []byte { - return o.privatePemBlock -} - -func (o ecdsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte { - return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl) -} diff --git a/helper/ssh/ed25519_key_pair.go b/helper/ssh/ed25519_key_pair.go deleted file mode 100644 index 91a0a3654..000000000 --- a/helper/ssh/ed25519_key_pair.go +++ /dev/null @@ -1,38 +0,0 @@ -package ssh - -import ( - "golang.org/x/crypto/ed25519" - - gossh "golang.org/x/crypto/ssh" -) - -type ed25519KeyPair struct { - privateKey *ed25519.PrivateKey - publicKey gossh.PublicKey - name string - privatePemBlock []byte -} - -func (o ed25519KeyPair) Type() KeyPairType { - return Ed25519 -} - -func (o ed25519KeyPair) Bits() int { - return 256 -} - -func (o ed25519KeyPair) Name() string { - return o.name -} - -func (o ed25519KeyPair) Description() string { - return description(o) -} - -func (o ed25519KeyPair) PrivateKeyPemBlock() []byte { - return o.privatePemBlock -} - -func (o ed25519KeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte { - return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl) -} diff --git a/helper/ssh/key_pair.go b/helper/ssh/key_pair.go index 60c2c22e5..11f6b6386 100644 --- a/helper/ssh/key_pair.go +++ b/helper/ssh/key_pair.go @@ -9,9 +9,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" - "errors" "fmt" - "strconv" "strings" "golang.org/x/crypto/ed25519" @@ -38,175 +36,130 @@ func (o KeyPairType) String() string { return string(o) } -const ( - // UnixNewLine is a unix new line. - UnixNewLine NewLineOption = "\n" - - // WindowsNewLine is a Windows new line. - WindowsNewLine NewLineOption = "\r\n" - - // NoNewLine will not append a new line. - NoNewLine NewLineOption = "" -) +// CreateKeyPairConfig describes how an SSH key pair should be created. +type CreateKeyPairConfig struct { + // Type describes the key pair's type. + Type KeyPairType -// NewLineOption specifies the type of new line to append to a string. -// See the 'const' block for choices. -type NewLineOption string - -func (o NewLineOption) String() string { - return string(o) -} + // Bits represents the key pair's bits of entropy. E.g., 4096 for + // a 4096 bit RSA key pair, or 521 for a ECDSA key pair with a + // 521-bit curve. + Bits int -func (o NewLineOption) Bytes() []byte { - return []byte(o) + // Name is the resulting key pair's name. This is used to identify + // the key pair in the SSH server's 'authorized_keys'. + Name string } -// KeyPairBuilder builds SSH key pairs. -// It can generate new keys of type RSA and ECDSA. -// It can parse user supplied keys of type DSA, RSA, ECDSA, -// and ED25519. -type KeyPairBuilder interface { - // SetType sets the key pair type. - SetType(KeyPairType) KeyPairBuilder - - // SetBits sets the key pair's bits of entropy. - SetBits(int) KeyPairBuilder - - // SetName sets the name of the key pair. This is primarily - // used to identify the public key in the authorized_keys file. - SetName(string) KeyPairBuilder - - // SetPrivateKey takes an existing private key in PEM format. - // It overrides key generation details specified by SetType() - // and SetBits(). - SetPrivateKey([]byte) KeyPairBuilder - - // Build returns a SSH key pair. - // - // The following defaults are used if not specified: - // Default type: ECDSA - // Default bits of entropy: - // - RSA: 4096 - // - ECDSA: 521 - // Default name: (empty string) - Build() (KeyPair, error) -} - -type defaultKeyPairBuilder struct { - // kind describes the resulting key pair's type. - kind KeyPairType - - // bits is the resulting key pair's bits of entropy. - bits int - - // name is the resulting key pair's name. - name string - - // privatePemBytes is the supplied key data when the builder - // is working from a preallocated key. - privatePemBytes []byte -} +// FromPrivateKeyConfig describes how an SSH key pair should be loaded from an +// existing private key. +type FromPrivateKeyConfig struct { + // RawPrivateKeyPemBlock is the raw private key that the key pair + // should be loaded from. + RawPrivateKeyPemBlock []byte -func (o *defaultKeyPairBuilder) SetType(kind KeyPairType) KeyPairBuilder { - o.kind = kind - return o + // Name is the resulting key pair's name. This is used to identify + // the key pair in the SSH server's 'authorized_keys'. + Name string } -func (o *defaultKeyPairBuilder) SetBits(bits int) KeyPairBuilder { - o.bits = bits - return o +// KeyPair represents an SSH key pair. +// TODO: Maybe a field for a description? Maybe save the type? +type KeyPair struct { + // PrivateKeyPemBlock represents the key pair's private key in + // ASN.1 Distinguished Encoding Rules (DER) format in a + // Privacy-Enhanced Mail (PEM) block. + PrivateKeyPemBlock []byte + + // PublicKeyAuthorizedKeysLine represents the key pair's public key + // as a line in OpenSSH authorized_keys. + PublicKeyAuthorizedKeysLine []byte + + // Name is the key pair's name. This is used to identify + // the key pair in the SSH server's 'authorized_keys'. + Name string } -func (o *defaultKeyPairBuilder) SetName(name string) KeyPairBuilder { - o.name = name - return o -} - -func (o *defaultKeyPairBuilder) SetPrivateKey(privateBytes []byte) KeyPairBuilder { - o.privatePemBytes = privateBytes - return o -} - -func (o *defaultKeyPairBuilder) Build() (KeyPair, error) { - if o.privatePemBytes != nil { - return o.preallocatedKeyPair() - } - - switch o.kind { - case Rsa: - return o.newRsaKeyPair() - case Ecdsa, Default: - return o.newEcdsaKeyPair() - } - - return nil, fmt.Errorf("Cannot generate SSH key pair - unsupported key pair type: %s", o.kind.String()) -} - -// preallocatedKeyPair returns an SSH key pair based on user -// supplied PEM data. -func (o *defaultKeyPairBuilder) preallocatedKeyPair() (KeyPair, error) { - privateKey, err := gossh.ParseRawPrivateKey(o.privatePemBytes) +// KeyPairFromPrivateKey returns a KeyPair loaded from an existing private key. +// +// Supported key pair types include: +// - DSA +// - ECDSA +// - ED25519 +// - RSA +func KeyPairFromPrivateKey(config FromPrivateKeyConfig) (KeyPair, error) { + privateKey, err := gossh.ParseRawPrivateKey(config.RawPrivateKeyPemBlock) if err != nil { - return nil, err + return KeyPair{}, err } switch pk := privateKey.(type) { case *rsa.PrivateKey: publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - return nil, err + return KeyPair{}, err } - return &rsaKeyPair{ - privateKey: pk, - publicKey: publicKey, - name: o.name, - privatePemBlock: o.privatePemBytes, + return KeyPair{ + PrivateKeyPemBlock: config.RawPrivateKeyPemBlock, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name), }, nil case *ecdsa.PrivateKey: publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - return nil, err + return KeyPair{}, err } - return &ecdsaKeyPair{ - privateKey: pk, - publicKey: publicKey, - name: o.name, - privatePemBlock: o.privatePemBytes, + return KeyPair{ + PrivateKeyPemBlock: config.RawPrivateKeyPemBlock, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name), }, nil case *dsa.PrivateKey: publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - return nil, err + return KeyPair{}, err } - return &dsaKeyPair{ - privateKey: pk, - publicKey: publicKey, - name: o.name, - privatePemBlock: o.privatePemBytes, + return KeyPair{ + PrivateKeyPemBlock: config.RawPrivateKeyPemBlock, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name), }, nil case *ed25519.PrivateKey: publicKey, err := gossh.NewPublicKey(pk.Public()) if err != nil { - return nil, err + return KeyPair{}, err } - return &ed25519KeyPair{ - privateKey: pk, - publicKey: publicKey, - name: o.name, - privatePemBlock: o.privatePemBytes, + return KeyPair{ + PrivateKeyPemBlock: config.RawPrivateKeyPemBlock, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(publicKey, config.Name), }, nil } - return nil, fmt.Errorf("Cannot parse preallocated key pair - unknown ssh key pair type") + return KeyPair{}, fmt.Errorf("Cannot parse existing SSH key pair - unknown key pair type") +} + +// NewKeyPair generates a new SSH key pair using the specified +// CreateKeyPairConfig. +func NewKeyPair(config CreateKeyPairConfig) (KeyPair, error) { + if config.Type == Default { + config.Type = Ecdsa + } + + switch config.Type { + case Ecdsa: + return newEcdsaKeyPair(config) + case Rsa: + return newRsaKeyPair(config) + } + + return KeyPair{}, fmt.Errorf("Unable to generate new key pair, type %s is not supported", + config.Type.String()) } // newEcdsaKeyPair returns a new ECDSA SSH key pair. -func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) { +func newEcdsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) { var curve elliptic.Curve - switch o.bits { + switch config.Bits { case 0: - o.bits = 521 + config.Bits = 521 fallthrough case 521: curve = elliptic.P521() @@ -216,26 +169,24 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) { curve = elliptic.P256() case 224: // Not supported by "golang.org/x/crypto/ssh". - return &ecdsaKeyPair{}, errors.New("golang.org/x/crypto/ssh does not support " + - strconv.Itoa(o.bits) + " bits") + return KeyPair{}, fmt.Errorf("golang.org/x/crypto/ssh does not support %d bits", config.Bits) default: - return &ecdsaKeyPair{}, errors.New("crypto/elliptic does not support " + - strconv.Itoa(o.bits) + " bits") + return KeyPair{}, fmt.Errorf("crypto/elliptic does not support %d bits", config.Bits) } privateKey, err := ecdsa.GenerateKey(curve, rand.Reader) if err != nil { - return &ecdsaKeyPair{}, err + return KeyPair{}, err } sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) if err != nil { - return &ecdsaKeyPair{}, err + return KeyPair{}, err } privateRaw, err := x509.MarshalECPrivateKey(privateKey) if err != nil { - return &ecdsaKeyPair{}, err + return KeyPair{}, err } privatePem, err := rawPemBlock(&pem.Block{ @@ -244,31 +195,30 @@ func (o *defaultKeyPairBuilder) newEcdsaKeyPair() (KeyPair, error) { Bytes: privateRaw, }) if err != nil { - return &ecdsaKeyPair{}, err + return KeyPair{}, err } - return &ecdsaKeyPair{ - privateKey: privateKey, - publicKey: sshPublicKey, - name: o.name, - privatePemBlock: privatePem, + return KeyPair{ + PrivateKeyPemBlock: privatePem, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name), + Name: config.Name, }, nil } // newRsaKeyPair returns a new RSA SSH key pair. -func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) { - if o.bits == 0 { - o.bits = defaultRsaBits +func newRsaKeyPair(config CreateKeyPairConfig) (KeyPair, error) { + if config.Bits == 0 { + config.Bits = defaultRsaBits } - privateKey, err := rsa.GenerateKey(rand.Reader, o.bits) + privateKey, err := rsa.GenerateKey(rand.Reader, config.Bits) if err != nil { - return &rsaKeyPair{}, err + return KeyPair{}, err } sshPublicKey, err := gossh.NewPublicKey(&privateKey.PublicKey) if err != nil { - return &rsaKeyPair{}, err + return KeyPair{}, err } privatePemBlock, err := rawPemBlock(&pem.Block{ @@ -277,48 +227,16 @@ func (o *defaultKeyPairBuilder) newRsaKeyPair() (KeyPair, error) { Bytes: x509.MarshalPKCS1PrivateKey(privateKey), }) if err != nil { - return &rsaKeyPair{}, err + return KeyPair{}, err } - return &rsaKeyPair{ - privateKey: privateKey, - publicKey: sshPublicKey, - name: o.name, - privatePemBlock: privatePemBlock, + return KeyPair{ + PrivateKeyPemBlock: privatePemBlock, + PublicKeyAuthorizedKeysLine: authorizedKeysLine(sshPublicKey, config.Name), + Name: config.Name, }, nil } -// KeyPair represents a SSH key pair. -type KeyPair interface { - // Type returns the key pair's type. - Type() KeyPairType - - // Bits returns the bits of entropy. - Bits() int - - // Name returns the key pair's name. An empty string is - // returned if no name was specified. - Name() string - - // Description returns a brief description of the key pair that - // is suitable for log messages or printing. - Description() string - - // PrivateKeyPemBlock returns a slice of bytes representing - // the private key in ASN.1 Distinguished Encoding Rules (DER) - // format in a Privacy-Enhanced Mail (PEM) block. - PrivateKeyPemBlock() []byte - - // PublicKeyAuthorizedKeysLine returns a slice of bytes - // representing the public key as a line in OpenSSH authorized_keys - // format with the specified new line. - PublicKeyAuthorizedKeysLine(NewLineOption) []byte -} - -func NewKeyPairBuilder() KeyPairBuilder { - return &defaultKeyPairBuilder{} -} - // rawPemBlock encodes a pem.Block to a slice of bytes. func rawPemBlock(block *pem.Block) ([]byte, error) { buffer := bytes.NewBuffer(nil) @@ -331,26 +249,10 @@ func rawPemBlock(block *pem.Block) ([]byte, error) { return buffer.Bytes(), nil } -// description returns a string describing a key pair. -func description(kp KeyPair) string { - buffer := bytes.NewBuffer(nil) - - buffer.WriteString(strconv.Itoa(kp.Bits())) - buffer.WriteString(" bit ") - buffer.WriteString(kp.Type().String()) - - if len(kp.Name()) > 0 { - buffer.WriteString(" named ") - buffer.WriteString(kp.Name()) - } - - return buffer.String() -} - -// publicKeyAuthorizedKeysLine returns a slice of bytes representing a SSH -// public key as a line in OpenSSH authorized_keys format. -func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewLineOption) []byte { - result := gossh.MarshalAuthorizedKey(publicKey) +// authorizedKeysLine returns a slice of bytes representing an SSH public key +// as a line in OpenSSH authorized_keys format. No line break is appended. +func authorizedKeysLine(sshPublicKey gossh.PublicKey, name string) []byte { + result := gossh.MarshalAuthorizedKey(sshPublicKey) // Remove the mandatory unix new line. // Awful, but the go ssh library automatically appends @@ -362,14 +264,5 @@ func publicKeyAuthorizedKeysLine(publicKey gossh.PublicKey, name string, nl NewL result = append(result, name...) } - switch nl { - case WindowsNewLine: - result = append(result, nl.Bytes()...) - case UnixNewLine: - // This is how all the other "SSH key pair" code works in - // the different builders. - result = append(result, UnixNewLine.Bytes()...) - } - return result } diff --git a/helper/ssh/key_pair_test.go b/helper/ssh/key_pair_test.go index 5a6e23b6c..eddb9a3cb 100644 --- a/helper/ssh/key_pair_test.go +++ b/helper/ssh/key_pair_test.go @@ -2,24 +2,17 @@ package ssh import ( "bytes" - "crypto/rand" - "errors" - "strconv" + "crypto/dsa" + "crypto/ecdsa" + "crypto/rsa" + "fmt" "testing" "github.com/hashicorp/packer/common/uuid" + "golang.org/x/crypto/ed25519" gossh "golang.org/x/crypto/ssh" ) -// expected contains the data that the key pair should contain. -type expected struct { - kind KeyPairType - bits int - desc string - name string - data []byte -} - const ( pemRsa1024 = `-----BEGIN RSA PRIVATE KEY----- MIICWwIBAAKBgQDJEMFPpTBiWNDb3qEIPTSeEnIP8FZdBpG8njOrclcMoQQNhzZ+ @@ -151,401 +144,306 @@ QBAgM= ` ) -func (o expected) matches(kp KeyPair) error { - if o.kind.String() == "" { - return errors.New("expected kind's value cannot be empty") - } - - if o.bits <= 0 { - return errors.New("expected bits' value cannot be less than or equal to 0") - } - - if o.desc == "" { - return errors.New("expected description's value cannot be empty") +func TestNewKeyPair_Default(t *testing.T) { + kp, err := NewKeyPair(CreateKeyPairConfig{}) + if err != nil { + t.Fatal(err.Error()) } - if len(o.data) == 0 { - return errors.New("expected random data value cannot be nothing") + err = verifyEcdsaKeyPair(kp, expectedData{ + bits: 521, + }) + if err != nil { + t.Fatal(err.Error()) } +} - if kp.Type() != o.kind { - return errors.New("key pair type should be " + o.kind.String() + - " - got '" + kp.Type().String() + "'") +func TestNewKeyPair_ECDSA_Default(t *testing.T) { + kp, err := NewKeyPair(CreateKeyPairConfig{ + Type: Ecdsa, + }) + if err != nil { + t.Fatal(err.Error()) } - if kp.Bits() != o.bits { - return errors.New("key pair bits should be " + strconv.Itoa(o.bits) + - " - got " + strconv.Itoa(kp.Bits())) + err = verifyEcdsaKeyPair(kp, expectedData{ + bits: 521, + }) + if err != nil { + t.Fatal(err.Error()) } +} - if len(o.name) > 0 && kp.Name() != o.name { - return errors.New("key pair name should be '" + o.name + - "' - got '" + kp.Name() + "'") - } +func TestNewKeyPair_ECDSA_Positive(t *testing.T) { + for _, bits := range []int{521, 384, 256} { + config := CreateKeyPairConfig{ + Type: Ecdsa, + Bits: bits, + Name: uuid.TimeOrderedUUID(), + } - if kp.Description() != o.desc { - return errors.New("key pair description should be '" + - o.desc + "' - got '" + kp.Description() + "'") - } + kp, err := NewKeyPair(config) + if err != nil { + t.Fatal(err.Error()) + } - err := o.verifyPublicKeyAuthorizedKeysFormat(kp) - if err != nil { - return err + err = verifyEcdsaKeyPair(kp, expectedData{ + bits: bits, + name: config.Name, + }) + if err != nil { + t.Fatal(err.Error()) + } } +} - err = o.verifyKeyPair(kp) - if err != nil { - return err +func TestNewKeyPair_ECDSA_Negative(t *testing.T) { + for _, bits := range []int{224, 1, 2, 3} { + _, err := NewKeyPair(CreateKeyPairConfig{ + Type: Ecdsa, + Bits: bits, + }) + if err == nil { + t.Fatalf("expected key pair generation to fail for %d bits", bits) + } } - - return nil } -func (o expected) verifyPublicKeyAuthorizedKeysFormat(kp KeyPair) error { - newLines := []NewLineOption{ - UnixNewLine, - NoNewLine, - WindowsNewLine, - } +func TestNewKeyPair_RSA_Positive(t *testing.T) { + for _, bits := range []int{4096, 2048} { + config := CreateKeyPairConfig{ + Type: Rsa, + Bits: bits, + Name: uuid.TimeOrderedUUID(), + } - for _, nl := range newLines { - publicKeyAk := kp.PublicKeyAuthorizedKeysLine(nl) + kp, err := NewKeyPair(config) + if err != nil { + t.Fatal(err.Error()) + } - if len(publicKeyAk) < 2 { - return errors.New("expected public key in authorized keys format to be at least 2 bytes") + err = verifyRsaKeyPair(kp, expectedData{ + bits: config.Bits, + name: config.Name, + }) + if err != nil { + t.Fatal(err.Error()) } + } +} - switch nl { - case NoNewLine: - if publicKeyAk[len(publicKeyAk)-1] == '\n' { - return errors.New("public key in authorized keys format has trailing new line when none was specified") - } - case UnixNewLine: - if publicKeyAk[len(publicKeyAk)-1] != '\n' { - return errors.New("public key in authorized keys format does not have unix new line when unix was specified") - } - if string(publicKeyAk[len(publicKeyAk)-2:]) == WindowsNewLine.String() { - return errors.New("public key in authorized keys format has windows new line when unix was specified") - } - case WindowsNewLine: - if string(publicKeyAk[len(publicKeyAk)-2:]) != WindowsNewLine.String() { - return errors.New("public key in authorized keys format does not have windows new line when windows was specified") - } +func TestKeyPairFromPrivateKey(t *testing.T) { + m := map[string]fromPrivateExpectedData{ + pemRsa1024: { + t: Rsa, + d: expectedData{ + bits: 1024, + }, + }, + pemRsa2048: { + t: Rsa, + d: expectedData{ + bits: 2048, + }, + }, + pemOpenSshRsa1024: { + t: Rsa, + d: expectedData{ + bits: 1024, + }, + }, + pemOpenSshRsa2048: { + t: Rsa, + d: expectedData{ + bits: 2048, + }, + }, + pemDsa: { + t: Dsa, + d: expectedData{ + bits: 1024, + }, + }, + pemEcdsa384: { + t: Ecdsa, + d: expectedData{ + bits: 384, + }, + }, + pemEcdsa521: { + t: Ecdsa, + d: expectedData{ + bits: 521, + }, + }, + pemOpenSshEd25519: { + t: Ed25519, + d: expectedData{ + bits: 256, + }, + }, + } + + for rawPrivateKey, expected := range m { + kp, err := KeyPairFromPrivateKey(FromPrivateKeyConfig{ + RawPrivateKeyPemBlock: []byte(rawPrivateKey), + }) + if err != nil { + t.Fatal(err.Error()) } - if len(o.name) > 0 { - if len(publicKeyAk) < len(o.name) { - return errors.New("public key in authorized keys format is shorter than the key pair's name") - } - - suffix := []byte{' '} - suffix = append(suffix, o.name...) - suffix = append(suffix, nl.Bytes()...) - if !bytes.HasSuffix(publicKeyAk, suffix) { - return errors.New("public key in authorized keys format with name does not have name in suffix - got '" + - string(publicKeyAk) + "'") - } + switch expected.t { + case Dsa: + err = verifyDsaKeyPair(kp, expected) + case Ecdsa: + err = verifyEcdsaKeyPair(kp, expected.d) + case Ed25519: + err = verifyEd25519KeyPair(kp, expected) + case Rsa: + err = verifyRsaKeyPair(kp, expected.d) + default: + err = fmt.Errorf("unexected SSH key pair type %s", expected.t.String()) + } + if err != nil { + t.Fatal(err.Error()) } } +} - return nil +type fromPrivateExpectedData struct { + t KeyPairType + d expectedData } -func (o expected) verifyKeyPair(kp KeyPair) error { - signer, err := gossh.ParsePrivateKey(kp.PrivateKeyPemBlock()) +type expectedData struct { + bits int + name string +} + +func verifyEcdsaKeyPair(kp KeyPair, e expectedData) error { + privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock) if err != nil { - return errors.New("failed to parse private key during verification - " + err.Error()) + return err } - signature, err := signer.Sign(rand.Reader, o.data) - if err != nil { - return errors.New("failed to sign test data during verification - " + err.Error()) + pk, ok := privateKey.(*ecdsa.PrivateKey) + if !ok { + return fmt.Errorf("private key should be *ecdsa.PrivateKey") } - err = signer.PublicKey().Verify(o.data, signature) + if pk.Curve.Params().BitSize != e.bits { + return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.Curve.Params().BitSize) + } + + publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - return errors.New("failed to verify test data - " + err.Error()) + return err + } + + expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n")) + if len(e.name) > 0 { + expectedBytes = append(expectedBytes, ' ') + expectedBytes = append(expectedBytes, e.name...) + } + + if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) { + return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'", + string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine)) } return nil } -func TestDefaultKeyPairBuilder_Build_Default(t *testing.T) { - kp, err := NewKeyPairBuilder().Build() +func verifyRsaKeyPair(kp KeyPair, e expectedData) error { + privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock) if err != nil { - t.Fatal(err.Error()) + return err } - err = expected{ - kind: Ecdsa, - bits: 521, - desc: "521 bit ECDSA", - data: []byte(uuid.TimeOrderedUUID()), - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) + pk, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return fmt.Errorf("private key should be *rsa.PrivateKey") } -} -func TestDefaultKeyPairBuilder_Build_EcdsaDefault(t *testing.T) { - kp, err := NewKeyPairBuilder(). - SetType(Ecdsa). - Build() - if err != nil { - t.Fatal(err.Error()) + if pk.N.BitLen() != e.bits { + return fmt.Errorf("bit size should be %d - got %d", e.bits, pk.N.BitLen()) } - err = expected{ - kind: Ecdsa, - bits: 521, - desc: "521 bit ECDSA", - data: []byte(uuid.TimeOrderedUUID()), - }.matches(kp) + publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - t.Fatal(err.Error()) + return err } -} -func TestDefaultKeyPairBuilder_Build_EcdsaSupportedCurves(t *testing.T) { - supportedBits := []int{ - 521, - 384, - 256, + expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n")) + if len(e.name) > 0 { + expectedBytes = append(expectedBytes, ' ') + expectedBytes = append(expectedBytes, e.name...) } - for _, bits := range supportedBits { - kp, err := NewKeyPairBuilder(). - SetType(Ecdsa). - SetBits(bits). - Build() - if err != nil { - t.Fatal(err.Error()) - } - - err = expected{ - kind: Ecdsa, - bits: bits, - desc: strconv.Itoa(bits) + " bit ECDSA", - data: []byte(uuid.TimeOrderedUUID()), - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) - } + if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) { + return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'", + string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine)) } + + return nil } -func TestDefaultKeyPairBuilder_Build_RsaDefault(t *testing.T) { - kp, err := NewKeyPairBuilder(). - SetType(Rsa). - Build() +func verifyDsaKeyPair(kp KeyPair, e fromPrivateExpectedData) error { + privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock) if err != nil { - t.Fatal(err.Error()) + return err } - err = expected{ - kind: Rsa, - bits: 4096, - desc: "4096 bit RSA", - data: []byte(uuid.TimeOrderedUUID()), - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) + pk, ok := privateKey.(*dsa.PrivateKey) + if !ok { + return fmt.Errorf("private key should be *rsa.PrivateKey") } -} -func TestDefaultKeyPairBuilder_Build_NamedEcdsa(t *testing.T) { - name := uuid.TimeOrderedUUID() - - kp, err := NewKeyPairBuilder(). - SetType(Ecdsa). - SetName(name). - Build() + publicKey, err := gossh.NewPublicKey(&pk.PublicKey) if err != nil { - t.Fatal(err.Error()) + return err } - err = expected{ - kind: Ecdsa, - bits: 521, - desc: "521 bit ECDSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - name: name, - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) + expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n")) + if len(e.d.name) > 0 { + expectedBytes = append(expectedBytes, ' ') + expectedBytes = append(expectedBytes, e.d.name...) } -} - -func TestDefaultKeyPairBuilder_Build_NamedRsa(t *testing.T) { - name := uuid.TimeOrderedUUID() - kp, err := NewKeyPairBuilder(). - SetType(Rsa). - SetName(name). - Build() - if err != nil { - t.Fatal(err.Error()) + if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) { + return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'", + string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine)) } - err = expected{ - kind: Rsa, - bits: 4096, - desc: "4096 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - name: name, - }.matches(kp) - if err != nil { - t.Fatal(err.Error()) - } + return nil } -func TestDefaultKeyPairBuilder_SetPrivateKey(t *testing.T) { - name := uuid.TimeOrderedUUID() - pemData := make(map[string]expected) - pemData[pemRsa1024] = expected{ - bits: 1024, - kind: Rsa, - name: name, - desc: "1024 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemRsa2048] = expected{ - bits: 2048, - kind: Rsa, - name: name, - desc: "2048 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemOpenSshRsa1024] = expected{ - bits: 1024, - kind: Rsa, - name: name, - desc: "1024 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemOpenSshRsa2048] = expected{ - bits: 2048, - kind: Rsa, - name: name, - desc: "2048 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemDsa] = expected{ - bits: 1024, - kind: Dsa, - name: name, - desc: "1024 bit DSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemEcdsa384] = expected{ - bits: 384, - kind: Ecdsa, - name: name, - desc: "384 bit ECDSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemEcdsa521] = expected{ - bits: 521, - kind: Ecdsa, - name: name, - desc: "521 bit ECDSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemOpenSshEd25519] = expected{ - bits: 256, - kind: Ed25519, - name: name, - desc: "256 bit ED25519 named " + name, - data: []byte(uuid.TimeOrderedUUID()), +func verifyEd25519KeyPair(kp KeyPair, e fromPrivateExpectedData) error { + privateKey, err := gossh.ParseRawPrivateKey(kp.PrivateKeyPemBlock) + if err != nil { + return err } - for s, l := range pemData { - kp, err := NewKeyPairBuilder().SetPrivateKey([]byte(s)).SetName(name).Build() - if err != nil { - t.Fatal(err) - } - err = l.matches(kp) - if err != nil { - t.Fatal(err) - } + pk, ok := privateKey.(*ed25519.PrivateKey) + if !ok { + return fmt.Errorf("private key should be *rsa.PrivateKey") } -} -func TestDefaultKeyPairBuilder_SetPrivateKey_Override(t *testing.T) { - name := uuid.TimeOrderedUUID() - pemData := make(map[string]expected) - pemData[pemRsa1024] = expected{ - bits: 1024, - kind: Rsa, - name: name, - desc: "1024 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemRsa2048] = expected{ - bits: 2048, - kind: Rsa, - name: name, - desc: "2048 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemOpenSshRsa1024] = expected{ - bits: 1024, - kind: Rsa, - name: name, - desc: "1024 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemOpenSshRsa2048] = expected{ - bits: 2048, - kind: Rsa, - name: name, - desc: "2048 bit RSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemDsa] = expected{ - bits: 1024, - kind: Dsa, - name: name, - desc: "1024 bit DSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemEcdsa384] = expected{ - bits: 384, - kind: Ecdsa, - name: name, - desc: "384 bit ECDSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), - } - pemData[pemEcdsa521] = expected{ - bits: 521, - kind: Ecdsa, - name: name, - desc: "521 bit ECDSA named " + name, - data: []byte(uuid.TimeOrderedUUID()), + publicKey, err := gossh.NewPublicKey(pk.Public()) + if err != nil { + return err } - pemData[pemOpenSshEd25519] = expected{ - bits: 256, - kind: Ed25519, - name: name, - desc: "256 bit ED25519 named " + name, - data: []byte(uuid.TimeOrderedUUID()), + + expectedBytes := bytes.TrimSuffix(gossh.MarshalAuthorizedKey(publicKey), []byte("\n")) + if len(e.d.name) > 0 { + expectedBytes = append(expectedBytes, ' ') + expectedBytes = append(expectedBytes, e.d.name...) } - supportedKeyTypes := []KeyPairType{Rsa, Dsa} - for _, keyType := range supportedKeyTypes { - for pemString, expectedResult := range pemData { - kp, err := NewKeyPairBuilder(). - SetPrivateKey([]byte(pemString)). - SetName(name). - SetType(keyType). - Build() - if err != nil { - t.Fatal(err) - } - err = expectedResult.matches(kp) - if err != nil { - t.Fatal(err) - } - } + if !bytes.Equal(expectedBytes, kp.PublicKeyAuthorizedKeysLine) { + return fmt.Errorf("authorized keys line should be:\n'%s'\nGot:\n'%s'", + string(expectedBytes), string(kp.PublicKeyAuthorizedKeysLine)) } + + return nil } diff --git a/helper/ssh/rsa_key_pair.go b/helper/ssh/rsa_key_pair.go deleted file mode 100644 index 9a83f193b..000000000 --- a/helper/ssh/rsa_key_pair.go +++ /dev/null @@ -1,38 +0,0 @@ -package ssh - -import ( - "crypto/rsa" - - gossh "golang.org/x/crypto/ssh" -) - -type rsaKeyPair struct { - privateKey *rsa.PrivateKey - publicKey gossh.PublicKey - name string - privatePemBlock []byte -} - -func (o rsaKeyPair) Type() KeyPairType { - return Rsa -} - -func (o rsaKeyPair) Bits() int { - return o.privateKey.N.BitLen() -} - -func (o rsaKeyPair) Name() string { - return o.name -} - -func (o rsaKeyPair) Description() string { - return description(o) -} - -func (o rsaKeyPair) PrivateKeyPemBlock() []byte { - return o.privatePemBlock -} - -func (o rsaKeyPair) PublicKeyAuthorizedKeysLine(nl NewLineOption) []byte { - return publicKeyAuthorizedKeysLine(o.publicKey, o.name, nl) -}