pull/2906/head
Ali Rizvi-Santiago 9 years ago
parent 11ff4439a6
commit b1ff14714b

@ -2,11 +2,11 @@ package common
import (
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"net/url"
)
// PackerKeyEnv is used to specify the key interval (delay) between keystrokes
@ -51,7 +51,7 @@ func DownloadableURL(original string) (string, error) {
supported := []string{"file", "http", "https", "ftp", "smb"}
found := false
for _, s := range supported {
if strings.HasPrefix(strings.ToLower(original), s + "://") {
if strings.HasPrefix(strings.ToLower(original), s+"://") {
found = true
break
}
@ -63,8 +63,10 @@ func DownloadableURL(original string) (string, error) {
original = filepath.ToSlash(original)
// make sure that it can be parsed though..
uri,err := url.Parse(original)
if err != nil { return "", err }
uri, err := url.Parse(original)
if err != nil {
return "", err
}
uri.Scheme = strings.ToLower(uri.Scheme)
@ -72,16 +74,20 @@ func DownloadableURL(original string) (string, error) {
}
// If the file exists, then make it an absolute path
_,err := os.Stat(original)
_, err := os.Stat(original)
if err == nil {
original, err = filepath.Abs(filepath.FromSlash(original))
if err != nil { return "", err }
if err != nil {
return "", err
}
original, err = filepath.EvalSymlinks(original)
if err != nil { return "", err }
if err != nil {
return "", err
}
original = filepath.Clean(original)
original = filepath.ToSlash(original)
original = filepath.ToSlash(original)
}
// Since it wasn't properly prefixed, let's make it into a well-formed

@ -13,21 +13,20 @@ import (
"log"
"net/url"
"os"
"runtime"
"path"
"runtime"
"strings"
)
// imports related to each Downloader implementation
import (
"github.com/jlaffaye/ftp"
"gopkg.in/cheggaaa/pb.v1"
"io"
"path/filepath"
"net/http"
"gopkg.in/cheggaaa/pb.v1"
"github.com/jlaffaye/ftp"
"path/filepath"
)
// DownloadConfig is the configuration given to instantiate a new
// download instance. Once a configuration is used to instantiate
// a download client, it must not be modified.
@ -112,7 +111,7 @@ type Downloader interface {
// A LocalDownloader is responsible for converting a uri to a local path
// that the platform can open directly.
type LocalDownloader interface {
toPath(string, url.URL) (string,error)
toPath(string, url.URL) (string, error)
}
// A RemoteDownloader is responsible for actually taking a remote URL and
@ -134,12 +133,16 @@ func (d *DownloadClient) Get() (string, error) {
/* parse the configuration url into a net/url object */
u, err := url.Parse(d.config.Url)
if err != nil { return "", err }
if err != nil {
return "", err
}
log.Printf("Parsed URL: %#v", u)
/* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd()
if err != nil { return "", err }
cwd, err := os.Getwd()
if err != nil {
return "", err
}
// Determine which is the correct downloader to use
var finalPath string
@ -150,13 +153,13 @@ func (d *DownloadClient) Get() (string, error) {
return "", fmt.Errorf("No downloader for scheme: %s", u.Scheme)
}
remote,ok := d.downloader.(RemoteDownloader)
remote, ok := d.downloader.(RemoteDownloader)
if !ok {
return "", fmt.Errorf("Unable to treat uri scheme %s as a Downloader : %T", u.Scheme, d.downloader)
}
local,ok := d.downloader.(LocalDownloader)
if !ok && !d.config.CopyFile{
local, ok := d.downloader.(LocalDownloader)
if !ok && !d.config.CopyFile {
return "", fmt.Errorf("Not allowed to use uri scheme %s in no copy file mode : %T", u.Scheme, d.downloader)
d.config.CopyFile = true
}
@ -167,18 +170,24 @@ func (d *DownloadClient) Get() (string, error) {
finalPath = d.config.TargetPath
f, err = os.OpenFile(finalPath, os.O_RDWR|os.O_CREATE, os.FileMode(0666))
if err != nil { return "", err }
if err != nil {
return "", err
}
log.Printf("[DEBUG] Downloading: %s", u.String())
err = remote.Download(f, u)
f.Close()
if err != nil { return "", err }
if err != nil {
return "", err
}
// Otherwise if our Downloader is a LocalDownloader we can just use the
// path after transforming it.
// Otherwise if our Downloader is a LocalDownloader we can just use the
// path after transforming it.
} else {
finalPath,err = local.toPath(cwd, *u)
if err != nil { return "", err }
finalPath, err = local.toPath(cwd, *u)
if err != nil {
return "", err
}
log.Printf("[DEBUG] Using local file: %s", finalPath)
}
@ -188,7 +197,9 @@ func (d *DownloadClient) Get() (string, error) {
verify, err = d.VerifyChecksum(finalPath)
if err == nil && !verify {
// Only delete the file if we made a copy or downloaded it
if d.config.CopyFile { os.Remove(finalPath) }
if d.config.CopyFile {
os.Remove(finalPath)
}
err = fmt.Errorf(
"checksums didn't match expected: %s",
@ -201,7 +212,9 @@ func (d *DownloadClient) Get() (string, error) {
// PercentProgress returns the download progress as a percentage.
func (d *DownloadClient) PercentProgress() int {
if d.downloader == nil { return -1 }
if d.downloader == nil {
return -1
}
return int((float64(d.downloader.Progress()) / float64(d.downloader.Total())) * 100)
}
@ -330,11 +343,11 @@ func (d *HTTPDownloader) Total() uint64 {
// files over FTP.
type FTPDownloader struct {
userInfo *url.Userinfo
mtu uint
mtu uint
active bool
active bool
progress uint64
total uint64
total uint64
}
func (d *FTPDownloader) Progress() uint64 {
@ -369,40 +382,52 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
var cli *ftp.ServerConn
log.Printf("Starting download over FTP: %s : %s\n", uri.Host, uri.Path)
cli,err := ftp.Dial(uri.Host)
if err != nil { return nil }
cli, err := ftp.Dial(uri.Host)
if err != nil {
return nil
}
defer cli.Quit()
// handle authentication
if uri.User != nil { userinfo = uri.User }
if uri.User != nil {
userinfo = uri.User
}
pass,ok := userinfo.Password()
if !ok { pass = "ftp@" }
pass, ok := userinfo.Password()
if !ok {
pass = "ftp@"
}
log.Printf("Authenticating to FTP server: %s : %s\n", userinfo.Username(), pass)
err = cli.Login(userinfo.Username(), pass)
if err != nil { return err }
if err != nil {
return err
}
// locate specified path
p := path.Dir(uri.Path)
log.Printf("Changing to FTP directory : %s\n", p)
err = cli.ChangeDir(p)
if err != nil { return nil }
if err != nil {
return nil
}
curpath,err := cli.CurrentDir()
if err != nil { return err }
curpath, err := cli.CurrentDir()
if err != nil {
return err
}
log.Printf("Current FTP directory : %s\n", curpath)
// collect stats about the specified file
var name string
var entry *ftp.Entry
_,name = path.Split(uri.Path)
_, name = path.Split(uri.Path)
entry = nil
entries,err := cli.List(curpath)
for _,e := range entries {
entries, err := cli.List(curpath)
for _, e := range entries {
if e.Type == ftp.EntryTypeFile && e.Name == name {
entry = e
break
@ -420,15 +445,19 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
// download specified file
d.active = true
reader,err := cli.RetrFrom(uri.Path, d.progress)
if err != nil { return nil }
reader, err := cli.RetrFrom(uri.Path, d.progress)
if err != nil {
return nil
}
// do it in a goro so that if someone wants to cancel it, they can
errch := make(chan error)
go func(d *FTPDownloader, r io.Reader, w io.Writer, e chan error) {
for ; d.active; {
n,err := io.CopyN(w, r, int64(d.mtu))
if err != nil { break }
for d.active {
n, err := io.CopyN(w, r, int64(d.mtu))
if err != nil {
break
}
d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
@ -455,9 +484,9 @@ func (d *FTPDownloader) Download(dst *os.File, src *url.URL) error {
type FileDownloader struct {
bufferSize *uint
active bool
active bool
progress uint64
total uint64
total uint64
}
func (d *FileDownloader) Progress() uint64 {
@ -476,27 +505,27 @@ func (d *FileDownloader) Resume() {
// TODO: Implement
}
func (d *FileDownloader) toPath(base string, uri url.URL) (string,error) {
func (d *FileDownloader) toPath(base string, uri url.URL) (string, error) {
var result string
// absolute path -- file://c:/absolute/path -> c:/absolute/path
if strings.HasSuffix(uri.Host, ":") {
result = path.Join(uri.Host, uri.Path)
// semi-absolute path (current drive letter)
// -- file:///absolute/path -> /absolute/path
// semi-absolute path (current drive letter)
// -- file:///absolute/path -> /absolute/path
} else if uri.Host == "" && strings.HasPrefix(uri.Path, "/") {
result = path.Join(filepath.VolumeName(base), uri.Path)
// relative path -- file://./relative/path -> ./relative/path
// relative path -- file://./relative/path -> ./relative/path
} else if uri.Host == "." {
result = path.Join(base, uri.Path)
// relative path -- file://relative/path -> ./relative/path
// relative path -- file://relative/path -> ./relative/path
} else {
result = path.Join(base, uri.Host, uri.Path)
}
return filepath.ToSlash(result),nil
return filepath.ToSlash(result), nil
}
func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
@ -509,43 +538,53 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
uri := src
/* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd()
if err != nil { return err }
cwd, err := os.Getwd()
if err != nil {
return err
}
/* determine which uri format is being used and convert to a real path */
realpath,err := d.toPath(cwd, *uri)
if err != nil { return err }
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* download the file using the operating system's facilities */
d.progress = 0
d.active = true
f, err := os.Open(realpath)
if err != nil { return err }
if err != nil {
return err
}
defer f.Close()
// get the file size
fi, err := f.Stat()
if err != nil { return err }
if err != nil {
return err
}
d.total = uint64(fi.Size())
progressBar := pb.New64(int64(d.Total())).Start()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n,err = io.Copy(dst, f)
n, err = io.Copy(dst, f)
d.active = false
d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
// use a goro in case someone else wants to enable cancel/resume
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d* FileDownloader, r io.Reader, w io.Writer, e chan error) {
for ; d.active; {
n,err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil { break }
go func(d *FileDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
@ -566,9 +605,9 @@ func (d *FileDownloader) Download(dst *os.File, src *url.URL) error {
type SMBDownloader struct {
bufferSize *uint
active bool
active bool
progress uint64
total uint64
total uint64
}
func (d *SMBDownloader) Progress() uint64 {
@ -587,11 +626,11 @@ func (d *SMBDownloader) Resume() {
// TODO: Implement
}
func (d *SMBDownloader) toPath(base string, uri url.URL) (string,error) {
const UNCPrefix = string(os.PathSeparator)+string(os.PathSeparator)
func (d *SMBDownloader) toPath(base string, uri url.URL) (string, error) {
const UNCPrefix = string(os.PathSeparator) + string(os.PathSeparator)
if runtime.GOOS != "windows" {
return "",fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
return "", fmt.Errorf("Support for SMB based uri's are not supported on %s", runtime.GOOS)
}
return UNCPrefix + filepath.ToSlash(path.Join(uri.Host, uri.Path)), nil
@ -607,43 +646,53 @@ func (d *SMBDownloader) Download(dst *os.File, src *url.URL) error {
uri := src
/* use the current working directory as the base for relative uri's */
cwd,err := os.Getwd()
if err != nil { return err }
cwd, err := os.Getwd()
if err != nil {
return err
}
/* convert uri to an smb-path */
realpath,err := d.toPath(cwd, *uri)
if err != nil { return err }
realpath, err := d.toPath(cwd, *uri)
if err != nil {
return err
}
/* Open up the "\\"-prefixed path using the Windows filesystem */
d.progress = 0
d.active = true
f, err := os.Open(realpath)
if err != nil { return err }
if err != nil {
return err
}
defer f.Close()
// get the file size (at the risk of performance)
fi, err := f.Stat()
if err != nil { return err }
if err != nil {
return err
}
d.total = uint64(fi.Size())
progressBar := pb.New64(int64(d.Total())).Start()
// no bufferSize specified, so copy synchronously.
if d.bufferSize == nil {
var n int64
n,err = io.Copy(dst, f)
n, err = io.Copy(dst, f)
d.active = false
d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))
// use a goro in case someone else wants to enable cancel/resume
// use a goro in case someone else wants to enable cancel/resume
} else {
errch := make(chan error)
go func(d* SMBDownloader, r io.Reader, w io.Writer, e chan error) {
for ; d.active; {
n,err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil { break }
go func(d *SMBDownloader, r io.Reader, w io.Writer, e chan error) {
for d.active {
n, err := io.CopyN(w, r, int64(*d.bufferSize))
if err != nil {
break
}
d.progress += uint64(n)
progressBar.Set64(int64(d.Progress()))

@ -8,9 +8,9 @@ import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"strings"
"path/filepath"
"testing"
)
@ -119,7 +119,7 @@ func TestDownloadClient_checksumGood(t *testing.T) {
TargetPath: tf.Name(),
Hash: HashForType("md5"),
Checksum: checksum,
CopyFile: true,
CopyFile: true,
})
path, err := client.Get()
if err != nil {
@ -398,7 +398,7 @@ func TestDownloadFileUrl(t *testing.T) {
// into a testable file path whilst ignoring a correct checksum match, stripping
// UNC path info, and then calling stat to ensure the correct file exists.
// (used by TestFileUriTransforms)
func SimulateFileUriDownload(t *testing.T, uri string) (string,error) {
func SimulateFileUriDownload(t *testing.T, uri string) (string, error) {
// source_path is a file path and source is a network path
source := fmt.Sprintf(uri)
t.Logf("Trying to download %s", source)
@ -430,7 +430,7 @@ func SimulateFileUriDownload(t *testing.T, uri string) (string,error) {
if _, err = os.Stat(path); err != nil {
t.Errorf("Could not stat source file: %s", path)
}
return path,err
return path, err
}
// TestFileUriTransforms tests the case where we use a local file uri
@ -470,10 +470,10 @@ func TestFileUriTransforms(t *testing.T) {
}
// all regular slashed testcases
for _,testcase := range testcases {
for _, testcase := range testcases {
uri := "file://" + fmt.Sprintf(testcase, testpath)
t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri)
res,err := SimulateFileUriDownload(t, uri)
res, err := SimulateFileUriDownload(t, uri)
if err != nil {
t.Errorf("Unable to transform uri '%s' into a path : %v", uri, err)
}
@ -485,7 +485,7 @@ func TestFileUriTransforms(t *testing.T) {
testcase := host + "/" + share + "/" + cwd[1:] + "/%s"
uri := "smb://" + fmt.Sprintf(testcase, testpath)
t.Logf("TestFileUriTransforms : Trying Uri '%s'", uri)
res,err := SimulateFileUriDownload(t, uri)
res, err := SimulateFileUriDownload(t, uri)
if err != nil {
t.Errorf("Unable to transform uri '%s' into a path", uri)
return

Loading…
Cancel
Save