|
|
|
|
@ -7,6 +7,7 @@ import (
|
|
|
|
|
"fmt"
|
|
|
|
|
"github.com/mitchellh/packer/packer"
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
"golang.org/x/crypto/ssh/agent"
|
|
|
|
|
"io"
|
|
|
|
|
"io/ioutil"
|
|
|
|
|
"log"
|
|
|
|
|
@ -226,10 +227,59 @@ func (c *comm) reconnect() (err error) {
|
|
|
|
|
if sshConn != nil {
|
|
|
|
|
c.client = ssh.NewClient(sshConn, sshChan, req)
|
|
|
|
|
}
|
|
|
|
|
c.connectToAgent()
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *comm) connectToAgent() {
|
|
|
|
|
if c.client == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// open connection to the local agent
|
|
|
|
|
socketLocation := os.Getenv("SSH_AUTH_SOCK")
|
|
|
|
|
if socketLocation == "" {
|
|
|
|
|
log.Printf("no local agent socket")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
agentConn, err := net.Dial("unix", socketLocation)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("could not connect to local agent socket: %s", socketLocation)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// create agent and add in auth
|
|
|
|
|
forwardingAgent := agent.NewClient(agentConn)
|
|
|
|
|
if forwardingAgent == nil {
|
|
|
|
|
log.Printf("could not create agent client")
|
|
|
|
|
agentConn.Close()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// add callback for forwarding agent to SSH config
|
|
|
|
|
// XXX - might want to handle reconnects appending multiple callbacks
|
|
|
|
|
auth := ssh.PublicKeysCallback(forwardingAgent.Signers)
|
|
|
|
|
c.config.SSHConfig.Auth = append(c.config.SSHConfig.Auth, auth)
|
|
|
|
|
agent.ForwardToAgent(c.client, forwardingAgent)
|
|
|
|
|
|
|
|
|
|
// Setup a session to request agent forwarding
|
|
|
|
|
session, err := c.newSession()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer session.Close()
|
|
|
|
|
|
|
|
|
|
err = agent.RequestAgentForwarding(session)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Printf("RequestAgentForwarding:", err)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.Printf("agent forwarding enabled")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *comm) scpSession(scpCommand string, f func(io.Writer, *bufio.Reader) error) error {
|
|
|
|
|
session, err := c.newSession()
|
|
|
|
|
if err != nil {
|
|
|
|
|
|