diff --git a/packer/plugin/command.go b/packer/plugin/command.go index da1d4ca56..3f5a66b0e 100644 --- a/packer/plugin/command.go +++ b/packer/plugin/command.go @@ -2,6 +2,7 @@ package plugin import ( "bytes" + "errors" "github.com/mitchellh/packer/packer" "net/rpc" "os/exec" @@ -24,17 +25,27 @@ func Command(cmd *exec.Cmd) (result packer.Command, err error) { return } - // TODO: timeout - // TODO: check that command is even running - address := "" - for { - line, err := out.ReadBytes('\n') - if err == nil { - address = strings.TrimSpace(string(line)) - break - } + cmdExited := make(chan bool) + go func() { + cmd.Wait() + cmdExited <- true + }() + + var address string + for done := false; !done; { + select { + case <-cmdExited: + err = errors.New("plugin exited before we could connect") + return + case <-time.After(10 * time.Millisecond): + if line, err := out.ReadBytes('\n'); err == nil { + address = strings.TrimSpace(string(line)) + done = true + } - time.Sleep(10 * time.Millisecond) + // Make sure to reset err to nil + err = nil + } } client, err := rpc.Dial("tcp", address) diff --git a/packer/plugin/command_test.go b/packer/plugin/command_test.go index 2088b6cd5..9dfe56464 100644 --- a/packer/plugin/command_test.go +++ b/packer/plugin/command_test.go @@ -22,3 +22,10 @@ func TestCommand_Good(t *testing.T) { result := command.Synopsis() assert.Equal(result, "1", "should return result") } + +func TestCommand_CommandExited(t *testing.T) { + assert := asserts.NewTestingAsserts(t, true) + + _, err := Command(helperProcess("im-a-command-that-doesnt-work")) + assert.NotNil(err, "should have an error") +} diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index 630bf3eda..e8397085e 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -1,10 +1,12 @@ package plugin import ( + "fmt" "github.com/mitchellh/packer/packer" "os" "os/exec" "testing" + "time" ) type helperCommand byte @@ -38,5 +40,30 @@ func TestHelperProcess(*testing.T) { return } - ServeCommand(new(helperCommand)) + args := os.Args + for len(args) > 0 { + if args[0] == "--" { + args = args[1:] + break + } + + args = args[1:] + } + + if len(args) == 0 { + fmt.Fprintf(os.Stderr, "No command\n") + os.Exit(2) + } + + cmd, args := args[0], args[1:] + switch cmd { + case "command": + ServeCommand(new(helperCommand)) + case "start-timeout": + time.Sleep(1 * time.Minute) + os.Exit(1) + default: + fmt.Fprintf(os.Stderr, "Unknown command: %q\n", cmd) + os.Exit(2) + } }