diff --git a/CHANGELOG.md b/CHANGELOG.md index 43e19cf48..bf28de474 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ BUG FIXES: * builder/openstack: Properly scrub password from logs [GH-554] * common/uuid: Use cryptographically secure PRNG when generating UUIDs. [GH-552] +* communicator/ssh: File uploads that exceed the size of memory no longer + cause crashes. [GH-561] ## 0.3.10 (October 20, 2013) diff --git a/communicator/ssh/communicator.go b/communicator/ssh/communicator.go index 780dc425a..17a616029 100644 --- a/communicator/ssh/communicator.go +++ b/communicator/ssh/communicator.go @@ -8,6 +8,7 @@ import ( "fmt" "github.com/mitchellh/packer/packer" "io" + "io/ioutil" "log" "net" "os" @@ -362,30 +363,49 @@ func checkSCPStatus(r *bufio.Reader) error { } func scpUploadFile(dst string, src io.Reader, w io.Writer, r *bufio.Reader) error { - // Determine the length of the upload content by copying it - // into an in-memory buffer. Note that this means what we upload - // must fit into memory. - log.Println("Copying input data into in-memory buffer so we can get the length") - inputBuf := new(bytes.Buffer) - if _, err := io.Copy(inputBuf, src); err != nil { + // Create a temporary file where we can copy the contents of the src + // so that we can determine the length, since SCP is length-prefixed. + tf, err := ioutil.TempFile("", "packer-upload") + if err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + defer os.Remove(tf.Name()) + defer tf.Close() + + log.Println("Copying input data into temporary file so we can read the length") + if _, err := io.Copy(tf, src); err != nil { return err } + // Sync the file so that the contents are definitely on disk, then + // read the length of it. + if err := tf.Sync(); err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + + // Seek the file to the beginning so we can re-read all of it + if _, err := tf.Seek(0, 0); err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + + fi, err := tf.Stat() + if err != nil { + return fmt.Errorf("Error creating temporary file for upload: %s", err) + } + // Start the protocol log.Println("Beginning file upload...") - fmt.Fprintln(w, "C0644", inputBuf.Len(), dst) - err := checkSCPStatus(r) - if err != nil { + fmt.Fprintln(w, "C0644", fi.Size(), dst) + if err := checkSCPStatus(r); err != nil { return err } - if _, err := io.Copy(w, inputBuf); err != nil { + if _, err := io.Copy(w, tf); err != nil { return err } fmt.Fprint(w, "\x00") - err = checkSCPStatus(r) - if err != nil { + if err := checkSCPStatus(r); err != nil { return err }