diff --git a/common/config.go b/common/config.go index 9527af55b..70ebc8452 100644 --- a/common/config.go +++ b/common/config.go @@ -43,7 +43,7 @@ func SupportedProtocol(u *url.URL) bool { // build a dummy NewDownloadClient since this is the only place that valid // protocols are actually exposed. - cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopProgressBar)) + cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopUi)) // Iterate through each downloader to see if a protocol was found. ok := false @@ -175,7 +175,7 @@ func FileExistsLocally(original string) bool { // First create a dummy downloader so we can figure out which // protocol to use. - cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopProgressBar)) + cli := NewDownloadClient(&DownloadConfig{}, new(packer.NoopUi)) d, ok := cli.config.DownloaderMap[u.Scheme] if !ok { return false diff --git a/common/download.go b/common/download.go index de018f9bd..39ffa5e17 100644 --- a/common/download.go +++ b/common/download.go @@ -56,8 +56,7 @@ type DownloadConfig struct { // A DownloadClient helps download, verify checksums, etc. type DownloadClient struct { - config *DownloadConfig - downloader Downloader + config *DownloadConfig } // HashForType returns the Hash implementation for the given string @@ -79,23 +78,21 @@ func HashForType(t string) hash.Hash { // NewDownloadClient returns a new DownloadClient for the given // configuration. -func NewDownloadClient(c *DownloadConfig, bar packer.ProgressBar) *DownloadClient { - const mtu = 1500 /* ethernet */ - 20 /* ipv4 */ - 20 /* tcp */ - +func NewDownloadClient(c *DownloadConfig, ui packer.Ui) *DownloadClient { // Create downloader map if it hasn't been specified already. if c.DownloaderMap == nil { + log.Printf("instantiating. ui: %#v", ui) c.DownloaderMap = map[string]Downloader{ - "file": &FileDownloader{progressBar: bar, bufferSize: nil}, - "http": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent}, - "https": &HTTPDownloader{progressBar: bar, userAgent: c.UserAgent}, - "smb": &SMBDownloader{progressBar: bar, bufferSize: nil}, + "file": &FileDownloader{Ui: ui, bufferSize: nil}, + "http": &HTTPDownloader{Ui: ui, userAgent: c.UserAgent}, + "https": &HTTPDownloader{Ui: ui, userAgent: c.UserAgent}, + "smb": &SMBDownloader{Ui: ui, bufferSize: nil}, } } return &DownloadClient{config: c} } -// A downloader implements the ability to transfer a file, and cancel or resume -// it. +// Downloader defines what capabilities a downloader should have. type Downloader interface { Resume() Cancel() @@ -142,17 +139,18 @@ func (d *DownloadClient) Get() (string, error) { var finalPath string var ok bool - d.downloader, ok = d.config.DownloaderMap[u.Scheme] + downloader, ok := d.config.DownloaderMap[u.Scheme] if !ok { return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme) } + log.Printf("downloader: %#v", downloader) - remote, ok := d.downloader.(RemoteDownloader) + remote, ok := downloader.(RemoteDownloader) if !ok { - return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, d.downloader) + return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader. : %T", u.Scheme, downloader) } - local, ok := d.downloader.(LocalDownloader) + local, ok := downloader.(LocalDownloader) if !ok && !d.config.CopyFile { d.config.CopyFile = true } @@ -167,7 +165,6 @@ func (d *DownloadClient) Get() (string, error) { return "", err } - log.Printf("[DEBUG] Downloading: %s", u.String()) err = remote.Download(f, u) f.Close() if err != nil { @@ -227,7 +224,7 @@ func (d *DownloadClient) VerifyChecksum(path string) (bool, error) { type HTTPDownloader struct { userAgent string - progressBar packer.ProgressBar + Ui packer.Ui } func (d *HTTPDownloader) Cancel() { @@ -349,8 +346,8 @@ func (d *HTTPDownloader) Download(dst *os.File, src *url.URL) error { type FileDownloader struct { bufferSize *uint - active bool - progressBar packer.ProgressBar + active bool + Ui packer.Ui } func (d *FileDownloader) Cancel() { @@ -469,8 +466,8 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error { type SMBDownloader struct { bufferSize *uint - active bool - progressBar packer.ProgressBar + active bool + Ui packer.Ui } func (d *SMBDownloader) Cancel() { @@ -566,6 +563,6 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error { return err } -func (d *HTTPDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } -func (d *FileDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } -func (d *SMBDownloader) ProgressBar() packer.ProgressBar { return d.progressBar } +func (d *HTTPDownloader) ProgressBar() packer.ProgressBar { return d.Ui.ProgressBar() } +func (d *FileDownloader) ProgressBar() packer.ProgressBar { return d.Ui.ProgressBar() } +func (d *SMBDownloader) ProgressBar() packer.ProgressBar { return d.Ui.ProgressBar() } diff --git a/common/step_download.go b/common/step_download.go index 26f6d6b65..16da1626f 100644 --- a/common/step_download.go +++ b/common/step_download.go @@ -63,9 +63,6 @@ func (s *StepDownload) Run(_ context.Context, state multistep.StateBag) multiste ui.Say(fmt.Sprintf("Retrieving %s", s.Description)) - // Get a progress bar from the ui so we can hand it off to the download client - bar := ui.ProgressBar() - // First try to use any already downloaded file // If it fails, proceed to regular download logic @@ -99,7 +96,7 @@ func (s *StepDownload) Run(_ context.Context, state multistep.StateBag) multiste } downloadConfigs[i] = config - if match, _ := NewDownloadClient(config, bar).VerifyChecksum(config.TargetPath); match { + if match, _ := NewDownloadClient(config, ui).VerifyChecksum(config.TargetPath); match { ui.Message(fmt.Sprintf("Found already downloaded, initial checksum matched, no download needed: %s", url)) finalPath = config.TargetPath break @@ -141,14 +138,14 @@ func (s *StepDownload) Cleanup(multistep.StateBag) {} func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag) (string, error, bool) { var path string - ui := state.Get("ui").(packer.Ui) - - // Get a progress bar and hand it off to the download client - bar := ui.ProgressBar() - log.Printf("new progress bar: %#v, %t", bar, bar == nil) + v, ok := state.GetOk("ui") + if !ok { + return "", nil, false + } + ui := v.(packer.Ui) - // Create download client with config and progress bar - download := NewDownloadClient(config, bar) + // Create download client with config + download := NewDownloadClient(config, ui) downloadCompleteCh := make(chan error, 1) go func() { @@ -160,7 +157,6 @@ func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag for { select { case err := <-downloadCompleteCh: - bar.Finish() if err != nil { return "", err, true @@ -175,7 +171,6 @@ func (s *StepDownload) download(config *DownloadConfig, state multistep.StateBag case <-time.After(1 * time.Second): if _, ok := state.GetOk(multistep.StateCancelled); ok { - bar.Finish() ui.Say("Interrupt received. Cancelling download...") return "", nil, false } diff --git a/packer/progressbar.go b/packer/progressbar.go index 5a70f25a4..2a1ed97d6 100644 --- a/packer/progressbar.go +++ b/packer/progressbar.go @@ -2,13 +2,14 @@ package packer import ( "io" + "sync" + "sync/atomic" "github.com/cheggaaa/pb" ) // ProgressBar allows to graphically display // a self refreshing progress bar. -// No-op When in machine readable mode. type ProgressBar interface { Start(total uint64) Add(current uint64) @@ -16,10 +17,46 @@ type ProgressBar interface { Finish() } +type StackableProgressBar struct { + total uint64 + started bool + BasicProgressBar + startOnce sync.Once + group sync.WaitGroup +} + +var _ ProgressBar = new(StackableProgressBar) + +func (spb *StackableProgressBar) start() { + spb.BasicProgressBar.ProgressBar = pb.New(0) + spb.BasicProgressBar.ProgressBar.SetUnits(pb.U_BYTES) + + spb.BasicProgressBar.ProgressBar.Start() + go func() { + spb.group.Wait() + spb.BasicProgressBar.ProgressBar.Finish() + spb.startOnce = sync.Once{} + spb.BasicProgressBar.ProgressBar = nil + }() +} + +func (spb *StackableProgressBar) Start(total uint64) { + atomic.AddUint64(&spb.total, total) + spb.group.Add(1) + spb.startOnce.Do(spb.start) + spb.SetTotal64(int64(spb.total)) +} + +func (spb *StackableProgressBar) Finish() { + spb.group.Done() +} + type BasicProgressBar struct { *pb.ProgressBar } +var _ ProgressBar = new(BasicProgressBar) + func (bpb *BasicProgressBar) Start(total uint64) { bpb.SetTotal64(int64(total)) bpb.ProgressBar.Start() @@ -41,20 +78,18 @@ func (bpb *BasicProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser { } } -var _ ProgressBar = new(BasicProgressBar) - // NoopProgressBar is a silent progress bar type NoopProgressBar struct { } +var _ ProgressBar = new(NoopProgressBar) + func (npb *NoopProgressBar) Start(uint64) {} func (npb *NoopProgressBar) Add(uint64) {} func (npb *NoopProgressBar) Finish() {} func (npb *NoopProgressBar) NewProxyReader(r io.Reader) io.Reader { return r } func (npb *NoopProgressBar) NewProxyReadCloser(r io.ReadCloser) io.ReadCloser { return r } -var _ ProgressBar = new(NoopProgressBar) - // ProxyReader implements io.ReadCloser but sends // count of read bytes to progress bar type ProxyReader struct { diff --git a/packer/ui.go b/packer/ui.go index e9a98b24c..340c56827 100644 --- a/packer/ui.go +++ b/packer/ui.go @@ -15,8 +15,6 @@ import ( "syscall" "time" "unicode" - - "github.com/cheggaaa/pb" ) type UiColor uint @@ -42,6 +40,17 @@ type Ui interface { ProgressBar() ProgressBar } +type NoopUi struct{} + +var _ Ui = new(NoopUi) + +func (*NoopUi) Ask(string) (string, error) { return "", errors.New("this is a noop ui") } +func (*NoopUi) Say(string) { return } +func (*NoopUi) Message(string) { return } +func (*NoopUi) Error(string) { return } +func (*NoopUi) Machine(string, ...string) { return } +func (*NoopUi) ProgressBar() ProgressBar { return new(NoopProgressBar) } + // ColoredUi is a UI that is colored using terminal colors. type ColoredUi struct { Color UiColor @@ -73,13 +82,13 @@ type BasicUi struct { l sync.Mutex interrupted bool scanner *bufio.Scanner + StackableProgressBar } var _ Ui = new(BasicUi) func (bu *BasicUi) ProgressBar() ProgressBar { - log.Printf("hehey !") - return &BasicProgressBar{ProgressBar: pb.New(0)} + return &bu.StackableProgressBar } // MachineReadableUi is a UI that only outputs machine-readable output