diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index 526f34f2f..df3765ff3 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -34,6 +34,11 @@ type CommunicatorStartResponse struct { RemoteCommandAddress string } +type CommunicatorDownloadArgs struct { + Path string + WriterAddress string +} + type CommunicatorUploadArgs struct { Path string ReaderAddress string @@ -93,6 +98,8 @@ func (c *communicator) Start(cmd string) (rc *packer.RemoteCommand, err error) { } func (c *communicator) Upload(path string, r io.Reader) (err error) { + // We need to create a server that can proxy the reader data + // over because we can't simply gob encode an io.Reader readerL := netListenerInRange(portRangeMin, portRangeMax) if readerL == nil { err = errors.New("couldn't allocate listener for upload reader") @@ -110,16 +117,32 @@ func (c *communicator) Upload(path string, r io.Reader) (err error) { readerL.Addr().String(), } - cerr := c.client.Call("Communicator.Upload", &args, &err) - if cerr != nil { - err = cerr - } - + err = c.client.Call("Communicator.Upload", &args, new(interface{})) return } -func (c *communicator) Download(string, io.Writer) error { - return nil +func (c *communicator) Download(path string, w io.Writer) (err error) { + // We need to create a server that can proxy that data downloaded + // into the writer because we can't gob encode a writer directly. + writerL := netListenerInRange(portRangeMin, portRangeMax) + if writerL == nil { + err = errors.New("couldn't allocate listener for download writer") + return + } + + // Make sure we close the listener once we're done because we'll be done + defer writerL.Close() + + // Serve a single connection and a single copy + go serveSingleCopy("downloadWriter", writerL, w, nil) + + args := CommunicatorDownloadArgs{ + path, + writerL.Addr().String(), + } + + err = c.client.Call("Communicator.Download", &args, new(interface{})) + return } func (c *CommunicatorServer) Start(cmd *string, reply *CommunicatorStartResponse) (err error) { @@ -172,6 +195,18 @@ func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interfa return } +func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) { + writerC, err := net.Dial("tcp", args.WriterAddress) + if err != nil { + return + } + + defer writerC.Close() + + err = c.c.Download(args.Path, writerC) + return +} + func (rc *RemoteCommandServer) Wait(args *interface{}, reply *int) error { rc.rc.Wait() *reply = rc.rc.ExitStatus diff --git a/packer/rpc/communicator_test.go b/packer/rpc/communicator_test.go index fb783c803..a1b066cda 100644 --- a/packer/rpc/communicator_test.go +++ b/packer/rpc/communicator_test.go @@ -21,7 +21,10 @@ type testCommunicator struct { uploadCalled bool uploadPath string - uploadReader io.Reader + uploadData string + + downloadCalled bool + downloadPath string } func (t *testCommunicator) Start(cmd string) (*packer.RemoteCommand, error) { @@ -49,15 +52,18 @@ func (t *testCommunicator) Start(cmd string) (*packer.RemoteCommand, error) { return rc, nil } -func (t *testCommunicator) Upload(path string, reader io.Reader) error { +func (t *testCommunicator) Upload(path string, reader io.Reader) (err error) { t.uploadCalled = true t.uploadPath = path - t.uploadReader = reader - - return nil + t.uploadData, err = bufio.NewReader(reader).ReadString('\n') + return } -func (t *testCommunicator) Download(string, io.Writer) error { +func (t *testCommunicator) Download(path string, writer io.Writer) error { + t.downloadCalled = true + t.downloadPath = path + writer.Write([]byte("download\n")) + return nil } @@ -110,18 +116,33 @@ func TestCommunicatorRPC(t *testing.T) { // Test that we can upload things uploadR, uploadW := io.Pipe() + go uploadW.Write([]byte("uploadfoo\n")) err = remote.Upload("foo", uploadR) assert.Nil(err, "should not error") assert.True(c.uploadCalled, "should be called") assert.Equal(c.uploadPath, "foo", "should be correct path") - return + assert.Equal(c.uploadData, "uploadfoo\n", "should have the proper data") + + // Test that we can download things + downloadR, downloadW := io.Pipe() + downloadDone := make(chan bool) + var downloadData string + var downloadErr error - // Test the upload reader - uploadW.Write([]byte("uploadfoo\n")) - bufUpR := bufio.NewReader(c.uploadReader) - data, err = bufUpR.ReadString('\n') + go func() { + bufDownR := bufio.NewReader(downloadR) + downloadData, downloadErr = bufDownR.ReadString('\n') + downloadDone <- true + }() + + err = remote.Download("bar", downloadW) assert.Nil(err, "should not error") - assert.Equal(data, "uploadfoo\n", "should have the proper data") + assert.True(c.downloadCalled, "should have called download") + assert.Equal(c.downloadPath, "bar", "should have correct download path") + + <-downloadDone + assert.Nil(downloadErr, "should not error reading download data") + assert.Equal(downloadData, "download\n", "should have the proper data") } func TestCommunicator_ImplementsCommunicator(t *testing.T) {