From 41b323ce4fa874b1c62cede236e8d10a0e3752b3 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Wed, 11 Dec 2013 12:24:45 -0800 Subject: [PATCH] packer/plugin: communicate over unix domain sockets if you can --- packer/plugin/client.go | 36 +++++++++------ packer/plugin/client_test.go | 8 +++- packer/plugin/plugin_test.go | 8 ++-- packer/plugin/{plugin.go => server.go} | 64 ++++++++++++++++++++------ 4 files changed, 82 insertions(+), 34 deletions(-) rename packer/plugin/{plugin.go => server.go} (69%) diff --git a/packer/plugin/client.go b/packer/plugin/client.go index 3fefca15f..a06403f4f 100644 --- a/packer/plugin/client.go +++ b/packer/plugin/client.go @@ -34,7 +34,7 @@ type Client struct { exited bool doneLogging chan struct{} l sync.Mutex - address string + address net.Addr } // ClientConfig is the configuration used to initialize a new @@ -206,11 +206,11 @@ func (c *Client) Kill() { // This method is safe to call multiple times. Subsequent calls have no effect. // Once a client has been started once, it cannot be started again, even if // it was killed. -func (c *Client) Start() (address string, err error) { +func (c *Client) Start() (addr net.Addr, err error) { c.l.Lock() defer c.l.Unlock() - if c.address != "" { + if c.address != nil { return c.address, nil } @@ -320,8 +320,8 @@ func (c *Client) Start() (address string, err error) { // Trim the line and split by "|" in order to get the parts of // the output. line := strings.TrimSpace(string(lineBytes)) - parts := strings.SplitN(line, "|", 2) - if len(parts) < 2 { + parts := strings.SplitN(line, "|", 3) + if len(parts) < 3 { err = fmt.Errorf("Unrecognized remote plugin message: %s", line) return } @@ -333,10 +333,17 @@ func (c *Client) Start() (address string, err error) { return } - c.address = parts[1] - address = c.address + switch parts[1] { + case "tcp": + addr, err = net.ResolveTCPAddr("tcp", parts[2]) + case "unix": + addr, err = net.ResolveUnixAddr("unix", parts[2]) + default: + err = fmt.Errorf("Unknown address type: %s", parts[1]) + } } + c.address = addr return } @@ -361,23 +368,24 @@ func (c *Client) logStderr(r io.Reader) { } func (c *Client) packrpcClient() (*packrpc.Client, error) { - address, err := c.Start() + addr, err := c.Start() if err != nil { return nil, err } - conn, err := net.Dial("tcp", address) + conn, err := net.Dial(addr.Network(), addr.String()) if err != nil { return nil, err } - // Make sure to set keep alive so that the connection doesn't die - tcpConn := conn.(*net.TCPConn) - tcpConn.SetKeepAlive(true) + if tcpConn, ok := conn.(*net.TCPConn); ok { + // Make sure to set keep alive so that the connection doesn't die + tcpConn.SetKeepAlive(true) + } - client, err := packrpc.NewClient(tcpConn) + client, err := packrpc.NewClient(conn) if err != nil { - tcpConn.Close() + conn.Close() return nil, err } diff --git a/packer/plugin/client_test.go b/packer/plugin/client_test.go index f9257034e..d558b4912 100644 --- a/packer/plugin/client_test.go +++ b/packer/plugin/client_test.go @@ -20,8 +20,12 @@ func TestClient(t *testing.T) { t.Fatalf("err should be nil, got %s", err) } - if addr != ":1234" { - t.Fatalf("incorrect addr %s", addr) + if addr.Network() != "tcp" { + t.Fatalf("bad: %#v", addr) + } + + if addr.String() != ":1234" { + t.Fatalf("bad: %#v", addr) } // Test that it exits properly if killed diff --git a/packer/plugin/plugin_test.go b/packer/plugin/plugin_test.go index 733190ec3..d2cf2d201 100644 --- a/packer/plugin/plugin_test.go +++ b/packer/plugin/plugin_test.go @@ -51,7 +51,7 @@ func TestHelperProcess(*testing.T) { cmd, args := args[0], args[1:] switch cmd { case "bad-version": - fmt.Printf("%s1|:1234\n", APIVersion) + fmt.Printf("%s1|tcp|:1234\n", APIVersion) <-make(chan int) case "builder": server, err := Server() @@ -80,7 +80,7 @@ func TestHelperProcess(*testing.T) { case "invalid-rpc-address": fmt.Println("lolinvalid") case "mock": - fmt.Printf("%s|:1234\n", APIVersion) + fmt.Printf("%s|tcp|:1234\n", APIVersion) <-make(chan int) case "post-processor": server, err := Server() @@ -102,11 +102,11 @@ func TestHelperProcess(*testing.T) { time.Sleep(1 * time.Minute) os.Exit(1) case "stderr": - fmt.Printf("%s|:1234\n", APIVersion) + fmt.Printf("%s|tcp|:1234\n", APIVersion) log.Println("HELLO") log.Println("WORLD") case "stdin": - fmt.Printf("%s|:1234\n", APIVersion) + fmt.Printf("%s|tcp|:1234\n", APIVersion) data := make([]byte, 5) if _, err := os.Stdin.Read(data); err != nil { log.Printf("stdin read error: %s", err) diff --git a/packer/plugin/plugin.go b/packer/plugin/server.go similarity index 69% rename from packer/plugin/plugin.go rename to packer/plugin/server.go index 63c19b9b3..a3dbc7b24 100644 --- a/packer/plugin/plugin.go +++ b/packer/plugin/server.go @@ -12,6 +12,7 @@ import ( "fmt" "github.com/mitchellh/packer/packer" packrpc "github.com/mitchellh/packer/packer/rpc" + "io/ioutil" "log" "net" "os" @@ -32,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 // The APIVersion is outputted along with the RPC address. The plugin // client validates this API version and will show an error if it doesn't // know how to speak it. -const APIVersion = "1" +const APIVersion = "2" // Server waits for a connection to this plugin and returns a Packer // RPC server that you can use to register components and serve them. @@ -62,23 +63,19 @@ func Server() (*packrpc.Server, error) { log.Printf("Plugin minimum port: %d\n", minPort) log.Printf("Plugin maximum port: %d\n", maxPort) - var address string - var listener net.Listener - for port := minPort; port <= maxPort; port++ { - address = fmt.Sprintf("127.0.0.1:%d", port) - listener, err = net.Listen("tcp", address) - if err != nil { - err = nil - continue - } - - break + listener, err := serverListener(minPort, maxPort) + if err != nil { + return nil, err } defer listener.Close() // Output the address to stdout - log.Printf("Plugin address: %s\n", address) - fmt.Printf("%s|%s\n", APIVersion, address) + log.Printf("Plugin address: %s %s\n", + listener.Addr().Network(), listener.Addr().String()) + fmt.Printf("%s|%s|%s\n", + APIVersion, + listener.Addr().Network(), + listener.Addr().String()) os.Stdout.Sync() // Accept a connection @@ -105,3 +102,42 @@ func Server() (*packrpc.Server, error) { log.Println("Serving a plugin connection...") return packrpc.NewServer(conn), nil } + +func serverListener(minPort, maxPort int64) (net.Listener, error) { + if runtime.GOOS == "windows" { + return serverListener_tcp(minPort, maxPort) + } + + return serverListener_unix() +} + +func serverListener_tcp(minPort, maxPort int64) (net.Listener, error) { + for port := minPort; port <= maxPort; port++ { + address := fmt.Sprintf("127.0.0.1:%d", port) + listener, err := net.Listen("tcp", address) + if err == nil { + return listener, nil + } + } + + return nil, errors.New("Couldn't bind plugin TCP listener") +} + +func serverListener_unix() (net.Listener, error) { + tf, err := ioutil.TempFile("", "packer-plugin") + if err != nil { + return nil, err + } + path := tf.Name() + + // Close the file and remove it because it has to not exist for + // the domain socket. + if err := tf.Close(); err != nil { + return nil, err + } + if err := os.Remove(path); err != nil { + return nil, err + } + + return net.Listen("unix", path) +}