diff --git a/builder/azure/common/devicelogin.go b/builder/azure/common/devicelogin.go index 8d053f802..ea177638a 100644 --- a/builder/azure/common/devicelogin.go +++ b/builder/azure/common/devicelogin.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "os" + "os/user" "path/filepath" "regexp" "strings" @@ -15,7 +16,6 @@ import ( "github.com/Azure/go-autorest/autorest/azure" "github.com/Azure/go-autorest/autorest/to" "github.com/hashicorp/packer/helper/useragent" - "github.com/mitchellh/go-homedir" ) var ( @@ -148,9 +148,13 @@ func tokenFromDeviceFlow(say func(string), oauthCfg adal.OAuthConfig, clientID, // tokenCachePath returns the full path the OAuth 2.0 token should be saved at // for given tenant ID. func tokenCachePath(tenantID string) string { - dir, err := homedir.Dir() - if err != nil { + var dir string + + u, err := user.Current() + if err != nil || u.HomeDir == "" { dir, _ = filepath.Abs(os.Args[0]) + } else { + dir = u.HomeDir } return filepath.Join(dir, ".azure", "packer", fmt.Sprintf("oauth-%s.json", tenantID)) diff --git a/builder/oracle/oci/config.go b/builder/oracle/oci/config.go index b7db218db..d83fa6c15 100644 --- a/builder/oracle/oci/config.go +++ b/builder/oracle/oci/config.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "log" "os" + "os/user" "path/filepath" "strings" @@ -16,8 +17,6 @@ import ( "github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/template/interpolate" ocicommon "github.com/oracle/oci-go-sdk/common" - - "github.com/mitchellh/go-homedir" ) type Config struct { @@ -95,7 +94,7 @@ func NewConfig(raws ...interface{}) (*Config, error) { var keyContent []byte if c.KeyFile != "" { - path, err := homedir.Expand(c.KeyFile) + path, err := packer.ExpandUser(c.KeyFile) if err != nil { return nil, err } @@ -249,15 +248,19 @@ func NewConfig(raws ...interface{}) (*Config, error) { return c, nil } -// getDefaultOCISettingsPath uses mitchellh/go-homedir to compute the default +// getDefaultOCISettingsPath uses os/user to compute the default // config file location ($HOME/.oci/config). func getDefaultOCISettingsPath() (string, error) { - home, err := homedir.Dir() + u, err := user.Current() if err != nil { return "", err } - path := filepath.Join(home, ".oci", "config") + if u.HomeDir == "" { + return "", fmt.Errorf("Unable to determine the home directory for the current user.") + } + + path := filepath.Join(u.HomeDir, ".oci", "config") if _, err := os.Stat(path); err != nil { return "", err } diff --git a/helper/communicator/config.go b/helper/communicator/config.go index fa4e053c3..7d66254b5 100644 --- a/helper/communicator/config.go +++ b/helper/communicator/config.go @@ -11,9 +11,9 @@ import ( packerssh "github.com/hashicorp/packer/communicator/ssh" "github.com/hashicorp/packer/helper/multistep" helperssh "github.com/hashicorp/packer/helper/ssh" + "github.com/hashicorp/packer/packer" "github.com/hashicorp/packer/template/interpolate" "github.com/masterzen/winrm" - "github.com/mitchellh/go-homedir" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" ) @@ -73,10 +73,11 @@ func (c *Config) ReadSSHPrivateKeyFile() ([]byte, error) { var privateKey []byte if c.SSHPrivateKeyFile != "" { - keyPath, err := homedir.Expand(c.SSHPrivateKeyFile) + keyPath, err := packer.ExpandUser(c.SSHPrivateKeyFile) if err != nil { - return privateKey, fmt.Errorf("Error expanding path for SSH private key: %s", err) + return []byte{}, fmt.Errorf("Error expanding path for SSH private key: %s", err) } + privateKey, err = ioutil.ReadFile(keyPath) if err != nil { return privateKey, fmt.Errorf("Error on reading SSH private key: %s", err) @@ -261,7 +262,7 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { } if c.SSHPrivateKeyFile != "" { - path, err := homedir.Expand(c.SSHPrivateKeyFile) + path, err := packer.ExpandUser(c.SSHPrivateKeyFile) if err != nil { errs = append(errs, fmt.Errorf( "ssh_private_key_file is invalid: %s", err)) @@ -279,7 +280,7 @@ func (c *Config) prepareSSH(ctx *interpolate.Context) []error { errs = append(errs, errors.New( "ssh_bastion_password or ssh_bastion_private_key_file must be specified")) } else if c.SSHBastionPrivateKeyFile != "" { - path, err := homedir.Expand(c.SSHBastionPrivateKeyFile) + path, err := packer.ExpandUser(c.SSHBastionPrivateKeyFile) if err != nil { errs = append(errs, fmt.Errorf( "ssh_bastion_private_key_file is invalid: %s", err)) diff --git a/helper/communicator/step_connect_ssh.go b/helper/communicator/step_connect_ssh.go index 6517a04a7..37db32eae 100644 --- a/helper/communicator/step_connect_ssh.go +++ b/helper/communicator/step_connect_ssh.go @@ -14,7 +14,6 @@ import ( "github.com/hashicorp/packer/helper/multistep" helperssh "github.com/hashicorp/packer/helper/ssh" "github.com/hashicorp/packer/packer" - "github.com/mitchellh/go-homedir" gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "golang.org/x/net/proxy" @@ -227,11 +226,12 @@ func sshBastionConfig(config *Config) (*gossh.ClientConfig, error) { } if config.SSHBastionPrivateKeyFile != "" { - path, err := homedir.Expand(config.SSHBastionPrivateKeyFile) + path, err := packer.ExpandUser(config.SSHBastionPrivateKeyFile) if err != nil { return nil, fmt.Errorf( "Error expanding path for SSH bastion private key: %s", err) } + signer, err := helperssh.FileSigner(path) if err != nil { return nil, err diff --git a/packer/config_file.go b/packer/config_file.go index 043b40ae8..f95914230 100644 --- a/packer/config_file.go +++ b/packer/config_file.go @@ -5,6 +5,7 @@ import ( "os" "os/user" "path/filepath" + "strings" ) // ConfigFile returns the default path to the configuration file. On @@ -84,3 +85,49 @@ func configDir() (string, error) { return filepath.Join(dir, defaultConfigDir), nil } + +// Given a path, check to see if it's using ~ to reference a user directory. +// If so, then replace that component with the requested user directory. +// In "~/", "~" gets replaced by current user's home dir. +// In "~root/", "~user" gets replaced by root's home dir. +// ~ has to be the first character of path for ExpandUser change it. +func ExpandUser(path string) (string, error) { + var ( + u *user.User + err error + ) + + // refuse to do anything with a zero-length path + if len(path) == 0 { + return path, nil + } + + // If no expansion was specified, then refuse that too + if path[0] != '~' { + return path, nil + } + + // Grab everything up to the first filepath.Separator + idx := strings.IndexAny(path, `/\`) + if idx == -1 { + idx = len(path) + } + + // Now we should be able to extract the username + username := path[:idx] + + // Check if the current user was requested + if username == "~" { + u, err = user.Current() + } else { + u, err = user.Lookup(username[1:]) + } + + // If we couldn't figure that out, then fail here + if err != nil { + return "", err + } + + // Now we can replace the path with u.HomeDir + return filepath.Join(u.HomeDir, path[idx:]), nil +} diff --git a/packer/config_file_test.go b/packer/config_file_test.go new file mode 100644 index 000000000..ba275793b --- /dev/null +++ b/packer/config_file_test.go @@ -0,0 +1,161 @@ +package packer + +import ( + "fmt" + "os/user" + "path/filepath" + "runtime" + "testing" +) + +// Depending on the platform, find a valid username to use +func platform_user() string { + // XXX: We make an assumption here that there's an Administrator user + // on the windows platform, whereas the correct way is to use + // the api or to scrape `net user`. + if runtime.GOOS == "windows" { + return "Administrator" + } + return "root" +} + +func homedir_current() (string, error) { + u, err := user.Current() + if err != nil { + return "", err + } + + return u.HomeDir, nil +} + +func homedir_user(username string) (string, error) { + u, err := user.Lookup(username) + if err != nil { + return "", err + } + + return u.HomeDir, nil +} + +// Begin the actual tests and stuff +func TestExpandUser_Empty(t *testing.T) { + var path, expected string + + // Try an invalid user + path, err := ExpandUser("~invalid-user-that-should-not-exist") + if err == nil { + t.Fatalf("expected failure") + } + + // Try an empty string + expected = "" + if path, err = ExpandUser(""); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try an absolute path + expected = "/etc/shadow" + if path, err = ExpandUser("/etc/shadow"); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try a relative path + expected = "tmp/foo" + if path, err = ExpandUser("tmp/foo"); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } +} + +func TestExpandUser_Current(t *testing.T) { + var path, expected string + + // Grab the current user's home directory to verify ExpandUser works + homedir, err := homedir_current() + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try just a tilde + expected = homedir + if path, err = ExpandUser("~"); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try as a directory + expected = filepath.Join(homedir, "") + if path, err = ExpandUser("~/"); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try as a file + expected = filepath.Join(homedir, "foo") + if path, err = ExpandUser("~/foo"); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } +} + +func TestExpandUser_User(t *testing.T) { + var path, expected string + + username := platform_user() + + // Grab the current user's home directory to verify ExpandUser works + homedir, err := homedir_user(username) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Try just a tilde + expected = homedir + if path, err = ExpandUser(fmt.Sprintf("~%s", username)); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try as a directory + expected = filepath.Join(homedir, "") + if path, err = ExpandUser(fmt.Sprintf("~%s/", username)); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } + + // Try as a file + expected = filepath.Join(homedir, "foo") + if path, err = ExpandUser(fmt.Sprintf("~%s/foo", username)); err != nil { + t.Fatalf("err: %s", err) + } + + if path != expected { + t.Fatalf("err: %v != %v", path, expected) + } +}