From a0a78b68e85e495636ae04ffe7ea5264d86c0b1a Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Wed, 12 Jun 2013 17:41:44 -0700 Subject: [PATCH] builder/common: Create a downloader --- builder/common/download.go | 186 ++++++++++++++++++++++++++++ builder/common/download_test.go | 42 +++++++ builder/vmware/step_download_iso.go | 144 ++++++--------------- command/build/command.go | 2 +- 4 files changed, 265 insertions(+), 109 deletions(-) create mode 100644 builder/common/download.go create mode 100644 builder/common/download_test.go diff --git a/builder/common/download.go b/builder/common/download.go new file mode 100644 index 000000000..400d61b98 --- /dev/null +++ b/builder/common/download.go @@ -0,0 +1,186 @@ +package common + +import ( + "bytes" + "encoding/hex" + "errors" + "fmt" + "hash" + "io" + "net/http" + "net/url" + "os" +) + +// DownloadConfig is the configuration given to instantiate a new +// download instance. Once a configuration is used to instantiate +// a download client, it must not be modified. +type DownloadConfig struct { + // The source URL in the form of a string. + Url string + + // This is the path to download the file to. + TargetPath string + + // DownloaderMap maps a schema to a Download. + DownloaderMap map[string]Downloader + + // If true, this will copy even a local file to the target + // location. If false, then it will "download" the file by just + // returning the local path to the file. + CopyFile bool + + // The hashing implementation to use to checksum the downloaded file. + Hash hash.Hash + + // The checksum for the downloaded file. The hash implementation configuration + // for the downloader will be used to verify with this checksum after + // it is downloaded. + Checksum []byte +} + +// A DownloadClient helps download, verify checksums, etc. +type DownloadClient struct { + config *DownloadConfig + downloader Downloader +} + +// NewDownloadClient returns a new DownloadClient for the given +// configuration. +func NewDownloadClient(c *DownloadConfig) *DownloadClient { + if c.DownloaderMap == nil { + c.DownloaderMap = map[string]Downloader{ + "http": new(HTTPDownloader), + } + } + + return &DownloadClient{config: c} +} + +// A downloader is responsible for actually taking a remote URL and +// downloading it. +type Downloader interface { + Cancel() + Download(io.Writer, *url.URL) error + Progress() uint + Total() uint +} + +func (d *DownloadClient) Cancel() { + // TODO(mitchellh): Implement +} + +func (d *DownloadClient) Get() (string, error) { + // If we already have the file and it matches, then just return the target path. + if verify, _ := d.VerifyChecksum(d.config.TargetPath); verify { + return d.config.TargetPath, nil + } + + url, err := url.Parse(d.config.Url) + if err != nil { + return "", err + } + + // Files when we don't copy the file are special cased. + var finalPath string + if url.Scheme == "file" && !d.config.CopyFile { + finalPath = url.Path + } else { + var ok bool + d.downloader, ok = d.config.DownloaderMap[url.Scheme] + if !ok { + return "", fmt.Errorf("No downloader for scheme: %s", url.Scheme) + } + + // Otherwise, download using the downloader. + f, err := os.Create(d.config.TargetPath) + if err != nil { + return "", err + } + defer f.Close() + + err = d.downloader.Download(f, url) + } + + if d.config.Hash != nil { + var verify bool + verify, err = d.VerifyChecksum(finalPath) + if err == nil && !verify { + err = fmt.Errorf("checksums didn't match expected: %s", hex.EncodeToString(d.config.Checksum)) + } + } + + return finalPath, err +} + +// PercentProgress returns the download progress as a percentage. +func (d *DownloadClient) PercentProgress() uint { + return uint((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100) +} + +// VerifyChecksum tests that the path matches the checksum for the +// download. +func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { + if d.config.Checksum == nil || d.config.Hash == nil { + return false, errors.New("Checksum or Hash isn't set on download.") + } + + f, err := os.Open(path) + if err != nil { + return false, err + } + defer f.Close() + + d.config.Hash.Reset() + io.Copy(d.config.Hash, f) + return bytes.Compare(d.config.Hash.Sum(nil), d.config.Checksum) == 0, nil +} + +// HTTPDownloader is an implementation of Downloader that downloads +// files over HTTP. +type HTTPDownloader struct { + progress uint + total uint +} + +func (*HTTPDownloader) Cancel() { + // TODO(mitchellh): Implement +} + +func (d *HTTPDownloader) Download(dst io.Writer, src *url.URL) error { + resp, err := http.Get(src.String()) + if err != nil { + return err + } + + d.progress = 0 + d.total = uint(resp.ContentLength) + + var buffer [4096]byte + for { + n, err := resp.Body.Read(buffer[:]) + if err != nil && err != io.EOF { + return err + } + + d.progress += uint(n) + + if _, werr := dst.Write(buffer[:n]); werr != nil { + return werr + } + + if err == io.EOF { + break + } + } + + return nil +} + +func (d *HTTPDownloader) Progress() uint { + return d.progress +} + +func (d *HTTPDownloader) Total() uint { + return d.total +} diff --git a/builder/common/download_test.go b/builder/common/download_test.go new file mode 100644 index 000000000..3e2a59ddf --- /dev/null +++ b/builder/common/download_test.go @@ -0,0 +1,42 @@ +package common + +import ( + "crypto/md5" + "encoding/hex" + "io/ioutil" + "os" + "testing" +) + +func TestDownloadClient_VerifyChecksum(t *testing.T) { + tf, err := ioutil.TempFile("", "packer") + if err != nil { + t.Fatalf("tempfile error: %s", err) + } + defer os.Remove(tf.Name()) + + // "foo" + checksum, err := hex.DecodeString("acbd18db4cc2f85cedef654fccc4a4d8") + if err != nil { + t.Fatalf("decode err: %s", err) + } + + // Write the file + tf.Write([]byte("foo")) + tf.Close() + + config := &DownloadConfig{ + Hash: md5.New(), + Checksum: checksum, + } + + d := NewDownloadClient(config) + result, err := d.VerifyChecksum(tf.Name()) + if err != nil { + t.Fatalf("Verify err: %s", err) + } + + if !result { + t.Fatal("didn't verify") + } +} diff --git a/builder/vmware/step_download_iso.go b/builder/vmware/step_download_iso.go index 4dbe66cb5..02d2afc8f 100644 --- a/builder/vmware/step_download_iso.go +++ b/builder/vmware/step_download_iso.go @@ -5,13 +5,9 @@ import ( "encoding/hex" "fmt" "github.com/mitchellh/multistep" + "github.com/mitchellh/packer/builder/common" "github.com/mitchellh/packer/packer" - "io" "log" - "net/http" - "net/url" - "os" - "strings" "time" ) @@ -31,62 +27,52 @@ func (s stepDownloadISO) Run(state map[string]interface{}) multistep.StepAction config := state["config"].(*config) ui := state["ui"].(packer.Ui) + checksum, err := hex.DecodeString(config.ISOMD5) + if err != nil { + ui.Error(fmt.Sprintf("Error parsing checksum: %s", err)) + return multistep.ActionHalt + } + log.Printf("Acquiring lock to download the ISO.") cachePath := cache.Lock(config.ISOUrl) defer cache.Unlock(config.ISOUrl) - err := s.checkMD5(cachePath, config.ISOMD5) - haveFile := err == nil - if err != nil { - if !os.IsNotExist(err) { - ui.Say(fmt.Sprintf("Error validating MD5 of ISO: %s", err)) - return multistep.ActionHalt - } + downloadConfig := &common.DownloadConfig{ + Url: config.ISOUrl, + TargetPath: cachePath, + CopyFile: false, + Hash: md5.New(), + Checksum: checksum, } - if !haveFile { - url, err := url.Parse(config.ISOUrl) - if err != nil { - ui.Error(fmt.Sprintf("Error parsing iso_url: %s", err)) - return multistep.ActionHalt - } + download := common.NewDownloadClient(downloadConfig) - // Start the download in a goroutine so that we cancel it and such. - var progress uint - downloadComplete := make(chan bool, 1) - go func() { - ui.Say("Copying or downloading ISO. Progress will be shown periodically.") - cachePath, err = s.downloadUrl(cachePath, url, &progress) - downloadComplete <- true - }() + downloadCompleteCh := make(chan error, 1) + go func() { + ui.Say("Copying or downloading ISO. Progress will be reported periodically.") + cachePath, err = download.Get() + downloadCompleteCh <- err + }() - progressTimer := time.NewTicker(15 * time.Second) - defer progressTimer.Stop() + progressTicker := time.NewTicker(5 * time.Second) + defer progressTicker.Stop() - DownloadWaitLoop: - for { - select { - case <-downloadComplete: - log.Println("Download of ISO completed.") - break DownloadWaitLoop - case <-progressTimer.C: - ui.Say(fmt.Sprintf("Download progress: %d%%", progress)) - case <-time.After(1 * time.Second): - if _, ok := state[multistep.StateCancelled]; ok { - ui.Say("Interrupt received. Cancelling download...") - return multistep.ActionHalt - } +DownloadWaitLoop: + for { + select { + case err := <-downloadCompleteCh: + if err != nil { + ui.Error(fmt.Sprintf("Error downloading ISO: %s", err)) } - } - - if err != nil { - ui.Error(fmt.Sprintf("Error downloading ISO: %s", err)) - return multistep.ActionHalt - } - if err = s.checkMD5(cachePath, config.ISOMD5); err != nil { - ui.Say(fmt.Sprintf("Error validating MD5 of ISO: %s", err)) - return multistep.ActionHalt + break DownloadWaitLoop + case <-progressTicker.C: + ui.Say(fmt.Sprintf("Download progress: %d%%", download.PercentProgress())) + case <-time.After(1 * time.Second): + if _, ok := state[multistep.StateCancelled]; ok { + ui.Say("Interrupt received. Cancelling download...") + return multistep.ActionHalt + } } } @@ -97,61 +83,3 @@ func (s stepDownloadISO) Run(state map[string]interface{}) multistep.StepAction } func (stepDownloadISO) Cleanup(map[string]interface{}) {} - -func (stepDownloadISO) checkMD5(path string, expected string) error { - f, err := os.Open(path) - if err != nil { - return err - } - - hash := md5.New() - io.Copy(hash, f) - result := strings.ToLower(hex.EncodeToString(hash.Sum(nil))) - if result != expected { - return fmt.Errorf("result != expected: %s != %s", result, expected) - } - - return nil -} - -func (stepDownloadISO) downloadUrl(path string, url *url.URL, progress *uint) (string, error) { - if url.Scheme == "file" { - // If it is just a file URL, then we already have the ISO - return url.Path, nil - } - - // Otherwise, it is an HTTP URL, and we must download it. - f, err := os.Create(path) - if err != nil { - return "", err - } - defer f.Close() - - log.Printf("Beginning download of ISO: %s", url.String()) - resp, err := http.Get(url.String()) - if err != nil { - return "", err - } - - var buffer [4096]byte - var totalRead int64 - for { - n, err := resp.Body.Read(buffer[:]) - if err != nil && err != io.EOF { - return "", err - } - - totalRead += int64(n) - *progress = uint((float64(totalRead) / float64(resp.ContentLength)) * 100) - - if _, werr := f.Write(buffer[:n]); werr != nil { - return "", werr - } - - if err == io.EOF { - break - } - } - - return path, nil -} diff --git a/command/build/command.go b/command/build/command.go index de5c6d3d5..8775e2e11 100644 --- a/command/build/command.go +++ b/command/build/command.go @@ -102,7 +102,7 @@ func (c Command) Run(env packer.Environment, args []string) int { ui = &packer.ColoredUi{ Color: colors[i%len(colors)], - Ui: env.Ui(), + Ui: env.Ui(), } ui = &packer.PrefixedUi{