diff --git a/builtin/provisioners/file/resource_provisioner.go b/builtin/provisioners/file/resource_provisioner.go index 9b9e8a97be..30ed5e3594 100644 --- a/builtin/provisioners/file/resource_provisioner.go +++ b/builtin/provisioners/file/resource_provisioner.go @@ -4,9 +4,7 @@ import ( "context" "fmt" "io/ioutil" - "log" "os" - "time" "github.com/hashicorp/terraform/communicator" "github.com/hashicorp/terraform/helper/schema" @@ -50,6 +48,9 @@ func applyFn(ctx context.Context) error { return err } + ctx, cancel := context.WithTimeout(ctx, comm.Timeout()) + defer cancel() + // Get the source src, deleteSource, err := getSrc(data) if err != nil { @@ -61,21 +62,11 @@ func applyFn(ctx context.Context) error { // Begin the file copy dst := data.Get("destination").(string) - resultCh := make(chan error, 1) - go func() { - resultCh <- copyFiles(comm, src, dst) - }() - - // Allow the file copy to complete unless there is an interrupt. - // If there is an interrupt we make no attempt to cleanly close - // the connection currently. We just abruptly exit. Because Terraform - // taints the resource, this is fine. - select { - case err := <-resultCh: + + if err := copyFiles(ctx, comm, src, dst); err != nil { return err - case <-ctx.Done(): - return fmt.Errorf("file transfer interrupted") } + return nil } func validateFn(c *terraform.ResourceConfig) (ws []string, es []error) { @@ -107,9 +98,9 @@ func getSrc(data *schema.ResourceData) (string, bool, error) { } // copyFiles is used to copy the files from a source to a destination -func copyFiles(comm communicator.Communicator, src, dst string) error { +func copyFiles(ctx context.Context, comm communicator.Communicator, src, dst string) error { // Wait and retry until we establish the connection - err := retryFunc(comm.Timeout(), func() error { + err := communicator.Retry(ctx, func() error { err := comm.Connect(nil) return err }) @@ -144,21 +135,3 @@ func copyFiles(comm communicator.Communicator, src, dst string) error { } return err } - -// retryFunc is used to retry a function for a given duration -func retryFunc(timeout time.Duration, f func() error) error { - finish := time.After(timeout) - for { - err := f() - if err == nil { - return nil - } - log.Printf("Retryable error: %v", err) - - select { - case <-finish: - return err - case <-time.After(3 * time.Second): - } - } -}