From fd7cb47adcf8d8c40204dfdb65ca5cf9da2303ed Mon Sep 17 00:00:00 2001 From: Adrien Delorme Date: Wed, 5 Sep 2018 12:50:53 +0200 Subject: [PATCH] use proxy reader for download progress & stop storing total/current in downloaders --- common/download.go | 118 ++++++++++---------------------- packer/progressbar.go | 50 ++++++++++++-- packer/rpc/ui.go | 15 ++-- provisioner/file/provisioner.go | 2 +- 4 files changed, 90 insertions(+), 95 deletions(-) diff --git a/common/download.go b/common/download.go index 21ff25165..de018f9bd 100644 --- a/common/download.go +++ b/common/download.go @@ -85,10 +85,10 @@ func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClien // Create downloader map if it hasn't been specified already. if c.DownloaderMap == nil { c.DownloaderMap = map[string]Downloader{ - "file": &FileDownloader{progress: bar, bufferSize: nil}, - "http": &HTTPDownloader{progress: bar, userAgent: c.UserAgent}, - "https": &HTTPDownloader{progress: bar, userAgent: c.UserAgent}, - "smb": &SMBDownloader{progress: bar, bufferSize: nil}, + "file": &FileDownloader{progressBar: bar, bufferSize: nil}, + "http": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent}, + "https": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent}, + "smb": &SMBDownloader{progressBar: bar, bufferSize: nil}, } } return &DownloadClient{config: c} @@ -99,8 +99,7 @@ func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClien type Downloader interface { Resume() Cancel() - Progress() uint64 - Total() uint64 + ProgressBar() packer.ProgressBar } // A LocalDownloader is responsible for converting a uri to a local path @@ -226,11 +225,9 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { // HTTPDownloader is an implementation of Downloader that downloads // files over HTTP. type HTTPDownloader struct { - current uint64 - total uint64 userAgent string - progress packer.ProgressBar + progressBar packer.ProgressBar } func (d *HTTPDownloader) Cancel() { @@ -249,8 +246,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return err } - // Reset our progress - d.current = 0 + var current uint64 // Make the request. We first make a HEAD request so we can check // if the server supports range queries. If the server/URL doesn't @@ -294,7 +290,7 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { if _, err = dst.Seek(0, os.SEEK_END); err == nil { req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) - d.current = uint64(fi.Size()) + current = uint64(fi.Size()) } } } @@ -321,24 +317,21 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return fmt.Errorf("HTTP error: %s", err.Error()) } - d.total = d.current + uint64(resp.ContentLength) + total := current + uint64(resp.ContentLength) - bar := d.progress - log.Printf("this %#v", bar) - log.Printf("that") - bar.Start(d.total) - bar.Set(d.current) + bar := d.ProgressBar() + bar.Start(total) + bar.Add(current) + + body := bar.NewProxyReader(resp.Body) var buffer [4096]byte for { - n, err := resp.Body.Read(buffer[:]) + n, err := body.Read(buffer[:]) if err != nil && err != io.EOF { return err } - d.current += uint64(n) - bar.Set(d.current) - if _, werr := dst.Write(buffer[:n]); werr != nil { return werr } @@ -351,32 +344,13 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { return nil } -func (d *HTTPDownloader) Progress() uint64 { - return d.current -} - -func (d *HTTPDownloader) Total() uint64 { - return d.total -} - // FileDownloader is an implementation of Downloader that downloads // files using the regular filesystem. type FileDownloader struct { bufferSize *uint - active bool - current uint64 - total uint64 - - progress packer.ProgressBar -} - -func (d *FileDownloader) Progress() uint64 { - return d.current -} - -func (d *FileDownloader) Total() uint64 { - return d.total + active bool + progressBar packer.ProgressBar } func (d *FileDownloader) Cancel() { @@ -443,7 +417,6 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { } /* download the file using the operating system's facilities */ - d.current = 0 d.active = true f, err := os.Open(realpath) @@ -457,38 +430,31 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { if err != nil { return err } - d.total = uint64(fi.Size()) - bar := d.progress + bar := d.ProgressBar() - bar.Start(d.total) - bar.Set(d.current) + bar.Start(uint64(fi.Size())) + fProxy := bar.NewProxyReader(f) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { - var n int64 - n, err = io.Copy(dst, f) + _, err = io.Copy(dst, fProxy) d.active = false - d.current += uint64(n) - bar.Set(d.current) - // use a goro in case someone else wants to enable cancel/resume } else { errch := make(chan error) go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) { for d.active { - n, err := io.CopyN(w, r, int64(*d.bufferSize)) + _, err := io.CopyN(w, r, int64(*d.bufferSize)) if err != nil { break } - d.current += uint64(n) - bar.Set(d.current) } d.active = false e <- err - }(d, f, dst, errch) + }(d, fProxy, dst, errch) // ...and we spin until it's done err = <-errch @@ -503,19 +469,8 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { type SMBDownloader struct { bufferSize *uint - active bool - current uint64 - total uint64 - - progress packer.ProgressBar -} - -func (d *SMBDownloader) Progress() uint64 { - return d.current -} - -func (d *SMBDownloader) Total() uint64 { - return d.total + active bool + progressBar packer.ProgressBar } func (d *SMBDownloader) Cancel() { @@ -564,7 +519,6 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { } /* Open up the "\\"-prefixed path using the Windows filesystem */ - d.current = 0 d.active = true f, err := os.Open(realpath) @@ -578,37 +532,31 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { if err != nil { return err } - d.total = uint64(fi.Size()) - bar := d.progress + bar := d.ProgressBar() - bar.Start(d.current) + bar.Start(uint64(fi.Size())) + fProxy := bar.NewProxyReader(f) // no bufferSize specified, so copy synchronously. if d.bufferSize == nil { - var n int64 - n, err = io.Copy(dst, f) + _, err = io.Copy(dst, fProxy) d.active = false - d.current += uint64(n) - bar.Set(d.current) - // use a goro in case someone else wants to enable cancel/resume } else { errch := make(chan error) go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) { for d.active { - n, err := io.CopyN(w, r, int64(*d.bufferSize)) + _, err := io.CopyN(w, r, int64(*d.bufferSize)) if err != nil { break } - d.current += uint64(n) - bar.Set(d.current) } d.active = false e <- err - }(d, f, dst, errch) + }(d, fProxy, dst, errch) // ...and as usual we spin until it's done err = <-errch @@ -617,3 +565,7 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { f.Close() return err } + +func (d *HTTPDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } +func (d *FileDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } +func (d *SMBDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } diff --git a/packer/progressbar.go b/packer/progressbar.go index 1f3784fc1..5a70f25a4 100644 --- a/packer/progressbar.go +++ b/packer/progressbar.go @@ -1,6 +1,8 @@ package packer import ( + "io" + "github.com/cheggaaa/pb" ) @@ -9,7 +11,8 @@ import ( // No-op When in machine readable mode. type ProgressBar interface { Start(total uint64) - Set(current uint64) + Add(current uint64) + NewProxyReader(r io.Reader) (proxy io.Reader) Finish() } @@ -22,8 +25,20 @@ func (bpb *BasicProgressBar) Start(total uint64) { bpb.ProgressBar.Start() } -func (bpb *BasicProgressBar) Set(current uint64) { - bpb.ProgressBar.Set64(int64(current)) +func (bpb *BasicProgressBar) Add(current uint64) { + bpb.ProgressBar.Add64(int64(current)) +} +func (bpb *BasicProgressBar) NewProxyReader(r io.Reader) io.Reader { + return &ProxyReader{ + Reader: r, + ProgressBar: bpb, + } +} +func (bpb *BasicProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser { + return &ProxyReader{ + Reader: r, + ProgressBar: bpb, + } } var _ ProgressBar = new(BasicProgressBar) @@ -32,8 +47,31 @@ var _ ProgressBar = new(BasicProgressBar) type NoopProgressBar struct { } -func (bpb *NoopProgressBar) Start(_ uint64) {} -func (bpb *NoopProgressBar) Set(_ uint64) {} -func (bpb *NoopProgressBar) Finish() {} +func (npb *NoopProgressBar) Start(uint64) {} +func (npb *NoopProgressBar) Add(uint64) {} +func (npb *NoopProgressBar) Finish() {} +func (npb *NoopProgressBar) NewProxyReader(r io.Reader) io.Reader { return r } +func (npb *NoopProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser { return r } var _ ProgressBar = new(NoopProgressBar) + +// ProxyReader implements io.ReadCloser but sends +// count of read bytes to progress bar +type ProxyReader struct { + io.Reader + ProgressBar +} + +func (r *ProxyReader) Read(p []byte) (n int, err error) { + n, err = r.Reader.Read(p) + r.ProgressBar.Add(uint64(n)) + return +} + +// Close the reader if it implements io.Closer +func (r *ProxyReader) Close() (err error) { + if closer, ok := r.Reader.(io.Closer); ok { + return closer.Close() + } + return +} diff --git a/packer/rpc/ui.go b/packer/rpc/ui.go index 50dce61fa..c1ed2435f 100644 --- a/packer/rpc/ui.go +++ b/packer/rpc/ui.go @@ -1,6 +1,7 @@ package rpc import ( + "io" "log" "math/rand" "net/rpc" @@ -88,14 +89,18 @@ func (pb *RemoteProgressBarClient) Start(total uint64) { pb.client.Call(pb.id+".Start", total, new(interface{})) } -func (pb *RemoteProgressBarClient) Set(current uint64) { - pb.client.Call(pb.id+".Set", current, new(interface{})) +func (pb *RemoteProgressBarClient) Add(current uint64) { + pb.client.Call(pb.id+".Add", current, new(interface{})) } func (pb *RemoteProgressBarClient) Finish() { pb.client.Call(pb.id+".Finish", nil, new(interface{})) } +func (pb *RemoteProgressBarClient) NewProxyReader(r io.Reader) io.Reader { + return &packer.ProxyReader{Reader: r, ProgressBar: pb} +} + func (u *UiServer) Ask(query string, reply *string) (err error) { *reply, err = u.ui.Ask(query) return @@ -128,7 +133,7 @@ func (u *UiServer) Say(message *string, reply *interface{}) error { return nil } -func RandStringBytes(n int) string { +func RandStringBytes(n int) string { // TODO(azr): remove before merging const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" b := make([]byte, n) @@ -167,7 +172,7 @@ func (pb *RemoteProgressBarServer) Start(total uint64, _ *interface{}) error { return nil } -func (pb *RemoteProgressBarServer) Set(current uint64, _ *interface{}) error { - pb.pb.Set(current) +func (pb *RemoteProgressBarServer) Add(current uint64, _ *interface{}) error { + pb.pb.Add(current) return nil } diff --git a/provisioner/file/provisioner.go b/provisioner/file/provisioner.go index a05c2ed9c..474e87596 100644 --- a/provisioner/file/provisioner.go +++ b/provisioner/file/provisioner.go @@ -177,7 +177,7 @@ func (p *Provisioner) ProvisionUpload(ui packer.Ui, comm packer.Communicator) er // Get a default progress bar bar := ui.ProgressBar() - bar.Start(0) + bar.Start(uint64(info.Size())) defer bar.Finish() // Create ProxyReader for the current progress