diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index c829f4ac49..0d5aaf34be 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -3,6 +3,7 @@ package ssh import ( "bufio" "bytes" + "context" "errors" "fmt" "io" @@ -17,6 +18,7 @@ import ( "sync" "time" + "github.com/hashicorp/errwrap" "github.com/hashicorp/terraform/communicator/remote" "github.com/hashicorp/terraform/terraform" "golang.org/x/crypto/ssh" @@ -28,20 +30,30 @@ const ( DefaultShebang = "#!/bin/sh\n" ) -// randShared is a global random generator object that is shared. -// This must be shared since it is seeded by the current time and -// creating multiple can result in the same values. By using a shared -// RNG we assure different numbers per call. -var randLock sync.Mutex -var randShared *rand.Rand +var ( + // randShared is a global random generator object that is shared. This must be + // shared since it is seeded by the current time and creating multiple can + // result in the same values. By using a shared RNG we assure different numbers + // per call. + randLock sync.Mutex + randShared *rand.Rand + + // enable ssh keeplive probes by default + keepAliveInterval = 2 * time.Second + + // max time to wait for for a KeepAlive response before considering the + // connection to be dead. + maxKeepAliveDelay = 120 * time.Second +) // Communicator represents the SSH communicator type Communicator struct { - connInfo *connectionInfo - client *ssh.Client - config *sshConfig - conn net.Conn - address string + connInfo *connectionInfo + client *ssh.Client + config *sshConfig + conn net.Conn + address string + cancelKeepAlive context.CancelFunc lock sync.Mutex } @@ -125,11 +137,13 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { " User: %s\n"+ " Password: %t\n"+ " Private key: %t\n"+ + " Certificate: %t\n"+ " SSH Agent: %t\n"+ " Checking Host Key: %t", c.connInfo.Host, c.connInfo.User, c.connInfo.Password != "", c.connInfo.PrivateKey != "", + c.connInfo.Certificate != "", c.connInfo.Agent, c.connInfo.HostKey != "", )) @@ -152,7 +166,8 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { } } - log.Printf("[DEBUG] connecting to TCP connection for SSH") + hostAndPort := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) + log.Printf("[DEBUG] Connecting to %s for SSH", hostAndPort) c.conn, err = c.config.connection() if err != nil { // Explicitly set this to the REAL nil. Connection() can return @@ -167,10 +182,11 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { return err } - log.Printf("[DEBUG] handshaking with SSH") - host := fmt.Sprintf("%s:%d", c.connInfo.Host, c.connInfo.Port) - sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, host, c.config.config) + log.Printf("[DEBUG] Connection established. Handshaking for user %v", c.connInfo.User) + sshConn, sshChan, req, err := ssh.NewClientConn(c.conn, hostAndPort, c.config.config) if err != nil { + err = errwrap.Wrapf(fmt.Sprintf("SSH authentication failed (%s@%s): {{err}}", c.connInfo.User, hostAndPort), err) + // While in theory this should be a fatal error, some hosts may start // the ssh service before it is properly configured, or before user // authentication data is available. @@ -203,11 +219,69 @@ func (c *Communicator) Connect(o terraform.UIOutput) (err error) { } } + if err != nil { + return err + } + if o != nil { o.Output("Connected!") } - return err + ctx, cancelKeepAlive := context.WithCancel(context.TODO()) + c.cancelKeepAlive = cancelKeepAlive + + // Start a keepalive goroutine to help maintain the connection for + // long-running commands. + log.Printf("[DEBUG] starting ssh KeepAlives") + go func() { + defer cancelKeepAlive() + // Along with the KeepAlives generating packets to keep the tcp + // connection open, we will use the replies to verify liveness of the + // connection. This will prevent dead connections from blocking the + // provisioner indefinitely. + respCh := make(chan error, 1) + + go func() { + t := time.NewTicker(keepAliveInterval) + defer t.Stop() + for { + select { + case <-t.C: + _, _, err := c.client.SendRequest("keepalive@terraform.io", true, nil) + respCh <- err + case <-ctx.Done(): + return + } + } + }() + + after := time.NewTimer(maxKeepAliveDelay) + defer after.Stop() + + for { + select { + case err := <-respCh: + if err != nil { + log.Printf("[ERROR] ssh keepalive: %s", err) + sshConn.Close() + return + } + case <-after.C: + // abort after too many missed keepalives + log.Println("[ERROR] no reply from ssh server") + sshConn.Close() + return + case <-ctx.Done(): + return + } + if !after.Stop() { + <-after.C + } + after.Reset(maxKeepAliveDelay) + } + }() + + return nil } // Disconnect implementation of communicator.Communicator interface @@ -215,6 +289,10 @@ func (c *Communicator) Disconnect() error { c.lock.Lock() defer c.lock.Unlock() + if c.cancelKeepAlive != nil { + c.cancelKeepAlive() + } + if c.config.sshAgent != nil { if err := c.config.sshAgent.Close(); err != nil { return err @@ -481,6 +559,13 @@ func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Re // our data and has completed. Or has errored. log.Println("[DEBUG] Waiting for SSH session to complete.") err = session.Wait() + + // log any stderr before exiting on an error + scpErr := stderr.String() + if len(scpErr) > 0 { + log.Printf("[ERROR] scp stderr: %q", stderr) + } + if err != nil { if exitErr, ok := err.(*ssh.ExitError); ok { // Otherwise, we have an ExitErorr, meaning we can just read @@ -499,11 +584,6 @@ func (c *Communicator) scpSession(scpCommand string, f func(io.Writer, *bufio.Re return err } - scpErr := stderr.String() - if len(scpErr) > 0 { - log.Printf("[ERROR] scp stderr: %q", stderr) - } - return nil } diff --git a/communicator/ssh/communicator_test.go b/communicator/ssh/communicator_test.go index 546d6f88cc..bbe8213630 100644 --- a/communicator/ssh/communicator_test.go +++ b/communicator/ssh/communicator_test.go @@ -58,10 +58,10 @@ const testServerHostCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC1 const testCAPublicKey = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCrozyZIhdEvalCn+eSzHH94cO9ykiywA13ntWI7mJcHBwYTeCYWG8E9zGXyp2iDOjCGudM0Tdt8o0OofKChk9Z/qiUN0G8y1kmaXBlBM3qA5R9NPpvMYMNkYLfX6ivtZCnqrsbzaoqN2Oc/7H2StHzJWh/XCGu9otQZA6vdv1oSmAsZOjw/xIGaGQqDUaLq21J280PP1qSbdJHf76iSHE+TWe3YpqV946JWM5tCh0DykZ10VznvxYpUjzhr07IN3tVKxOXbPnnU7lX6IaLIWgfzLqwSyheeux05c3JLF9iF4sFu8ou4hwQz1iuUTU1jxgwZP0w/bkXgFFs0949lW81` -func newMockLineServer(t *testing.T, signer ssh.Signer) string { +func newMockLineServer(t *testing.T, signer ssh.Signer, pubKey string) string { serverConfig := &ssh.ServerConfig{ PasswordCallback: acceptUserPass("user", "pass"), - PublicKeyCallback: acceptPublicKey(testClientPublicKey), + PublicKeyCallback: acceptPublicKey(pubKey), } var err error @@ -100,16 +100,19 @@ func newMockLineServer(t *testing.T, signer ssh.Signer) string { go func(in <-chan *ssh.Request) { for req := range in { + // since this channel's requests are serviced serially, + // this will block keepalive probes, and can simulate a + // hung connection. + if bytes.Contains(req.Payload, []byte("sleep")) { + time.Sleep(time.Second) + } + if req.WantReply { req.Reply(true, nil) } } }(requests) - go func(newChannel ssh.NewChannel) { - conn.OpenChannel(newChannel.ChannelType(), nil) - }(newChannel) - defer channel.Close() } conn.Close() @@ -119,7 +122,7 @@ func newMockLineServer(t *testing.T, signer ssh.Signer) string { } func TestNew_Invalid(t *testing.T) { - address := newMockLineServer(t, nil) + address := newMockLineServer(t, nil, testClientPublicKey) parts := strings.Split(address, ":") r := &terraform.InstanceState{ @@ -147,7 +150,7 @@ func TestNew_Invalid(t *testing.T) { } func TestStart(t *testing.T) { - address := newMockLineServer(t, nil) + address := newMockLineServer(t, nil, testClientPublicKey) parts := strings.Split(address, ":") r := &terraform.InstanceState{ @@ -179,8 +182,99 @@ func TestStart(t *testing.T) { } } +// TestKeepAlives verifies that the keepalive messages don't interfere with +// normal operation of the client. +func TestKeepAlives(t *testing.T) { + ivl := keepAliveInterval + keepAliveInterval = 250 * time.Millisecond + defer func() { keepAliveInterval = ivl }() + + address := newMockLineServer(t, nil, testClientPublicKey) + parts := strings.Split(address, ":") + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": "user", + "password": "pass", + "host": parts[0], + "port": parts[1], + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + if err := c.Connect(nil); err != nil { + t.Fatal(err) + } + + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "sleep" + cmd.Stdout = stdout + + // wait a bit before executing the command, so that at least 1 keepalive is sent + time.Sleep(500 * time.Millisecond) + + err = c.Start(&cmd) + if err != nil { + t.Fatalf("error executing remote command: %s", err) + } +} + +// TestDeadConnection verifies that failed keepalive messages will eventually +// kill the connection. +func TestFailedKeepAlives(t *testing.T) { + ivl := keepAliveInterval + del := maxKeepAliveDelay + maxKeepAliveDelay = 500 * time.Millisecond + keepAliveInterval = 250 * time.Millisecond + defer func() { + keepAliveInterval = ivl + maxKeepAliveDelay = del + }() + + address := newMockLineServer(t, nil, testClientPublicKey) + parts := strings.Split(address, ":") + + r := &terraform.InstanceState{ + Ephemeral: terraform.EphemeralState{ + ConnInfo: map[string]string{ + "type": "ssh", + "user": "user", + "password": "pass", + "host": parts[0], + "port": parts[1], + }, + }, + } + + c, err := New(r) + if err != nil { + t.Fatalf("error creating communicator: %s", err) + } + + if err := c.Connect(nil); err != nil { + t.Fatal(err) + } + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "sleep" + cmd.Stdout = stdout + + err = c.Start(&cmd) + if err == nil { + t.Fatal("expected connection error") + } +} + func TestLostConnection(t *testing.T) { - address := newMockLineServer(t, nil) + address := newMockLineServer(t, nil, testClientPublicKey) parts := strings.Split(address, ":") r := &terraform.InstanceState{ @@ -229,11 +323,11 @@ func TestHostKey(t *testing.T) { // get the server's public key signer, err := ssh.ParsePrivateKey([]byte(testServerPrivateKey)) if err != nil { - panic("unable to parse private key: " + err.Error()) + t.Fatalf("unable to parse private key: %v", err) } pubKey := fmt.Sprintf("ssh-rsa %s", base64.StdEncoding.EncodeToString(signer.PublicKey().Marshal())) - address := newMockLineServer(t, nil) + address := newMockLineServer(t, nil, testClientPublicKey) host, p, _ := net.SplitHostPort(address) port, _ := strconv.Atoi(p) @@ -269,7 +363,7 @@ func TestHostKey(t *testing.T) { } // now check with the wrong HostKey - address = newMockLineServer(t, nil) + address = newMockLineServer(t, nil, testClientPublicKey) _, p, _ = net.SplitHostPort(address) port, _ = strconv.Atoi(p) @@ -308,7 +402,7 @@ func TestHostCert(t *testing.T) { t.Fatal(err) } - address := newMockLineServer(t, signer) + address := newMockLineServer(t, signer, testClientPublicKey) host, p, _ := net.SplitHostPort(address) port, _ := strconv.Atoi(p) @@ -344,7 +438,7 @@ func TestHostCert(t *testing.T) { } // now check with the wrong HostKey - address = newMockLineServer(t, signer) + address = newMockLineServer(t, signer, testClientPublicKey) _, p, _ = net.SplitHostPort(address) port, _ = strconv.Atoi(p) @@ -367,6 +461,105 @@ func TestHostCert(t *testing.T) { } } +const SERVER_PEM = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEA8CkDr7uxCFt6lQUVwS8NyPO+fQNxORoGnMnN/XhVJZvpqyKR +Uji9R0d8D66bYxUUsabXjP2y4HTVzbZtnvXFZZshk0cOtJjjekpYJaLK2esPR/iX +wvSltNkrDQDPN/RmgEEMIevW8AgrPsqrnybFHxTpd7rEUHXBOe4nMNRIg3XHykB6 +jZk8q5bBPUe3I/f0DK5TJEBpTc6dO3P/j93u55VUqr39/SPRHnld2mCw+c8v6UOh +sssO/DIZFPScD3DYqsk2N+/nz9zXfcOTdWGhawgxuIo1DTokrNQbG3pDrLqcWgqj +13vqJFCmRA0O2CQIwJePd6+Np/XO3Uh/KL6FlQIDAQABAoIBAQCmvQMXNmvCDqk7 +30zsVDvw4fHGH+azK3Od1aqTqcEMHISOUbCtckFPxLzIsoSltRQqB1kuRVG07skm +Stsu+xny4lLcSwBVuLRuykEK2EyYIc/5Owo6y9pkhkaSf5ZfFes4bnD6+B/BhRpp +PRMMq0E+xCkX/G6iIi9mhgdlqm0x/vKtjzQeeshw9+gRcRLUpX+UeKFKXMXcDayx +qekr1bAaQKNBhTK+CbZjcqzG4f+BXVGRTZ9nsPAV+yTnWUCU0TghwPmtthHbebqa +9hlkum7qik/bQj/tjJ8/b0vTfHQSVxhtPG/ZV2Tn9ZuL/vrkYqeyMU8XkJ/uaEvH +WPyOcB4BAoGBAP5o5JSEtPog+U3JFrLNSRjz5ofZNVkJzice+0XyqlzJDHhX5tF8 +mriYQZLLXYhckBm4IdkhTn/dVbXNQTzyy2WVuO5nU8bkCMvGL9CGpW4YGqwGf7NX +e4H3emtRjLv8VZpUHe/RUUDhmYvMSt1qmXuskfpROuGfLhQBUd6A4J+BAoGBAPGp +UcMKjrxZ5qjYU6DLgS+xeca4Eu70HgdbSQbRo45WubXjyXvTRFij36DrpxJWf1D7 +lIsyBifoTra/lAuC1NQXGYWjTCdk2ey8Ll5qOgiXvE6lINHABr+U/Z90/g6LuML2 +VzaZbq/QLcT3yVsdyTogKckzCaKsCpusyHE1CXAVAoGAd6kMglKc8N0bhZukgnsN ++5+UeacPcY6sGTh4RWErAjNKGzx1A2lROKvcg9gFaULoQECcIw2IZ5nKW5VsLueg +BWrTrcaJ4A2XmYjhKnp6SvspaGoyHD90hx/Iw7t6r1yzQsB3yDmytwqldtyjBdvC +zynPC2azhDWjraMlR7tka4ECgYAxwvLiHa9sm3qCtCDsUFtmrb3srITBjaUNUL/F +1q8+JR+Sk7gudj9xnTT0VvINNaB71YIt83wPBagHu4VJpYQbtDH+MbUBu6OgOtO1 +f1w53rzY2OncJxV8p7pd9mJGLoE6LC2jQY7oRw7Vq0xcJdME1BCmrIrEY3a/vaF8 +pjYuTQKBgQCIOH23Xita8KmhH0NdlWxZfcQt1j3AnOcKe6UyN4BsF8hqS7eTA52s +WjG5X2IBl7gs1eMM1qkqR8npS9nwfO/pBmZPwjiZoilypXxWj+c+P3vwre2yija4 +bXgFVj4KFBwhr1+8KcobxC0SAPEouMvSkxzjjw+gnebozUtPlud9jA== +-----END RSA PRIVATE KEY----- +` +const CLIENT_CERT_SIGNED_BY_SERVER = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgbMDNUn4M2TtzrSH7MOT2QsvLzZWjehJ5TYrBOp9p+lwAAAADAQABAAABAQCyu57E7zIWRyEWuaiOiikOSZKFjbwLkpE9fboFfLLsNUJj4zw+5bZUJtzWK8roPjgL8s1oPncro5wuTtI2Nu4fkpeFK0Hb33o6Eyksuj4Om4+6Uemn1QEcb0bZqK8Zyg9Dg9deP7LeE0v78b5/jZafFgwxv+/sMhM0PRD34NCDYcYmkkHlvQtQWFAdbPXCgghObedZyYdoqZVuhTsiPMWtQS/cc9M4tv6mPOuQlhZt3R/Oh/kwUyu45oGRb5bhO4JicozFS3oeClpU+UMbgslkzApJqxZBWN7+PDFSZhKk2GslyeyP4sH3E30Z00yVi/lQYgmQsB+Hg6ClemNQMNu/AAAAAAAAAAAAAAACAAAABHVzZXIAAAAIAAAABHVzZXIAAAAAWzBjXAAAAAB/POfPAAAAAAAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEA8CkDr7uxCFt6lQUVwS8NyPO+fQNxORoGnMnN/XhVJZvpqyKRUji9R0d8D66bYxUUsabXjP2y4HTVzbZtnvXFZZshk0cOtJjjekpYJaLK2esPR/iXwvSltNkrDQDPN/RmgEEMIevW8AgrPsqrnybFHxTpd7rEUHXBOe4nMNRIg3XHykB6jZk8q5bBPUe3I/f0DK5TJEBpTc6dO3P/j93u55VUqr39/SPRHnld2mCw+c8v6UOhsssO/DIZFPScD3DYqsk2N+/nz9zXfcOTdWGhawgxuIo1DTokrNQbG3pDrLqcWgqj13vqJFCmRA0O2CQIwJePd6+Np/XO3Uh/KL6FlQAAAQ8AAAAHc3NoLXJzYQAAAQC6sKEQHyl954BQn2BXuTgOB3NkENBxN7SD8ZaS8PNkDESytLjSIqrzoE6m7xuzprA+G23XRrCY/um3UvM7+7+zbwig2NIBbGbp3QFliQHegQKW6hTZP09jAQZk5jRrrEr/QT/s+gtHPmjxJK7XOQYxhInDKj+aJg62ExcwpQlP/0ATKNOIkdzTzzq916p0UOnnVaaPMKibh5Lv69GafIhKJRZSuuLN9fvs1G1RuUbxn/BNSeoRCr54L++Ztg09fJxunoyELs8mwgzCgB3pdZoUR2Z6ak05W4mvH3lkSz2BKUrlwxI6mterxhJy1GuN1K/zBG0gEMl2UTLajGK3qKM8 itbitloaner@MacBook-Pro-4.fios-router.home` +const CLIENT_PEM = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAsruexO8yFkchFrmojoopDkmShY28C5KRPX26BXyy7DVCY+M8 +PuW2VCbc1ivK6D44C/LNaD53K6OcLk7SNjbuH5KXhStB2996OhMpLLo+DpuPulHp +p9UBHG9G2aivGcoPQ4PXXj+y3hNL+/G+f42WnxYMMb/v7DITND0Q9+DQg2HGJpJB +5b0LUFhQHWz1woIITm3nWcmHaKmVboU7IjzFrUEv3HPTOLb+pjzrkJYWbd0fzof5 +MFMruOaBkW+W4TuCYnKMxUt6HgpaVPlDG4LJZMwKSasWQVje/jwxUmYSpNhrJcns +j+LB9xN9GdNMlYv5UGIJkLAfh4OgpXpjUDDbvwIDAQABAoIBAEu2ctFVyk/pnbi0 +uRR4rl+hBvKQUeJNGj2ELvL4Ggs5nIAX2IOEZ7JKLC6FqpSrFq7pEd5g57aSvixX +s3DH4CN7w7fj1ShBCNPlHgIWewdRGpeA74vrDWdwNAEsFdDE6aZeCTOhpDGy1vNJ +OrtpzS5i9pN0jTvvEneEjtWSZIHiiVlN+0hsFaiwZ6KXON+sDccZPmnP6Fzwj5Rc +WS0dKSwnxnx0otWgwWFs8nr306nSeMsNmQkHsS9lz4DEVpp9owdzrX1JmbQvNYAV +ohmB3ET4JYFgerqPXJfed9poueGuWCP6MYhsjNeHN35QhofxdO5/0i3JlZfqwZei +tNq/0oECgYEA6SqjRqDiIp3ajwyB7Wf0cIQG/P6JZDyN1jl//htgniliIH5UP1Tm +uAMG5MincV6X9lOyXyh6Yofu5+NR0yt9SqbDZVJ3ZCxKTun7pxJvQFd7wl5bMkiJ +qVfS08k6gQHHDoO+eel+DtpIfWc+e3tvX0aihSU0GZEMqDXYkkphLGECgYEAxDxb ++JwJ3N5UEjjkuvFBpuJnmjIaN9HvQkTv3inlx1gLE4iWBZXXsu4aWF8MCUeAAZyP +42hQDSkCYX/A22tYCEn/jfrU6A+6rkWBTjdUlYLvlSkhosSnO+117WEItb5cUE95 +hF4UY7LNs1AsDkV4WE87f/EjpxSwUAjB2Lfd/B8CgYAJ/JiHsuZcozQ0Qk3iVDyF +ATKnbWOHFozgqw/PW27U92LLj32eRM2o/gAylmGNmoaZt1YBe2NaiwXxiqv7hnZU +VzYxRcn1UWxRWvY7Xq/DKrwTRCVVzwOObEOMbKcD1YaoGX50DEso6bKHJH/pnAzW +INlfKIvFuI+5OK0w/tyQoQKBgQCf/jpaOxaLfrV62eobRQJrByLDBGB97GsvU7di +IjTWz8DQH0d5rE7d8uWF8ZCFrEcAiV6DYZQK9smbJqbd/uoacAKtBro5rkFdPwwK +8m/DKqsdqRhkdgOHh7bjYH7Sdy8ax4Fi27WyB6FQtmgFBrz0+zyetsODwQlzZ4Bs +qpSRrwKBgQC0vWHrY5aGIdF+b8EpP0/SSLLALpMySHyWhDyxYcPqdhszYbjDcavv +xrrLXNUD2duBHKPVYE+7uVoDkpZXLUQ4x8argo/IwQM6Kh2ma1y83TYMT6XhL1+B +5UPcl6RXZBCkiU7nFIG6/0XKFqVWc3fU8e09X+iJwXIJ5Jatywtg+g== +-----END RSA PRIVATE KEY----- +` + +func TestCertificateBasedAuth(t *testing.T) { + signer, err := ssh.ParsePrivateKey([]byte(SERVER_PEM)) + if err != nil { + t.Fatalf("unable to parse private key: %v", err) + } + address := newMockLineServer(t, signer, CLIENT_CERT_SIGNED_BY_SERVER) + host, p, _ := net.SplitHostPort(address) + port, _ := strconv.Atoi(p) + + connInfo := &connectionInfo{ + User: "user", + Host: host, + PrivateKey: CLIENT_PEM, + Certificate: CLIENT_CERT_SIGNED_BY_SERVER, + Port: port, + Timeout: "30s", + } + + cfg, err := prepareSSHConfig(connInfo) + if err != nil { + t.Fatal(err) + } + + c := &Communicator{ + connInfo: connInfo, + config: cfg, + } + + var cmd remote.Cmd + stdout := new(bytes.Buffer) + cmd.Command = "echo foo" + cmd.Stdout = stdout + + if err := c.Start(&cmd); err != nil { + t.Fatal(err) + } + if err := c.Disconnect(); err != nil { + t.Fatal(err) + } +} + func TestAccUploadFile(t *testing.T) { // use the local ssh server and scp binary to check uploads if ok := os.Getenv("SSH_UPLOAD_TEST"); ok == "" { @@ -572,11 +765,12 @@ func acceptUserPass(goodUser, goodPass string) func(ssh.ConnMetadata, []byte) (* } func acceptPublicKey(keystr string) func(ssh.ConnMetadata, ssh.PublicKey) (*ssh.Permissions, error) { - goodkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keystr)) - if err != nil { - panic(fmt.Errorf("error parsing key: %s", err)) - } return func(_ ssh.ConnMetadata, inkey ssh.PublicKey) (*ssh.Permissions, error) { + goodkey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(keystr)) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + if bytes.Equal(inkey.Marshal(), goodkey.Marshal()) { return nil, nil } diff --git a/communicator/ssh/provisioner.go b/communicator/ssh/provisioner.go index b6cb1b05fa..21d73b2761 100644 --- a/communicator/ssh/provisioner.go +++ b/communicator/ssh/provisioner.go @@ -3,6 +3,7 @@ package ssh import ( "bytes" "encoding/pem" + "errors" "fmt" "io/ioutil" "log" @@ -40,16 +41,17 @@ const ( // only keys we look at. If a PrivateKey is given, that is used instead // of a password. type connectionInfo struct { - User string - Password string - PrivateKey string `mapstructure:"private_key"` - Host string - HostKey string `mapstructure:"host_key"` - Port int - Agent bool - Timeout string - ScriptPath string `mapstructure:"script_path"` - TimeoutVal time.Duration `mapstructure:"-"` + User string + Password string + PrivateKey string `mapstructure:"private_key"` + Certificate string `mapstructure:"certificate"` + Host string + HostKey string `mapstructure:"host_key"` + Port int + Agent bool + Timeout string + ScriptPath string `mapstructure:"script_path"` + TimeoutVal time.Duration `mapstructure:"-"` BastionUser string `mapstructure:"bastion_user"` BastionPassword string `mapstructure:"bastion_password"` @@ -150,12 +152,13 @@ func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) { host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port) sshConf, err := buildSSHClientConfig(sshClientConfigOpts{ - user: connInfo.User, - host: host, - privateKey: connInfo.PrivateKey, - password: connInfo.Password, - hostKey: connInfo.HostKey, - sshAgent: sshAgent, + user: connInfo.User, + host: host, + privateKey: connInfo.PrivateKey, + password: connInfo.Password, + hostKey: connInfo.HostKey, + certificate: connInfo.Certificate, + sshAgent: sshAgent, }) if err != nil { return nil, err @@ -191,12 +194,13 @@ func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) { } type sshClientConfigOpts struct { - privateKey string - password string - sshAgent *sshAgent - user string - host string - hostKey string + privateKey string + password string + sshAgent *sshAgent + certificate string + user string + host string + hostKey string } func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) { @@ -234,11 +238,23 @@ func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) { } if opts.privateKey != "" { - pubKeyAuth, err := readPrivateKey(opts.privateKey) - if err != nil { - return nil, err + if opts.certificate != "" { + log.Println("using client certificate for authentication") + + certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate) + if err != nil { + return nil, err + } + conf.Auth = append(conf.Auth, certSigner) + } else { + log.Println("using private key for authentication") + + pubKeyAuth, err := readPrivateKey(opts.privateKey) + if err != nil { + return nil, err + } + conf.Auth = append(conf.Auth, pubKeyAuth) } - conf.Auth = append(conf.Auth, pubKeyAuth) } if opts.password != "" { @@ -254,22 +270,47 @@ func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) { return conf, nil } +// Create a Cert Signer and return ssh.AuthMethod +func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) { + rawPk, err := ssh.ParseRawPrivateKey([]byte(pk)) + if err != nil { + return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err) + } + + pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate)) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err) + } + + usigner, err := ssh.NewSignerFromKey(rawPk) + if err != nil { + return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err) + } + + ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner) + if err != nil { + return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err) + } + + return ssh.PublicKeys(ucertSigner), nil +} + func readPrivateKey(pk string) (ssh.AuthMethod, error) { // We parse the private key on our own first so that we can // show a nicer error if the private key has a password. block, _ := pem.Decode([]byte(pk)) if block == nil { - return nil, fmt.Errorf("Failed to read key %q: no key found", pk) + return nil, errors.New("Failed to read ssh private key: no key found") } if block.Headers["Proc-Type"] == "4,ENCRYPTED" { - return nil, fmt.Errorf( - "Failed to read key %q: password protected keys are\n"+ - "not supported. Please decrypt the key prior to use.", pk) + return nil, errors.New( + "Failed to read ssh private key: password protected keys are\n" + + "not supported. Please decrypt the key prior to use.") } signer, err := ssh.ParsePrivateKey([]byte(pk)) if err != nil { - return nil, fmt.Errorf("Failed to parse key file %q: %s", pk, err) + return nil, fmt.Errorf("Failed to parse ssh private key: %s", err) } return ssh.PublicKeys(signer), nil diff --git a/communicator/ssh/provisioner_test.go b/communicator/ssh/provisioner_test.go index 601d53e54d..9eb5af7c8a 100644 --- a/communicator/ssh/provisioner_test.go +++ b/communicator/ssh/provisioner_test.go @@ -14,6 +14,7 @@ func TestProvisioner_connInfo(t *testing.T) { "user": "root", "password": "supersecret", "private_key": "someprivatekeycontents", + "certificate": "somecertificate", "host": "127.0.0.1", "port": "22", "timeout": "30s", @@ -37,6 +38,9 @@ func TestProvisioner_connInfo(t *testing.T) { if conf.PrivateKey != "someprivatekeycontents" { t.Fatalf("bad: %v", conf) } + if conf.Certificate != "somecertificate" { + t.Fatalf("bad: %v", conf) + } if conf.Host != "127.0.0.1" { t.Fatalf("bad: %v", conf) } @@ -74,6 +78,7 @@ func TestProvisioner_connInfoIpv6(t *testing.T) { "user": "root", "password": "supersecret", "private_key": "someprivatekeycontents", + "certificate": "somecertificate", "host": "::1", "port": "22", "timeout": "30s", @@ -101,14 +106,13 @@ func TestProvisioner_connInfoHostname(t *testing.T) { r := &terraform.InstanceState{ Ephemeral: terraform.EphemeralState{ ConnInfo: map[string]string{ - "type": "ssh", - "user": "root", - "password": "supersecret", - "private_key": "someprivatekeycontents", - "host": "example.com", - "port": "22", - "timeout": "30s", - + "type": "ssh", + "user": "root", + "password": "supersecret", + "private_key": "someprivatekeycontents", + "host": "example.com", + "port": "22", + "timeout": "30s", "bastion_host": "example.com", }, },