diff --git a/common/download.go b/common/download.go index 16c0724c3..e4b4dc2e0 100644 --- a/common/download.go +++ b/common/download.go @@ -101,7 +101,7 @@ func (d *DownloadClient) Cancel() { func (d *DownloadClient) Get() (string, error) { // If we already have the file and it matches, then just return the target path. if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { - log.Println("Initial checksum matched, no download needed.") + log.Println("[DEBUG] Initial checksum matched, no download needed.") return d.config.TargetPath, nil } @@ -115,13 +115,19 @@ func (d *DownloadClient) Get() (string, error) { // Files when we don't copy the file are special cased. var f *os.File var finalPath string + sourcePath := "" if url.Scheme == "file" && !d.config.CopyFile { + // This is a special case where we use a source file that already exists + // locally and we don't make a copy. Normally we would copy or download. finalPath = url.Path + log.Printf("[DEBUG] Using local file: %s", finalPath) // Remove forward slash on absolute Windows file URLs before processing if runtime.GOOS == "windows" && len(finalPath) > 0 && finalPath[0] == '/' { finalPath = finalPath[1:len(finalPath)] } + // Keep track of the source so we can make sure not to delete this later + sourcePath = finalPath } else { finalPath = d.config.TargetPath @@ -137,7 +143,7 @@ func (d *DownloadClient) Get() (string, error) { return "", err } - log.Printf("Downloading: %s", url.String()) + log.Printf("[DEBUG] Downloading: %s", url.String()) err = d.downloader.Download(f, url) f.Close() if err != nil { @@ -149,8 +155,10 @@ func (d *DownloadClient) Get() (string, error) { var verify bool verify, err = d.VerifyChecksum(finalPath) if err == nil && !verify { - // Delete the file - os.Remove(finalPath) + // Only delete the file if we made a copy or downloaded it + if sourcePath != finalPath { + os.Remove(finalPath) + } err = fmt.Errorf( "checksums didn't match expected: %s", diff --git a/common/download_test.go b/common/download_test.go index dc5bd29ed..51f6f270c 100644 --- a/common/download_test.go +++ b/common/download_test.go @@ -3,6 +3,7 @@ package common import ( "crypto/md5" "encoding/hex" + "fmt" "io/ioutil" "net/http" "net/http/httptest" @@ -338,3 +339,40 @@ func TestHashForType(t *testing.T) { t.Fatalf("fake hash is not nil") } } + +// TestDownloadFileUrl tests a special case where we use a local file for +// iso_url. In this case we can still verify the checksum but we should not +// delete the file if the checksum fails. Instead we'll just error and let the +// user fix the checksum. +func TestDownloadFileUrl(t *testing.T) { + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("Unable to detect working directory: %s", err) + } + + // source_path is a file path and source is a network path + sourcePath := fmt.Sprintf("%s/test-fixtures/fileurl/%s", cwd, "cake") + source := fmt.Sprintf("file://" + sourcePath) + t.Logf("Trying to download %s", source) + + config := &DownloadConfig{ + Url: source, + // This should be wrong. We want to make sure we don't delete + Checksum: []byte("nope"), + Hash: HashForType("sha256"), + CopyFile: false, + } + + client := NewDownloadClient(config) + + // Verify that we fail to match the checksum + _, err = client.Get() + if err.Error() != "checksums didn't match expected: 6e6f7065" { + t.Fatalf("Unexpected failure; expected checksum not to match") + } + + if _, err = os.Stat(sourcePath); err != nil { + t.Errorf("Could not stat source file: %s", sourcePath) + } + +} diff --git a/common/test-fixtures/fileurl/cake b/common/test-fixtures/fileurl/cake new file mode 100644 index 000000000..e800d1ffb --- /dev/null +++ b/common/test-fixtures/fileurl/cake @@ -0,0 +1 @@ +delicious chocolate cake