diff --git a/common/download.go b/common/download.go index 81102069c..6cd2a7a15 100644 --- a/common/download.go +++ b/common/download.go @@ -99,8 +99,6 @@ func (d *DownloadClient) Cancel() { } func (d *DownloadClient) Get() (string, error) { - var f *os.File - // If we already have the file and it matches, then just return the target path. if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { log.Println("Initial checksum matched, no download needed.") @@ -115,6 +113,7 @@ func (d *DownloadClient) Get() (string, error) { log.Printf("Parsed URL: %#v", url) // Files when we don't copy the file are special cased. + var f *os.File var finalPath string if url.Scheme == "file" && !d.config.CopyFile { finalPath = url.Path @@ -199,6 +198,15 @@ func (*HTTPDownloader) Cancel() { func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { log.Printf("Starting download: %s", src.String()) + + // Seek to the beginning by default + if _, err := dst.Seek(0, 0); err != nil { + return err + } + + // Make the request. We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. req, err := http.NewRequest("HEAD", src.String(), nil) if err != nil { return err @@ -215,41 +223,21 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { } resp, err := httpClient.Do(req) - if err != nil || resp.StatusCode != 200 { - req.Method = "GET" - resp, err = httpClient.Do(req) - if err != nil { - return err - } - } - - if resp.StatusCode != 200 { - log.Printf( - "Non-200 status code: %d. Getting error body.", resp.StatusCode) - if req.Method != "GET" { - req.Method = "GET" - resp, err = httpClient.Do(req) - if err != nil { - return err + if err == nil && (resp.StatusCode >= 200 && resp.StatusCode < 300) { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if resp.Header.Get("Accept-Ranges") == "bytes" { + if fi, err := dst.Stat(); err == nil { + if _, err = dst.Seek(0, os.SEEK_END); err == nil { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) + d.progress = uint(fi.Size()) + } } } - errorBody := new(bytes.Buffer) - io.Copy(errorBody, resp.Body) - return fmt.Errorf("HTTP error '%d'! Remote side responded:\n%s", - resp.StatusCode, errorBody.String()) } + // Set the request to GET now, and redo the query to download req.Method = "GET" - d.progress = 0 - - if resp.Header.Get("Accept-Ranges") == "bytes" { - if fi, err := dst.Stat(); err == nil { - if _, err = dst.Seek(0, os.SEEK_END); err == nil { - req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) - d.progress = uint(fi.Size()) - } - } - } resp, err = httpClient.Do(req) if err != nil { @@ -257,7 +245,6 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { } d.total = uint(resp.ContentLength) - var buffer [4096]byte for { n, err := resp.Body.Read(buffer[:]) diff --git a/common/download_test.go b/common/download_test.go index effbf0059..dc5bd29ed 100644 --- a/common/download_test.go +++ b/common/download_test.go @@ -161,6 +161,41 @@ func TestDownloadClient_checksumNoDownload(t *testing.T) { } } +func TestDownloadClient_resume(t *testing.T) { + tf, _ := ioutil.TempFile("", "packer") + tf.Write([]byte("w")) + tf.Close() + + ts := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + rw.Header().Set("Accept-Ranges", "bytes") + rw.WriteHeader(204) + return + } + + http.ServeFile(rw, r, "./test-fixtures/root/basic.txt") + })) + defer ts.Close() + + client := NewDownloadClient(&DownloadConfig{ + Url: ts.URL, + TargetPath: tf.Name(), + }) + path, err := client.Get() + if err != nil { + t.Fatalf("err: %s", err) + } + + raw, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("err: %s", err) + } + + if string(raw) != "wello\n" { + t.Fatalf("bad: %s", string(raw)) + } +} + func TestDownloadClient_usesDefaultUserAgent(t *testing.T) { tf, err := ioutil.TempFile("", "packer") if err != nil {