diff --git a/config/module/get_file.go b/config/module/get_file.go index 66ead7e51f..73cb858341 100644 --- a/config/module/get_file.go +++ b/config/module/get_file.go @@ -12,7 +12,14 @@ import ( type FileGetter struct{} func (g *FileGetter) Get(dst string, u *url.URL) error { - fi, err := os.Stat(dst) + // The source path must exist and be a directory to be usable. + if fi, err := os.Stat(u.Path); err != nil { + return fmt.Errorf("source path error: %s", err) + } else if !fi.IsDir() { + return fmt.Errorf("source path must be a directory") + } + + fi, err := os.Lstat(dst) if err != nil && !os.IsNotExist(err) { return err } @@ -20,16 +27,14 @@ func (g *FileGetter) Get(dst string, u *url.URL) error { // If the destination already exists, it must be a symlink if err == nil { mode := fi.Mode() - if mode&os.ModeSymlink != 0 { + if mode&os.ModeSymlink == 0 { return fmt.Errorf("destination exists and is not a symlink") } - } - // The source path must exist and be a directory to be usable. - if fi, err := os.Stat(u.Path); err != nil { - return fmt.Errorf("source path error: %s", err) - } else if !fi.IsDir() { - return fmt.Errorf("source path must be a directory") + // Remove the destination + if err := os.Remove(dst); err != nil { + return err + } } // Create all the parent directories diff --git a/config/module/get_file_test.go b/config/module/get_file_test.go index c805cd5c31..6cde4befc4 100644 --- a/config/module/get_file_test.go +++ b/config/module/get_file_test.go @@ -20,6 +20,15 @@ func TestFileGetter(t *testing.T) { t.Fatalf("err: %s", err) } + // Verify the destination folder is a symlink + fi, err := os.Lstat(dst) + if err != nil { + t.Fatalf("err: %s", err) + } + if fi.Mode()&os.ModeSymlink == 0 { + t.Fatal("destination is not a symlink") + } + // Verify the main file exists mainPath := filepath.Join(dst, "main.tf") if _, err := os.Stat(mainPath); err != nil { @@ -65,6 +74,36 @@ func TestFileGetter_dir(t *testing.T) { } } +func TestFileGetter_dirSymlink(t *testing.T) { + g := new(FileGetter) + dst := tempDir(t) + dst2 := tempDir(t) + + // Make parents + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + t.Fatalf("err: %s", err) + } + if err := os.MkdirAll(dst2, 0755); err != nil { + t.Fatalf("err: %s", err) + } + + // Make a symlink + if err := os.Symlink(dst2, dst); err != nil { + t.Fatalf("err: %s") + } + + // With a dir that exists that isn't a symlink + if err := g.Get(dst, testModuleURL("basic")); err != nil { + t.Fatalf("err: %s", err) + } + + // Verify the main file exists + mainPath := filepath.Join(dst, "main.tf") + if _, err := os.Stat(mainPath); err != nil { + t.Fatalf("err: %s", err) + } +} + func testModuleURL(n string) *url.URL { u, err := url.Parse(testModule(n)) if err != nil {