backport communicator/ssh from 0.12

pull/22184/head
James Bardin 7 years ago
parent 03b0df2dcb
commit 1700da5740

@ -3,6 +3,7 @@ package ssh
import (
"bufio"
"bytes"
"context"
"errors"
"fmt"
"io"
@ -17,6 +18,7 @@ import (
"sync"
"time"
"github.com/hashicorp/errwrap"
"github.com/hashicorp/terraform/communicator/remote"
"github.com/hashicorp/terraform/terraform"
"golang.org/x/crypto/ssh"
@ -28,20 +30,30 @@ const (
DefaultShebang = "#!/bin/sh\n"
)
// randShared is a global random generator object that is shared.
// This must be shared since it is seeded by the current time and
// creating multiple can result in the same values. By using a shared
// RNG we assure different numbers per call.
var randLock sync.Mutex
var randShared *rand.Rand
var (
// randShared is a global random generator object that is shared. This must be
// shared since it is seeded by the current time and creating multiple can
// result in the same values. By using a shared RNG we assure different numbers
// per call.
randLock sync.Mutex
randShared *rand.Rand
// enable ssh keeplive probes by default
keepAliveInterval = 2 * time.Second
// max time to wait for for a KeepAlive response before considering the
// connection to be dead.
maxKeepAliveDelay = 120 * time.Second
)
// Communicator represents the SSH communicator
type Communicator struct {
connInfo *connectionInfo
client *ssh.Client
config *sshConfig
conn net.Conn
address string
connInfo *connectionInfo
client *ssh.Client
config *sshConfig
conn net.Conn
address string
cancelKeepAlive context.CancelFunc
lock sync.Mutex
}
@ -125,11 +137,13 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
" User: %s\n"+
" Password: %t\n"+
" Private key: %t\n"+
" Certificate: %t\n"+
" SSH Agent: %t\n"+
" Checking Host Key: %t",
c.connInfo.Host, c.connInfo.User,
c.connInfo.Password != "",
c.connInfo.PrivateKey != "",
c.connInfo.Certificate != "",
c.connInfo.Agent,
c.connInfo.HostKey != "",
))
@ -152,7 +166,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
}
}
log.Printf("[DEBUG] connecting to TCP connection for SSH")
hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port)
log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort)
c.conn, err = c.config.connection()
if err != nil {
// Explicitly set this to the REAL nil. Connection() can return
@ -167,10 +182,11 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
return err
}
log.Printf("[DEBUG] handshaking with SSH")
host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port)
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config)
log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User)
sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config)
if err != nil {
err = errwrap.Wrapf(fmt.Sprintf("SSH authentication failed (%s@%s): {{err}}", c.connInfo.User, hostAndPort), err)
// While in theory this should be a fatal error, some hosts may start
// the ssh service before it is properly configured, or before user
// authentication data is available.
@ -203,11 +219,69 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) {
}
}
if err != nil {
return err
}
if o != nil {
o.Output("Connected!")
}
return err
ctx, cancelKeepAlive := context.WithCancel(context.TODO())
c.cancelKeepAlive = cancelKeepAlive
// Start a keepalive goroutine to help maintain the connection for
// long-running commands.
log.Printf("[DEBUG] starting ssh KeepAlives")
go func() {
defer cancelKeepAlive()
// Along with the KeepAlives generating packets to keep the tcp
// connection open, we will use the replies to verify liveness of the
// connection. This will prevent dead connections from blocking the
// provisioner indefinitely.
respCh := make(chan error, 1)
go func() {
t := time.NewTicker(keepAliveInterval)
defer t.Stop()
for {
select {
case <-t.C:
_, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil)
respCh <- err
case <-ctx.Done():
return
}
}
}()
after := time.NewTimer(maxKeepAliveDelay)
defer after.Stop()
for {
select {
case err := <-respCh:
if err != nil {
log.Printf("[ERROR] ssh keepalive: %s", err)
sshConn.Close()
return
}
case <-after.C:
// abort after too many missed keepalives
log.Println("[ERROR] no reply from ssh server")
sshConn.Close()
return
case <-ctx.Done():
return
}
if !after.Stop() {
<-after.C
}
after.Reset(maxKeepAliveDelay)
}
}()
return nil
}
// Disconnect implementation of communicator.Communicator interface
@ -215,6 +289,10 @@ func (c *Communicator) Disconnect() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.cancelKeepAlive != nil {
c.cancelKeepAlive()
}
if c.config.sshAgent != nil {
if err := c.config.sshAgent.Close(); err != nil {
return err
@ -481,6 +559,13 @@ func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Re
// our data and has completed. Or has errored.
log.Println("[DEBUG] Waiting for SSH session to complete.")
err = session.Wait()
// log any stderr before exiting on an error
scpErr := stderr.String()
if len(scpErr) > 0 {
log.Printf("[ERROR] scp stderr: %q", stderr)
}
if err != nil {
if exitErr, ok := err.(*ssh.ExitError); ok {
// Otherwise, we have an ExitErorr, meaning we can just read
@ -499,11 +584,6 @@ func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Re
return err
}
scpErr := stderr.String()
if len(scpErr) > 0 {
log.Printf("[ERROR] scp stderr: %q", stderr)
}
return nil
}

@ -58,10 +58,10 @@ const testServerHostCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC1
const testCAPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCrozyZIhdEvalCn+eSzHH94cO9ykiywA13ntWI7mJcHBwYTeCYWG8E9zGXyp2iDOjCGudM0Tdt8o0OofKChk9Z/qiUN0G8y1kmaXBlBM3qA5R9NPpvMYMNkYLfX6ivtZCnqrsbzaoqN2Oc/7H2StHzJWh/XCGu9otQZA6vdv1oSmAsZOjw/xIGaGQqDUaLq21J280PP1qSbdJHf76iSHE+TWe3YpqV946JWM5tCh0DykZ10VznvxYpUjzhr07IN3tVKxOXbPnnU7lX6IaLIWgfzLqwSyheeux05c3JLF9iF4sFu8ou4hwQz1iuUTU1jxgwZP0w/bkXgFFs0949lW81`
func newMockLineServer(t *testing.T, signer ssh.Signer) string {
func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string {
serverConfig := &ssh.ServerConfig{
PasswordCallback: acceptUserPass("user", "pass"),
PublicKeyCallback: acceptPublicKey(testClientPublicKey),
PublicKeyCallback: acceptPublicKey(pubKey),
}
var err error
@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer) string {
go func(in <-chan *ssh.Request) {
for req := range in {
// since this channel's requests are serviced serially,
// this will block keepalive probes, and can simulate a
// hung connection.
if bytes.Contains(req.Payload, []byte("sleep")) {
time.Sleep(time.Second)
}
if req.WantReply {
req.Reply(true, nil)
}
}
}(requests)
go func(newChannel ssh.NewChannel) {
conn.OpenChannel(newChannel.ChannelType(), nil)
}(newChannel)
defer channel.Close()
}
conn.Close()
@ -119,7 +122,7 @@ func newMockLineServer(t *testing.T, signer ssh.Signer) string {
}
func TestNew_Invalid(t *testing.T) {
address := newMockLineServer(t, nil)
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
@ -147,7 +150,7 @@ func TestNew_Invalid(t *testing.T) {
}
func TestStart(t *testing.T) {
address := newMockLineServer(t, nil)
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
@ -179,8 +182,99 @@ func TestStart(t *testing.T) {
}
}
// TestKeepAlives verifies that the keepalive messages don't interfere with
// normal operation of the client.
func TestKeepAlives(t *testing.T) {
ivl := keepAliveInterval
keepAliveInterval = 250 * time.Millisecond
defer func() { keepAliveInterval = ivl }()
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
Ephemeral: terraform.EphemeralState{
ConnInfo: map[string]string{
"type": "ssh",
"user": "user",
"password": "pass",
"host": parts[0],
"port": parts[1],
},
},
}
c, err := New(r)
if err != nil {
t.Fatalf("error creating communicator: %s", err)
}
if err := c.Connect(nil); err != nil {
t.Fatal(err)
}
var cmd remote.Cmd
stdout := new(bytes.Buffer)
cmd.Command = "sleep"
cmd.Stdout = stdout
// wait a bit before executing the command, so that at least 1 keepalive is sent
time.Sleep(500 * time.Millisecond)
err = c.Start(&cmd)
if err != nil {
t.Fatalf("error executing remote command: %s", err)
}
}
// TestDeadConnection verifies that failed keepalive messages will eventually
// kill the connection.
func TestFailedKeepAlives(t *testing.T) {
ivl := keepAliveInterval
del := maxKeepAliveDelay
maxKeepAliveDelay = 500 * time.Millisecond
keepAliveInterval = 250 * time.Millisecond
defer func() {
keepAliveInterval = ivl
maxKeepAliveDelay = del
}()
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
Ephemeral: terraform.EphemeralState{
ConnInfo: map[string]string{
"type": "ssh",
"user": "user",
"password": "pass",
"host": parts[0],
"port": parts[1],
},
},
}
c, err := New(r)
if err != nil {
t.Fatalf("error creating communicator: %s", err)
}
if err := c.Connect(nil); err != nil {
t.Fatal(err)
}
var cmd remote.Cmd
stdout := new(bytes.Buffer)
cmd.Command = "sleep"
cmd.Stdout = stdout
err = c.Start(&cmd)
if err == nil {
t.Fatal("expected connection error")
}
}
func TestLostConnection(t *testing.T) {
address := newMockLineServer(t, nil)
address := newMockLineServer(t, nil, testClientPublicKey)
parts := strings.Split(address, ":")
r := &terraform.InstanceState{
@ -229,11 +323,11 @@ func TestHostKey(t *testing.T) {
// get the server's public key
signer, err := ssh.ParsePrivateKey([]byte(testServerPrivateKey))
if err != nil {
panic("unable to parse private key: " + err.Error())
t.Fatalf("unable to parse private key: %v", err)
}
pubKey := fmt.Sprintf("ssh-rsa %s", base64.StdEncoding.EncodeToString(signer.PublicKey().Marshal()))
address := newMockLineServer(t, nil)
address := newMockLineServer(t, nil, testClientPublicKey)
host, p, _ := net.SplitHostPort(address)
port, _ := strconv.Atoi(p)
@ -269,7 +363,7 @@ func TestHostKey(t *testing.T) {
}
// now check with the wrong HostKey
address = newMockLineServer(t, nil)
address = newMockLineServer(t, nil, testClientPublicKey)
_, p, _ = net.SplitHostPort(address)
port, _ = strconv.Atoi(p)
@ -308,7 +402,7 @@ func TestHostCert(t *testing.T) {
t.Fatal(err)
}
address := newMockLineServer(t, signer)
address := newMockLineServer(t, signer, testClientPublicKey)
host, p, _ := net.SplitHostPort(address)
port, _ := strconv.Atoi(p)
@ -344,7 +438,7 @@ func TestHostCert(t *testing.T) {
}
// now check with the wrong HostKey
address = newMockLineServer(t, signer)
address = newMockLineServer(t, signer, testClientPublicKey)
_, p, _ = net.SplitHostPort(address)
port, _ = strconv.Atoi(p)
@ -367,6 +461,105 @@ func TestHostCert(t *testing.T) {
}
}
const SERVER_PEM = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEA8CkDr7uxCFt6lQUVwS8NyPO+fQNxORoGnMnN/XhVJZvpqyKR
Uji9R0d8D66bYxUUsabXjP2y4HTVzbZtnvXFZZshk0cOtJjjekpYJaLK2esPR/iX
wvSltNkrDQDPN/RmgEEMIevW8AgrPsqrnybFHxTpd7rEUHXBOe4nMNRIg3XHykB6
jZk8q5bBPUe3I/f0DK5TJEBpTc6dO3P/j93u55VUqr39/SPRHnld2mCw+c8v6UOh
sssO/DIZFPScD3DYqsk2N+/nz9zXfcOTdWGhawgxuIo1DTokrNQbG3pDrLqcWgqj
13vqJFCmRA0O2CQIwJePd6+Np/XO3Uh/KL6FlQIDAQABAoIBAQCmvQMXNmvCDqk7
30zsVDvw4fHGH+azK3Od1aqTqcEMHISOUbCtckFPxLzIsoSltRQqB1kuRVG07skm
Stsu+xny4lLcSwBVuLRuykEK2EyYIc/5Owo6y9pkhkaSf5ZfFes4bnD6+B/BhRpp
PRMMq0E+xCkX/G6iIi9mhgdlqm0x/vKtjzQeeshw9+gRcRLUpX+UeKFKXMXcDayx
qekr1bAaQKNBhTK+CbZjcqzG4f+BXVGRTZ9nsPAV+yTnWUCU0TghwPmtthHbebqa
9hlkum7qik/bQj/tjJ8/b0vTfHQSVxhtPG/ZV2Tn9ZuL/vrkYqeyMU8XkJ/uaEvH
WPyOcB4BAoGBAP5o5JSEtPog+U3JFrLNSRjz5ofZNVkJzice+0XyqlzJDHhX5tF8
mriYQZLLXYhckBm4IdkhTn/dVbXNQTzyy2WVuO5nU8bkCMvGL9CGpW4YGqwGf7NX
e4H3emtRjLv8VZpUHe/RUUDhmYvMSt1qmXuskfpROuGfLhQBUd6A4J+BAoGBAPGp
UcMKjrxZ5qjYU6DLgS+xeca4Eu70HgdbSQbRo45WubXjyXvTRFij36DrpxJWf1D7
lIsyBifoTra/lAuC1NQXGYWjTCdk2ey8Ll5qOgiXvE6lINHABr+U/Z90/g6LuML2
VzaZbq/QLcT3yVsdyTogKckzCaKsCpusyHE1CXAVAoGAd6kMglKc8N0bhZukgnsN
+5+UeacPcY6sGTh4RWErAjNKGzx1A2lROKvcg9gFaULoQECcIw2IZ5nKW5VsLueg
BWrTrcaJ4A2XmYjhKnp6SvspaGoyHD90hx/Iw7t6r1yzQsB3yDmytwqldtyjBdvC
zynPC2azhDWjraMlR7tka4ECgYAxwvLiHa9sm3qCtCDsUFtmrb3srITBjaUNUL/F
1q8+JR+Sk7gudj9xnTT0VvINNaB71YIt83wPBagHu4VJpYQbtDH+MbUBu6OgOtO1
f1w53rzY2OncJxV8p7pd9mJGLoE6LC2jQY7oRw7Vq0xcJdME1BCmrIrEY3a/vaF8
pjYuTQKBgQCIOH23Xita8KmhH0NdlWxZfcQt1j3AnOcKe6UyN4BsF8hqS7eTA52s
WjG5X2IBl7gs1eMM1qkqR8npS9nwfO/pBmZPwjiZoilypXxWj+c+P3vwre2yija4
bXgFVj4KFBwhr1+8KcobxC0SAPEouMvSkxzjjw+gnebozUtPlud9jA==
-----END RSA PRIVATE KEY-----
`
const CLIENT_CERT_SIGNED_BY_SERVER = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgbMDNUn4M2TtzrSH7MOT2QsvLzZWjehJ5TYrBOp9p+lwAAAADAQABAAABAQCyu57E7zIWRyEWuaiOiikOSZKFjbwLkpE9fboFfLLsNUJj4zw+5bZUJtzWK8roPjgL8s1oPncro5wuTtI2Nu4fkpeFK0Hb33o6Eyksuj4Om4+6Uemn1QEcb0bZqK8Zyg9Dg9deP7LeE0v78b5/jZafFgwxv+/sMhM0PRD34NCDYcYmkkHlvQtQWFAdbPXCgghObedZyYdoqZVuhTsiPMWtQS/cc9M4tv6mPOuQlhZt3R/Oh/kwUyu45oGRb5bhO4JicozFS3oeClpU+UMbgslkzApJqxZBWN7+PDFSZhKk2GslyeyP4sH3E30Z00yVi/lQYgmQsB+Hg6ClemNQMNu/AAAAAAAAAAAAAAACAAAABHVzZXIAAAAIAAAABHVzZXIAAAAAWzBjXAAAAAB/POfPAAAAAAAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEA8CkDr7uxCFt6lQUVwS8NyPO+fQNxORoGnMnN/XhVJZvpqyKRUji9R0d8D66bYxUUsabXjP2y4HTVzbZtnvXFZZshk0cOtJjjekpYJaLK2esPR/iXwvSltNkrDQDPN/RmgEEMIevW8AgrPsqrnybFHxTpd7rEUHXBOe4nMNRIg3XHykB6jZk8q5bBPUe3I/f0DK5TJEBpTc6dO3P/j93u55VUqr39/SPRHnld2mCw+c8v6UOhsssO/DIZFPScD3DYqsk2N+/nz9zXfcOTdWGhawgxuIo1DTokrNQbG3pDrLqcWgqj13vqJFCmRA0O2CQIwJePd6+Np/XO3Uh/KL6FlQAAAQ8AAAAHc3NoLXJzYQAAAQC6sKEQHyl954BQn2BXuTgOB3NkENBxN7SD8ZaS8PNkDESytLjSIqrzoE6m7xuzprA+G23XRrCY/um3UvM7+7+zbwig2NIBbGbp3QFliQHegQKW6hTZP09jAQZk5jRrrEr/QT/s+gtHPmjxJK7XOQYxhInDKj+aJg62ExcwpQlP/0ATKNOIkdzTzzq916p0UOnnVaaPMKibh5Lv69GafIhKJRZSuuLN9fvs1G1RuUbxn/BNSeoRCr54L++Ztg09fJxunoyELs8mwgzCgB3pdZoUR2Z6ak05W4mvH3lkSz2BKUrlwxI6mterxhJy1GuN1K/zBG0gEMl2UTLajGK3qKM8 itbitloaner@MacBook-Pro-4.fios-router.home`
const CLIENT_PEM = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAsruexO8yFkchFrmojoopDkmShY28C5KRPX26BXyy7DVCY+M8
PuW2VCbc1ivK6D44C/LNaD53K6OcLk7SNjbuH5KXhStB2996OhMpLLo+DpuPulHp
p9UBHG9G2aivGcoPQ4PXXj+y3hNL+/G+f42WnxYMMb/v7DITND0Q9+DQg2HGJpJB
5b0LUFhQHWz1woIITm3nWcmHaKmVboU7IjzFrUEv3HPTOLb+pjzrkJYWbd0fzof5
MFMruOaBkW+W4TuCYnKMxUt6HgpaVPlDG4LJZMwKSasWQVje/jwxUmYSpNhrJcns
j+LB9xN9GdNMlYv5UGIJkLAfh4OgpXpjUDDbvwIDAQABAoIBAEu2ctFVyk/pnbi0
uRR4rl+hBvKQUeJNGj2ELvL4Ggs5nIAX2IOEZ7JKLC6FqpSrFq7pEd5g57aSvixX
s3DH4CN7w7fj1ShBCNPlHgIWewdRGpeA74vrDWdwNAEsFdDE6aZeCTOhpDGy1vNJ
OrtpzS5i9pN0jTvvEneEjtWSZIHiiVlN+0hsFaiwZ6KXON+sDccZPmnP6Fzwj5Rc
WS0dKSwnxnx0otWgwWFs8nr306nSeMsNmQkHsS9lz4DEVpp9owdzrX1JmbQvNYAV
ohmB3ET4JYFgerqPXJfed9poueGuWCP6MYhsjNeHN35QhofxdO5/0i3JlZfqwZei
tNq/0oECgYEA6SqjRqDiIp3ajwyB7Wf0cIQG/P6JZDyN1jl//htgniliIH5UP1Tm
uAMG5MincV6X9lOyXyh6Yofu5+NR0yt9SqbDZVJ3ZCxKTun7pxJvQFd7wl5bMkiJ
qVfS08k6gQHHDoO+eel+DtpIfWc+e3tvX0aihSU0GZEMqDXYkkphLGECgYEAxDxb
+JwJ3N5UEjjkuvFBpuJnmjIaN9HvQkTv3inlx1gLE4iWBZXXsu4aWF8MCUeAAZyP
42hQDSkCYX/A22tYCEn/jfrU6A+6rkWBTjdUlYLvlSkhosSnO+117WEItb5cUE95
hF4UY7LNs1AsDkV4WE87f/EjpxSwUAjB2Lfd/B8CgYAJ/JiHsuZcozQ0Qk3iVDyF
ATKnbWOHFozgqw/PW27U92LLj32eRM2o/gAylmGNmoaZt1YBe2NaiwXxiqv7hnZU
VzYxRcn1UWxRWvY7Xq/DKrwTRCVVzwOObEOMbKcD1YaoGX50DEso6bKHJH/pnAzW
INlfKIvFuI+5OK0w/tyQoQKBgQCf/jpaOxaLfrV62eobRQJrByLDBGB97GsvU7di
IjTWz8DQH0d5rE7d8uWF8ZCFrEcAiV6DYZQK9smbJqbd/uoacAKtBro5rkFdPwwK
8m/DKqsdqRhkdgOHh7bjYH7Sdy8ax4Fi27WyB6FQtmgFBrz0+zyetsODwQlzZ4Bs
qpSRrwKBgQC0vWHrY5aGIdF+b8EpP0/SSLLALpMySHyWhDyxYcPqdhszYbjDcavv
xrrLXNUD2duBHKPVYE+7uVoDkpZXLUQ4x8argo/IwQM6Kh2ma1y83TYMT6XhL1+B
5UPcl6RXZBCkiU7nFIG6/0XKFqVWc3fU8e09X+iJwXIJ5Jatywtg+g==
-----END RSA PRIVATE KEY-----
`
func TestCertificateBasedAuth(t *testing.T) {
signer, err := ssh.ParsePrivateKey([]byte(SERVER_PEM))
if err != nil {
t.Fatalf("unable to parse private key: %v", err)
}
address := newMockLineServer(t, signer, CLIENT_CERT_SIGNED_BY_SERVER)
host, p, _ := net.SplitHostPort(address)
port, _ := strconv.Atoi(p)
connInfo := &connectionInfo{
User: "user",
Host: host,
PrivateKey: CLIENT_PEM,
Certificate: CLIENT_CERT_SIGNED_BY_SERVER,
Port: port,
Timeout: "30s",
}
cfg, err := prepareSSHConfig(connInfo)
if err != nil {
t.Fatal(err)
}
c := &Communicator{
connInfo: connInfo,
config: cfg,
}
var cmd remote.Cmd
stdout := new(bytes.Buffer)
cmd.Command = "echo foo"
cmd.Stdout = stdout
if err := c.Start(&cmd); err != nil {
t.Fatal(err)
}
if err := c.Disconnect(); err != nil {
t.Fatal(err)
}
}
func TestAccUploadFile(t *testing.T) {
// use the local ssh server and scp binary to check uploads
if ok := os.Getenv("SSH_UPLOAD_TEST"); ok == "" {
@ -572,11 +765,12 @@ func acceptUserPass(goodUser, goodPass string) func(ssh.ConnMetadata, []byte) (*
}
func acceptPublicKey(keystr string) func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) {
goodkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keystr))
if err != nil {
panic(fmt.Errorf("error parsing key: %s", err))
}
return func(_ ssh.ConnMetadata, inkey ssh.PublicKey) (*ssh.Permissions, error) {
goodkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keystr))
if err != nil {
return nil, fmt.Errorf("error parsing key: %v", err)
}
if bytes.Equal(inkey.Marshal(), goodkey.Marshal()) {
return nil, nil
}

@ -3,6 +3,7 @@ package ssh
import (
"bytes"
"encoding/pem"
"errors"
"fmt"
"io/ioutil"
"log"
@ -40,16 +41,17 @@ const (
// only keys we look at. If a PrivateKey is given, that is used instead
// of a password.
type connectionInfo struct {
User string
Password string
PrivateKey string `mapstructure:"private_key"`
Host string
HostKey string `mapstructure:"host_key"`
Port int
Agent bool
Timeout string
ScriptPath string `mapstructure:"script_path"`
TimeoutVal time.Duration `mapstructure:"-"`
User string
Password string
PrivateKey string `mapstructure:"private_key"`
Certificate string `mapstructure:"certificate"`
Host string
HostKey string `mapstructure:"host_key"`
Port int
Agent bool
Timeout string
ScriptPath string `mapstructure:"script_path"`
TimeoutVal time.Duration `mapstructure:"-"`
BastionUser string `mapstructure:"bastion_user"`
BastionPassword string `mapstructure:"bastion_password"`
@ -150,12 +152,13 @@ func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
user: connInfo.User,
host: host,
privateKey: connInfo.PrivateKey,
password: connInfo.Password,
hostKey: connInfo.HostKey,
sshAgent: sshAgent,
user: connInfo.User,
host: host,
privateKey: connInfo.PrivateKey,
password: connInfo.Password,
hostKey: connInfo.HostKey,
certificate: connInfo.Certificate,
sshAgent: sshAgent,
})
if err != nil {
return nil, err
@ -191,12 +194,13 @@ func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
}
type sshClientConfigOpts struct {
privateKey string
password string
sshAgent *sshAgent
user string
host string
hostKey string
privateKey string
password string
sshAgent *sshAgent
certificate string
user string
host string
hostKey string
}
func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
@ -234,11 +238,23 @@ func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
}
if opts.privateKey != "" {
pubKeyAuth, err := readPrivateKey(opts.privateKey)
if err != nil {
return nil, err
if opts.certificate != "" {
log.Println("using client certificate for authentication")
certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate)
if err != nil {
return nil, err
}
conf.Auth = append(conf.Auth, certSigner)
} else {
log.Println("using private key for authentication")
pubKeyAuth, err := readPrivateKey(opts.privateKey)
if err != nil {
return nil, err
}
conf.Auth = append(conf.Auth, pubKeyAuth)
}
conf.Auth = append(conf.Auth, pubKeyAuth)
}
if opts.password != "" {
@ -254,22 +270,47 @@ func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
return conf, nil
}
// Create a Cert Signer and return ssh.AuthMethod
func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) {
rawPk, err := ssh.ParseRawPrivateKey([]byte(pk))
if err != nil {
return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err)
}
pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate))
if err != nil {
return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err)
}
usigner, err := ssh.NewSignerFromKey(rawPk)
if err != nil {
return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err)
}
ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner)
if err != nil {
return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err)
}
return ssh.PublicKeys(ucertSigner), nil
}
func readPrivateKey(pk string) (ssh.AuthMethod, error) {
// We parse the private key on our own first so that we can
// show a nicer error if the private key has a password.
block, _ := pem.Decode([]byte(pk))
if block == nil {
return nil, fmt.Errorf("Failed to read key %q: no key found", pk)
return nil, errors.New("Failed to read ssh private key: no key found")
}
if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
return nil, fmt.Errorf(
"Failed to read key %q: password protected keys are\n"+
"not supported. Please decrypt the key prior to use.", pk)
return nil, errors.New(
"Failed to read ssh private key: password protected keys are\n" +
"not supported. Please decrypt the key prior to use.")
}
signer, err := ssh.ParsePrivateKey([]byte(pk))
if err != nil {
return nil, fmt.Errorf("Failed to parse key file %q: %s", pk, err)
return nil, fmt.Errorf("Failed to parse ssh private key: %s", err)
}
return ssh.PublicKeys(signer), nil

