diff --git a/common/step_provision.go b/common/step_provision.go index 893f13431..d069a5c0c 100644 --- a/common/step_provision.go +++ b/common/step_provision.go @@ -4,6 +4,7 @@ import ( "github.com/mitchellh/multistep" "github.com/mitchellh/packer/packer" "log" + "time" ) // StepProvision runs the provisioners. @@ -22,13 +23,31 @@ func (*StepProvision) Run(state map[string]interface{}) multistep.StepAction { hook := state["hook"].(packer.Hook) ui := state["ui"].(packer.Ui) + // Run the provisioner in a goroutine so we can continually check + // for cancellations... log.Println("Running the provision hook") - if err := hook.Run(packer.HookProvision, ui, comm, nil); err != nil { - state["error"] = err - return multistep.ActionHalt - } + errCh := make(chan error, 1) + go func() { + errCh <- hook.Run(packer.HookProvision, ui, comm, nil) + }() + + for { + select { + case err := <-errCh: + if err != nil { + state["error"] = err + return multistep.ActionHalt + } - return multistep.ActionContinue + return multistep.ActionContinue + case <-time.After(1 * time.Second): + if _, ok := state[multistep.StateCancelled]; ok { + log.Println("Cancelling provisioning due to interrupt...") + hook.Cancel() + return multistep.ActionHalt + } + } + } } func (*StepProvision) Cleanup(map[string]interface{}) {} diff --git a/packer/build.go b/packer/build.go index b8e0fa165..e324fdea7 100644 --- a/packer/build.go +++ b/packer/build.go @@ -207,10 +207,12 @@ func (b *coreBuild) Run(originalUi Ui, cache Cache) ([]Artifact, error) { hooks[HookProvision] = make([]Hook, 0, 1) } - hooks[HookProvision] = append(hooks[HookProvision], &ProvisionHook{provisioners}) + hooks[HookProvision] = append(hooks[HookProvision], &ProvisionHook{ + Provisioners: provisioners, + }) } - hook := &DispatchHook{hooks} + hook := &DispatchHook{Mapping: hooks} artifacts := make([]Artifact, 0, 1) // The builder just has a normal Ui, but targetted diff --git a/packer/build_test.go b/packer/build_test.go index c5eedbc81..cbac0f208 100644 --- a/packer/build_test.go +++ b/packer/build_test.go @@ -13,10 +13,10 @@ func testBuild() *coreBuild { builderConfig: 42, builderType: "foo", hooks: map[string][]Hook{ - "foo": []Hook{&TestHook{}}, + "foo": []Hook{&MockHook{}}, }, provisioners: []coreBuildProvisioner{ - coreBuildProvisioner{&TestProvisioner{}, []interface{}{42}}, + coreBuildProvisioner{&MockProvisioner{}, []interface{}{42}}, }, postProcessors: [][]coreBuildPostProcessor{ []coreBuildPostProcessor{ @@ -59,9 +59,9 @@ func TestBuild_Prepare(t *testing.T) { assert.Equal(builder.prepareConfig, []interface{}{42, packerConfig}, "prepare config should be 42") coreProv := build.provisioners[0] - prov := coreProv.provisioner.(*TestProvisioner) - assert.True(prov.prepCalled, "prepare should be called") - assert.Equal(prov.prepConfigs, []interface{}{42, packerConfig}, "prepare should be called with proper config") + prov := coreProv.provisioner.(*MockProvisioner) + assert.True(prov.PrepCalled, "prepare should be called") + assert.Equal(prov.PrepConfigs, []interface{}{42, packerConfig}, "prepare should be called with proper config") corePP := build.postProcessors[0][0] pp := corePP.processor.(*TestPostProcessor) @@ -104,9 +104,9 @@ func TestBuild_Prepare_Debug(t *testing.T) { assert.Equal(builder.prepareConfig, []interface{}{42, packerConfig}, "prepare config should be 42") coreProv := build.provisioners[0] - prov := coreProv.provisioner.(*TestProvisioner) - assert.True(prov.prepCalled, "prepare should be called") - assert.Equal(prov.prepConfigs, []interface{}{42, packerConfig}, "prepare should be called with proper config") + prov := coreProv.provisioner.(*MockProvisioner) + assert.True(prov.PrepCalled, "prepare should be called") + assert.Equal(prov.PrepConfigs, []interface{}{42, packerConfig}, "prepare should be called with proper config") } func TestBuildPrepare_variables_default(t *testing.T) { @@ -187,14 +187,14 @@ func TestBuild_Run(t *testing.T) { dispatchHook := builder.runHook dispatchHook.Run("foo", nil, nil, 42) - hook := build.hooks["foo"][0].(*TestHook) - assert.True(hook.runCalled, "run should be called") - assert.Equal(hook.runData, 42, "should have correct data") + hook := build.hooks["foo"][0].(*MockHook) + assert.True(hook.RunCalled, "run should be called") + assert.Equal(hook.RunData, 42, "should have correct data") // Verify provisioners run dispatchHook.Run(HookProvision, nil, nil, 42) - prov := build.provisioners[0].provisioner.(*TestProvisioner) - assert.True(prov.provCalled, "provision should be called") + prov := build.provisioners[0].provisioner.(*MockProvisioner) + assert.True(prov.ProvCalled, "provision should be called") // Verify post-processor was run pp := build.postProcessors[0][0].processor.(*TestPostProcessor) diff --git a/packer/environment_test.go b/packer/environment_test.go index e301e8c7a..6cda3f7d2 100644 --- a/packer/environment_test.go +++ b/packer/environment_test.go @@ -20,7 +20,7 @@ func init() { func testComponentFinder() *ComponentFinder { builderFactory := func(n string) (Builder, error) { return testBuilder(), nil } ppFactory := func(n string) (PostProcessor, error) { return new(TestPostProcessor), nil } - provFactory := func(n string) (Provisioner, error) { return new(TestProvisioner), nil } + provFactory := func(n string) (Provisioner, error) { return new(MockProvisioner), nil } return &ComponentFinder{ Builder: builderFactory, PostProcessor: ppFactory, @@ -227,7 +227,7 @@ func TestEnvironment_DefaultCli_Version(t *testing.T) { func TestEnvironment_Hook(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) - hook := &TestHook{} + hook := &MockHook{} hooks := make(map[string]Hook) hooks["foo"] = hook @@ -309,7 +309,7 @@ func TestEnvironment_PostProcessor_Error(t *testing.T) { func TestEnvironmentProvisioner(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) - p := &TestProvisioner{} + p := &MockProvisioner{} ps := make(map[string]Provisioner) ps["foo"] = p diff --git a/packer/hook.go b/packer/hook.go index 4a79222f4..e5e7ad8a9 100644 --- a/packer/hook.go +++ b/packer/hook.go @@ -1,5 +1,9 @@ package packer +import ( + "sync" +) + // This is the hook that should be fired for provisioners to run. const HookProvision = "packer_provision" @@ -11,19 +15,40 @@ const HookProvision = "packer_provision" // you must reference the documentation for the specific hook you're interested // in. In addition to that, the Hook is given access to a UI so that it can // output things to the user. +// +// Cancel is called when the hook needs to be cancelled. This will usually +// be called when Run is still in progress so the mechanism that handles this +// must be race-free. Cancel should attempt to cancel the hook in the +// quickest, safest way possible. type Hook interface { Run(string, Ui, Communicator, interface{}) error + Cancel() } // A Hook implementation that dispatches based on an internal mapping. type DispatchHook struct { Mapping map[string][]Hook + + l sync.Mutex + cancelled bool + runningHook Hook } // Runs the hook with the given name by dispatching it to the proper // hooks if a mapping exists. If a mapping doesn't exist, then nothing // happens. func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { + h.l.Lock() + h.cancelled = false + h.l.Unlock() + + // Make sure when we exit that we reset the running hook. + defer func() { + h.l.Lock() + defer h.l.Unlock() + h.runningHook = nil + }() + hooks, ok := h.Mapping[name] if !ok { // No hooks for that name. No problem. @@ -31,6 +56,15 @@ func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface } for _, hook := range hooks { + h.l.Lock() + if h.cancelled { + h.l.Unlock() + return nil + } + + h.runningHook = hook + h.l.Unlock() + if err := hook.Run(name, ui, comm, data); err != nil { return err } @@ -38,3 +72,16 @@ func (h *DispatchHook) Run(name string, ui Ui, comm Communicator, data interface return nil } + +// Cancels all the hooks that are currently in-flight, if any. This will +// block until the hooks are all cancelled. +func (h *DispatchHook) Cancel() { + h.l.Lock() + defer h.l.Unlock() + + if h.runningHook != nil { + h.runningHook.Cancel() + } + + h.cancelled = true +} diff --git a/packer/hook_mock.go b/packer/hook_mock.go new file mode 100644 index 000000000..7177329e3 --- /dev/null +++ b/packer/hook_mock.go @@ -0,0 +1,31 @@ +package packer + +// MockHook is an implementation of Hook that can be used for tests. +type MockHook struct { + RunFunc func() error + + RunCalled bool + RunComm Communicator + RunData interface{} + RunName string + RunUi Ui + CancelCalled bool +} + +func (t *MockHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { + t.RunCalled = true + t.RunComm = comm + t.RunData = data + t.RunName = name + t.RunUi = ui + + if t.RunFunc == nil { + return nil + } + + return t.RunFunc() +} + +func (t *MockHook) Cancel() { + t.CancelCalled = true +} diff --git a/packer/hook_test.go b/packer/hook_test.go index e43dfd1fd..bf88e65df 100644 --- a/packer/hook_test.go +++ b/packer/hook_test.go @@ -2,52 +2,89 @@ package packer import ( "cgl.tideland.biz/asserts" + "sync" "testing" + "time" ) -type TestHook struct { - runCalled bool - runComm Communicator - runData interface{} - runName string - runUi Ui +// A helper Hook implementation for testing cancels. +type CancelHook struct { + sync.Mutex + cancelCh chan struct{} + doneCh chan struct{} + + Cancelled bool } -func (t *TestHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { - t.runCalled = true - t.runComm = comm - t.runData = data - t.runName = name - t.runUi = ui +func (h *CancelHook) Run(string, Ui, Communicator, interface{}) error { + h.Lock() + h.cancelCh = make(chan struct{}) + h.doneCh = make(chan struct{}) + h.Unlock() + + defer close(h.doneCh) + + select { + case <-h.cancelCh: + h.Cancelled = true + case <-time.After(1 * time.Second): + } + return nil } +func (h *CancelHook) Cancel() { + h.Lock() + close(h.cancelCh) + h.Unlock() + + <-h.doneCh +} + func TestDispatchHook_Implements(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) var r Hook - c := &DispatchHook{nil} + c := &DispatchHook{} assert.Implementor(c, &r, "should be a Hook") } func TestDispatchHook_Run_NoHooks(t *testing.T) { // Just make sure nothing blows up - dh := &DispatchHook{make(map[string][]Hook)} + dh := &DispatchHook{} dh.Run("foo", nil, nil, nil) } func TestDispatchHook_Run(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) - hook := &TestHook{} + hook := &MockHook{} mapping := make(map[string][]Hook) mapping["foo"] = []Hook{hook} - dh := &DispatchHook{mapping} + dh := &DispatchHook{Mapping: mapping} dh.Run("foo", nil, nil, 42) - assert.True(hook.runCalled, "run should be called") - assert.Equal(hook.runName, "foo", "should be proper event") - assert.Equal(hook.runData, 42, "should be correct data") + assert.True(hook.RunCalled, "run should be called") + assert.Equal(hook.RunName, "foo", "should be proper event") + assert.Equal(hook.RunData, 42, "should be correct data") +} + +func TestDispatchHook_cancel(t *testing.T) { + hook := new(CancelHook) + + dh := &DispatchHook{ + Mapping: map[string][]Hook{ + "foo": []Hook{hook}, + }, + } + + go dh.Run("foo", nil, nil, 42) + time.Sleep(100 * time.Millisecond) + dh.Cancel() + + if !hook.Cancelled { + t.Fatal("hook should've cancelled") + } } diff --git a/packer/plugin/hook.go b/packer/plugin/hook.go index 90b0779d9..5d4dbf06a 100644 --- a/packer/plugin/hook.go +++ b/packer/plugin/hook.go @@ -19,8 +19,17 @@ func (c *cmdHook) Run(name string, ui packer.Ui, comm packer.Communicator, data return c.hook.Run(name, ui, comm, data) } +func (c *cmdHook) Cancel() { + defer func() { + r := recover() + c.checkExit(r, nil) + }() + + c.hook.Cancel() +} + func (c *cmdHook) checkExit(p interface{}, cb func()) { - if c.client.Exited() { + if c.client.Exited() && cb != nil { cb() } else if p != nil && !Killed { log.Panic(p) diff --git a/packer/plugin/hook_test.go b/packer/plugin/hook_test.go index e6880e616..6f897da19 100644 --- a/packer/plugin/hook_test.go +++ b/packer/plugin/hook_test.go @@ -1,17 +1,10 @@ package plugin import ( - "github.com/mitchellh/packer/packer" "os/exec" "testing" ) -type helperHook byte - -func (helperHook) Run(string, packer.Ui, packer.Communicator, interface{}) error { - return nil -} - func TestHook_NoExist(t *testing.T) { c := NewClient(&ClientConfig{Cmd: exec.Command("i-should-not-exist")}) defer c.Kill() diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index 0610887ec..17018f82d 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -2,6 +2,7 @@ package plugin import ( "fmt" + "github.com/mitchellh/packer/packer" "log" "os" "os/exec" @@ -54,7 +55,7 @@ func TestHelperProcess(*testing.T) { case "command": ServeCommand(new(helperCommand)) case "hook": - ServeHook(new(helperHook)) + ServeHook(new(packer.MockHook)) case "invalid-rpc-address": fmt.Println("lolinvalid") case "mock": @@ -63,7 +64,7 @@ func TestHelperProcess(*testing.T) { case "post-processor": ServePostProcessor(new(helperPostProcessor)) case "provisioner": - ServeProvisioner(new(helperProvisioner)) + ServeProvisioner(new(packer.MockProvisioner)) case "start-timeout": time.Sleep(1 * time.Minute) os.Exit(1) diff --git a/packer/plugin/provisioner.go b/packer/plugin/provisioner.go index 7445c4165..0feb9d727 100644 --- a/packer/plugin/provisioner.go +++ b/packer/plugin/provisioner.go @@ -28,6 +28,15 @@ func (c *cmdProvisioner) Provision(ui packer.Ui, comm packer.Communicator) error return c.p.Provision(ui, comm) } +func (c *cmdProvisioner) Cancel() { + defer func() { + r := recover() + c.checkExit(r, nil) + }() + + c.p.Cancel() +} + func (c *cmdProvisioner) checkExit(p interface{}, cb func()) { if c.client.Exited() && cb != nil { cb() diff --git a/packer/plugin/provisioner_test.go b/packer/plugin/provisioner_test.go index 4a665411c..f0d7eb773 100644 --- a/packer/plugin/provisioner_test.go +++ b/packer/plugin/provisioner_test.go @@ -1,21 +1,10 @@ package plugin import ( - "github.com/mitchellh/packer/packer" "os/exec" "testing" ) -type helperProvisioner byte - -func (helperProvisioner) Prepare(...interface{}) error { - return nil -} - -func (helperProvisioner) Provision(packer.Ui, packer.Communicator) error { - return nil -} - func TestProvisioner_NoExist(t *testing.T) { c := NewClient(&ClientConfig{Cmd: exec.Command("i-should-not-exist")}) defer c.Kill() diff --git a/packer/provisioner.go b/packer/provisioner.go index 99a0a2ac2..3592fb919 100644 --- a/packer/provisioner.go +++ b/packer/provisioner.go @@ -1,5 +1,9 @@ package packer +import ( + "sync" +) + // A provisioner is responsible for installing and configuring software // on a machine prior to building the actual image. type Provisioner interface { @@ -13,6 +17,11 @@ type Provisioner interface { // is guaranteed to be connected to some machine so that provisioning // can be done. Provision(Ui, Communicator) error + + // Cancel is called to cancel the provisioning. This is usually called + // while Provision is still being called. The Provisioner should act + // to stop its execution as quickly as possible in a race-free way. + Cancel() } // A Hook implementation that runs the given provisioners. @@ -20,11 +29,25 @@ type ProvisionHook struct { // The provisioners to run as part of the hook. These should already // be prepared (by calling Prepare) at some earlier stage. Provisioners []Provisioner + + lock sync.Mutex + runningProvisioner Provisioner } // Runs the provisioners in order. func (h *ProvisionHook) Run(name string, ui Ui, comm Communicator, data interface{}) error { + defer func() { + h.lock.Lock() + defer h.lock.Unlock() + + h.runningProvisioner = nil + }() + for _, p := range h.Provisioners { + h.lock.Lock() + h.runningProvisioner = p + h.lock.Unlock() + if err := p.Provision(ui, comm); err != nil { return err } @@ -32,3 +55,13 @@ func (h *ProvisionHook) Run(name string, ui Ui, comm Communicator, data interfac return nil } + +// Cancels the privisioners that are still running. +func (h *ProvisionHook) Cancel() { + h.lock.Lock() + defer h.lock.Unlock() + + if h.runningProvisioner != nil { + h.runningProvisioner.Cancel() + } +} diff --git a/packer/provisioner_mock.go b/packer/provisioner_mock.go new file mode 100644 index 000000000..b61f642af --- /dev/null +++ b/packer/provisioner_mock.go @@ -0,0 +1,34 @@ +package packer + +// MockProvisioner is an implementation of Provisioner that can be +// used for tests. +type MockProvisioner struct { + ProvFunc func() error + + PrepCalled bool + PrepConfigs []interface{} + ProvCalled bool + ProvUi Ui + CancelCalled bool +} + +func (t *MockProvisioner) Prepare(configs ...interface{}) error { + t.PrepCalled = true + t.PrepConfigs = configs + return nil +} + +func (t *MockProvisioner) Provision(ui Ui, comm Communicator) error { + t.ProvCalled = true + t.ProvUi = ui + + if t.ProvFunc == nil { + return nil + } + + return t.ProvFunc() +} + +func (t *MockProvisioner) Cancel() { + t.CancelCalled = true +} diff --git a/packer/provisioner_test.go b/packer/provisioner_test.go index 1db8b55bc..a3d97d511 100644 --- a/packer/provisioner_test.go +++ b/packer/provisioner_test.go @@ -1,23 +1,10 @@ package packer -import "testing" - -type TestProvisioner struct { - prepCalled bool - prepConfigs []interface{} - provCalled bool -} - -func (t *TestProvisioner) Prepare(configs ...interface{}) error { - t.prepCalled = true - t.prepConfigs = configs - return nil -} - -func (t *TestProvisioner) Provision(Ui, Communicator) error { - t.provCalled = true - return nil -} +import ( + "sync" + "testing" + "time" +) func TestProvisionHook_Impl(t *testing.T) { var raw interface{} @@ -28,23 +15,68 @@ func TestProvisionHook_Impl(t *testing.T) { } func TestProvisionHook(t *testing.T) { - pA := &TestProvisioner{} - pB := &TestProvisioner{} + pA := &MockProvisioner{} + pB := &MockProvisioner{} ui := testUi() var comm Communicator = nil var data interface{} = nil - hook := &ProvisionHook{[]Provisioner{pA, pB}} + hook := &ProvisionHook{ + Provisioners: []Provisioner{pA, pB}, + } + hook.Run("foo", ui, comm, data) - if !pA.provCalled { + if !pA.ProvCalled { t.Error("provision should be called on pA") } - if !pB.provCalled { + if !pB.ProvCalled { t.Error("provision should be called on pB") } } +func TestProvisionHook_cancel(t *testing.T) { + var lock sync.Mutex + order := make([]string, 0, 2) + + p := &MockProvisioner{ + ProvFunc: func() error { + time.Sleep(50 * time.Millisecond) + + lock.Lock() + defer lock.Unlock() + order = append(order, "prov") + + return nil + }, + } + + hook := &ProvisionHook{ + Provisioners: []Provisioner{p}, + } + + finished := make(chan struct{}) + go func() { + hook.Run("foo", nil, nil, nil) + close(finished) + }() + + // Cancel it while it is running + time.Sleep(10 * time.Millisecond) + hook.Cancel() + lock.Lock() + order = append(order, "cancel") + lock.Unlock() + + // Wait + <-finished + + // Verify order + if order[0] != "cancel" || order[1] != "prov" { + t.Fatalf("bad: %#v", order) + } +} + // TODO(mitchellh): Test that they're run in the proper order diff --git a/packer/rpc/builder_test.go b/packer/rpc/builder_test.go index 8dff67b30..5235de9b5 100644 --- a/packer/rpc/builder_test.go +++ b/packer/rpc/builder_test.go @@ -72,7 +72,7 @@ func TestBuilderRPC(t *testing.T) { // Test Run cache := new(testCache) - hook := &testHook{} + hook := &packer.MockHook{} ui := &testUi{} artifact, err := bClient.Run(ui, hook, cache) assert.Nil(err, "should have no error") @@ -83,7 +83,7 @@ func TestBuilderRPC(t *testing.T) { assert.True(cache.lockCalled, "lock should be called") b.runHook.Run("foo", nil, nil, nil) - assert.True(hook.runCalled, "run should be called") + assert.True(hook.RunCalled, "run should be called") b.runUi.Say("format") assert.True(ui.sayCalled, "say should be called") diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 27840f080..687d991a7 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -2,6 +2,7 @@ package rpc import ( "github.com/mitchellh/packer/packer" + "log" "net/rpc" ) @@ -37,6 +38,13 @@ func (h *hook) Run(name string, ui packer.Ui, comm packer.Communicator, data int return h.client.Call("Hook.Run", args, new(interface{})) } +func (h *hook) Cancel() { + err := h.client.Call("Hook.Cancel", new(interface{}), new(interface{})) + if err != nil { + log.Printf("Hook.Cancel error: %s", err) + } +} + func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { client, err := rpc.Dial("tcp", args.RPCAddress) if err != nil { @@ -50,3 +58,8 @@ func (h *HookServer) Run(args *HookRunArgs, reply *interface{}) error { *reply = nil return nil } + +func (h *HookServer) Cancel(args *interface{}, reply *interface{}) error { + h.hook.Cancel() + return nil +} diff --git a/packer/rpc/hook_test.go b/packer/rpc/hook_test.go index fcc2aea7a..1d226056a 100644 --- a/packer/rpc/hook_test.go +++ b/packer/rpc/hook_test.go @@ -4,24 +4,17 @@ import ( "cgl.tideland.biz/asserts" "github.com/mitchellh/packer/packer" "net/rpc" + "reflect" + "sync" "testing" + "time" ) -type testHook struct { - runCalled bool - runUi packer.Ui -} - -func (h *testHook) Run(name string, ui packer.Ui, comm packer.Communicator, data interface{}) error { - h.runCalled = true - return nil -} - func TestHookRPC(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) // Create the UI to test - h := new(testHook) + h := new(packer.MockHook) // Serve server := rpc.NewServer() @@ -37,7 +30,11 @@ func TestHookRPC(t *testing.T) { // Test Run ui := &testUi{} hClient.Run("foo", ui, nil, 42) - assert.True(h.runCalled, "run should be called") + assert.True(h.RunCalled, "run should be called") + + // Test Cancel + hClient.Cancel() + assert.True(h.CancelCalled, "cancel should be called") } func TestHook_Implements(t *testing.T) { @@ -48,3 +45,56 @@ func TestHook_Implements(t *testing.T) { assert.Implementor(h, &r, "should be a Hook") } + +func TestHook_cancelWhileRun(t *testing.T) { + var finishLock sync.Mutex + finishOrder := make([]string, 0, 2) + + h := &packer.MockHook{ + RunFunc: func() error { + time.Sleep(100 * time.Millisecond) + + finishLock.Lock() + finishOrder = append(finishOrder, "run") + finishLock.Unlock() + return nil + }, + } + + // Serve + server := rpc.NewServer() + RegisterHook(server, h) + address := serveSingleConn(server) + + // Create the client over RPC and run some methods to verify it works + client, err := rpc.Dial("tcp", address) + if err != nil { + t.Fatalf("err: %s", err) + } + + hClient := Hook(client) + + // Start the run + finished := make(chan struct{}) + go func() { + hClient.Run("foo", nil, nil, nil) + close(finished) + }() + + // Cancel it pretty quickly. + time.Sleep(10 * time.Millisecond) + hClient.Cancel() + + finishLock.Lock() + finishOrder = append(finishOrder, "cancel") + finishLock.Unlock() + + // Verify things are good + <-finished + + // Check the results + expected := []string{"cancel", "run"} + if !reflect.DeepEqual(finishOrder, expected) { + t.Fatalf("bad: %#v", finishOrder) + } +} diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index dbe366e56..7d3ed1617 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -2,6 +2,7 @@ package rpc import ( "github.com/mitchellh/packer/packer" + "log" "net/rpc" ) @@ -47,6 +48,13 @@ func (p *provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return p.client.Call("Provisioner.Provision", args, new(interface{})) } +func (p *provisioner) Cancel() { + err := p.client.Call("Provisioner.Cancel", new(interface{}), new(interface{})) + if err != nil { + log.Printf("Provisioner.Cancel err: %s", err) + } +} + func (p *ProvisionerServer) Prepare(args *ProvisionerPrepareArgs, reply *error) error { *reply = p.p.Prepare(args.Configs...) if *reply != nil { @@ -71,3 +79,8 @@ func (p *ProvisionerServer) Provision(args *ProvisionerProvisionArgs, reply *int return nil } + +func (p *ProvisionerServer) Cancel(args *interface{}, reply *interface{}) error { + p.p.Cancel() + return nil +} diff --git a/packer/rpc/provisioner_test.go b/packer/rpc/provisioner_test.go index 106ae62eb..e251d814f 100644 --- a/packer/rpc/provisioner_test.go +++ b/packer/rpc/provisioner_test.go @@ -7,32 +7,11 @@ import ( "testing" ) -type testProvisioner struct { - prepareCalled bool - prepareConfigs []interface{} - provCalled bool - provComm packer.Communicator - provUi packer.Ui -} - -func (p *testProvisioner) Prepare(configs ...interface{}) error { - p.prepareCalled = true - p.prepareConfigs = configs - return nil -} - -func (p *testProvisioner) Provision(ui packer.Ui, comm packer.Communicator) error { - p.provCalled = true - p.provComm = comm - p.provUi = ui - return nil -} - func TestProvisionerRPC(t *testing.T) { assert := asserts.NewTestingAsserts(t, true) // Create the interface to test - p := new(testProvisioner) + p := new(packer.MockProvisioner) // Start the server server := rpc.NewServer() @@ -47,17 +26,23 @@ func TestProvisionerRPC(t *testing.T) { config := 42 pClient := Provisioner(client) pClient.Prepare(config) - assert.True(p.prepareCalled, "prepare should be called") - assert.Equal(p.prepareConfigs, []interface{}{42}, "prepare should be called with right arg") + assert.True(p.PrepCalled, "prepare should be called") + assert.Equal(p.PrepConfigs, []interface{}{42}, "prepare should be called with right arg") // Test Provision ui := &testUi{} - comm := new(packer.MockCommunicator) + comm := &packer.MockCommunicator{} pClient.Provision(ui, comm) - assert.True(p.provCalled, "provision should be called") + assert.True(p.ProvCalled, "provision should be called") - p.provUi.Say("foo") + p.ProvUi.Say("foo") assert.True(ui.sayCalled, "say should be called") + + // Test Cancel + pClient.Cancel() + if !p.CancelCalled { + t.Fatal("cancel should be called") + } } func TestProvisioner_Implements(t *testing.T) { diff --git a/packer/template_test.go b/packer/template_test.go index 398e48975..2e3fbc672 100644 --- a/packer/template_test.go +++ b/packer/template_test.go @@ -589,7 +589,7 @@ func TestTemplate_Build(t *testing.T) { "test-builder": builder, } - provisioner := &TestProvisioner{} + provisioner := &MockProvisioner{} provisionerMap := map[string]Provisioner{ "test-prov": provisioner, } @@ -677,7 +677,7 @@ func TestTemplate_Build_ProvisionerOverride(t *testing.T) { "test-builder": builder, } - provisioner := &TestProvisioner{} + provisioner := &MockProvisioner{} provisionerMap := map[string]Provisioner{ "test-prov": provisioner, } diff --git a/provisioner/chef-solo/provisioner.go b/provisioner/chef-solo/provisioner.go index 2d255375d..fe96ad467 100644 --- a/provisioner/chef-solo/provisioner.go +++ b/provisioner/chef-solo/provisioner.go @@ -183,6 +183,12 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } +func (p *Provisioner) Cancel() { + // Just hard quit. It isn't a big deal if what we're doing keeps + // running on the other side. + os.Exit(0) +} + func (p *Provisioner) uploadDirectory(ui packer.Ui, comm packer.Communicator, dst string, src string) error { if err := p.createDir(ui, comm, dst); err != nil { return err diff --git a/provisioner/file/provisioner.go b/provisioner/file/provisioner.go index c49047f65..096b10437 100644 --- a/provisioner/file/provisioner.go +++ b/provisioner/file/provisioner.go @@ -84,3 +84,9 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { } return err } + +func (p *Provisioner) Cancel() { + // Just hard quit. It isn't a big deal if what we're doing keeps + // running on the other side. + os.Exit(0) +} diff --git a/provisioner/salt-masterless/provisioner.go b/provisioner/salt-masterless/provisioner.go index 24d29097c..8e0f11f1f 100644 --- a/provisioner/salt-masterless/provisioner.go +++ b/provisioner/salt-masterless/provisioner.go @@ -200,6 +200,12 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } +func (p *Provisioner) Cancel() { + // Just hard quit. It isn't a big deal if what we're doing keeps + // running on the other side. + os.Exit(0) +} + func UploadLocalDirectory(localDir string, remoteDir string, comm packer.Communicator, ui packer.Ui) (err error) { visitPath := func(localPath string, f os.FileInfo, err error) (err2 error) { localRelPath := strings.Replace(localPath, localDir, "", 1) diff --git a/provisioner/shell/provisioner.go b/provisioner/shell/provisioner.go index 29799cba9..e8a5f61a5 100644 --- a/provisioner/shell/provisioner.go +++ b/provisioner/shell/provisioner.go @@ -281,6 +281,12 @@ func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error { return nil } +func (p *Provisioner) Cancel() { + // Just hard quit. It isn't a big deal if what we're doing keeps + // running on the other side. + os.Exit(0) +} + // retryable will retry the given function over and over until a // non-error is returned. func (p *Provisioner) retryable(f func() error) error {