@ -14,6 +14,7 @@ func TestProvisioner_connInfo(t *testing.T) {
"user": "root",
"password": "supersecret",
"private_key": "someprivatekeycontents",
"certificate": "somecertificate",
"host": "127.0.0.1",
"port": "22",
"timeout": "30s",
@ -37,6 +38,9 @@ func TestProvisioner_connInfo(t *testing.T) {
if conf.PrivateKey != "someprivatekeycontents" {
t.Fatalf("bad: %v", conf)
}
if conf.Certificate != "somecertificate" {
t.Fatalf("bad: %v", conf)
}
if conf.Host != "127.0.0.1" {
t.Fatalf("bad: %v", conf)
}
@ -74,6 +78,7 @@ func TestProvisioner_connInfoIpv6(t *testing.T) {
"user": "root",
"password": "supersecret",
"private_key": "someprivatekeycontents",
"certificate": "somecertificate",
"host": "::1",
"port": "22",
"timeout": "30s",
@ -101,14 +106,13 @@ func TestProvisioner_connInfoHostname(t *testing.T) {
r := &terraform.InstanceState{
Ephemeral: terraform.EphemeralState{
ConnInfo: map[string]string{
"type": "ssh",
"user": "root",
"password": "supersecret",
"private_key": "someprivatekeycontents",
"host": "example.com",
"port": "22",
"timeout": "30s",
"type": "ssh",
"user": "root",
"password": "supersecret",
"private_key": "someprivatekeycontents",
"host": "example.com",
"port": "22",
"timeout": "30s",
"bastion_host": "example.com",
},
},

Loading…
Cancel
Save