From a9b0081828efeb97f8ff04c0fd6d80b3f1ae663e Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Mon, 9 Dec 2013 14:24:55 -0800 Subject: [PATCH] packer/rpc: muxconn is a lot more sane, acts like bsd socket --- packer/rpc/artifact_test.go | 17 +-- packer/rpc/client.go | 38 ++++++ packer/rpc/client_test.go | 9 ++ packer/rpc/muxconn.go | 248 ++++++++++++++++++++++++++++++------ packer/rpc/muxconn_test.go | 20 ++- packer/rpc/server.go | 9 -- packer/rpc/server_new.go | 61 +++++++++ packer/rpc/ui.go | 13 +- packer/rpc/ui_test.go | 2 +- 9 files changed, 337 insertions(+), 80 deletions(-) create mode 100644 packer/rpc/client.go create mode 100644 packer/rpc/client_test.go create mode 100644 packer/rpc/server_new.go diff --git a/packer/rpc/artifact_test.go b/packer/rpc/artifact_test.go index 336fa6d8b..4cc49df83 100644 --- a/packer/rpc/artifact_test.go +++ b/packer/rpc/artifact_test.go @@ -2,7 +2,6 @@ package rpc import ( "github.com/mitchellh/packer/packer" - "net/rpc" "reflect" "testing" ) @@ -31,19 +30,13 @@ func (testArtifact) Destroy() error { func TestArtifactRPC(t *testing.T) { // Create the interface to test - a := new(testArtifact) + a := new(packer.MockArtifact) // Start the server - server := rpc.NewServer() - RegisterArtifact(server, a) - address := serveSingleConn(server) - - // Create the client over RPC and run some methods to verify it works - client, err := rpc.Dial("tcp", address) - if err != nil { - t.Fatalf("err: %s", err) - } - aClient := Artifact(client) + server := NewServer() + server.RegisterArtifact(a) + client := testClient(t, server) + aClient := client.Artifact() // Test if aClient.BuilderId() != "bid" { diff --git a/packer/rpc/client.go b/packer/rpc/client.go new file mode 100644 index 000000000..73b3ce1d8 --- /dev/null +++ b/packer/rpc/client.go @@ -0,0 +1,38 @@ +package rpc + +import ( + "github.com/mitchellh/packer/packer" + "io" + "net/rpc" +) + +// Client is the client end that communicates with a Packer RPC server. +// Establishing a connection is up to the user, the Client can just +// communicate over any ReadWriteCloser. +type Client struct { + mux *MuxConn + client *rpc.Client +} + +func NewClient(rwc io.ReadWriteCloser) (*Client, error) { + // Create the MuxConn around the RWC and get the client to server stream. + // This is the primary stream that we use to communicate with the + // remote RPC server. On the remote side Server.ServeConn also listens + // on this stream ID. + mux := NewMuxConn(rwc) + stream, err := mux.Dial(0) + if err != nil { + return nil, err + } + + return &Client{ + mux: mux, + client: rpc.NewClient(stream), + }, nil +} + +func (c *Client) Artifact() packer.Artifact { + return &artifact{ + client: c.client, + } +} diff --git a/packer/rpc/client_test.go b/packer/rpc/client_test.go new file mode 100644 index 000000000..1f9b930a7 --- /dev/null +++ b/packer/rpc/client_test.go @@ -0,0 +1,9 @@ +package rpc + +import ( + "testing" +) + +func testClient(t *testing.T, server *Server) *Client { + return nil +} diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index 9907ce102..cf296be04 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -6,6 +6,7 @@ import ( "io" "log" "sync" + "time" ) // MuxConn is a connection that can be used bi-directionally for RPC. Normally, @@ -20,15 +21,24 @@ import ( // we decided to cut a lot of corners and make this easily usable for Packer. type MuxConn struct { rwc io.ReadWriteCloser - streams map[byte]io.WriteCloser + streams map[byte]*Stream mu sync.RWMutex wlock sync.Mutex } +type muxPacketType byte + +const ( + muxPacketSyn muxPacketType = iota + muxPacketAck + muxPacketFin + muxPacketData +) + func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { m := &MuxConn{ rwc: rwc, - streams: make(map[byte]io.WriteCloser), + streams: make(map[byte]*Stream), } go m.loop() @@ -46,56 +56,140 @@ func (m *MuxConn) Close() error { for _, w := range m.streams { w.Close() } - m.streams = make(map[byte]io.WriteCloser) + m.streams = make(map[byte]*Stream) return m.rwc.Close() } -// Stream returns a io.ReadWriteCloser that will only read/write to the -// given stream ID. No handshake is done so if the remote end does not -// have a stream open with the same ID, then the messages will simply -// be dropped. -// -// This is one of those cases where we cut corners. Since Packer only does -// local connections, we can assume that both ends are ready at a certain -// point. In a real muxer, we'd probably want a handshake here. -func (m *MuxConn) Stream(id byte) (io.ReadWriteCloser, error) { +// Accept accepts a multiplexed connection with the given ID. This +// will block until a request is made to connect. +func (m *MuxConn) Accept(id byte) (io.ReadWriteCloser, error) { + stream, err := m.openStream(id) + if err != nil { + return nil, err + } + + // If the stream isn't closed, then it is already open somehow + stream.mu.Lock() + if stream.state != streamStateSynRecv && stream.state != streamStateClosed { + stream.mu.Unlock() + return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state) + } + + if stream.state == streamStateSynRecv { + // Fast track establishing since we already got the syn + stream.setState(streamStateEstablished) + stream.mu.Unlock() + } + + if stream.state != streamStateEstablished { + // Go into the listening state + stream.setState(streamStateListen) + stream.mu.Unlock() + + // Wait for the connection to establish + ACCEPT_ESTABLISH_LOOP: + for { + time.Sleep(50 * time.Millisecond) + stream.mu.Lock() + switch stream.state { + case streamStateListen: + stream.mu.Unlock() + case streamStateEstablished: + stream.mu.Unlock() + break ACCEPT_ESTABLISH_LOOP + default: + defer stream.mu.Unlock() + return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) + } + } + } + + // Send the ack down + if _, err := m.write(stream.id, muxPacketAck, nil); err != nil { + return nil, err + } + + return stream, nil +} + +// Dial opens a connection to the remote end using the given stream ID. +// An Accept on the remote end will only work with if the IDs match. +func (m *MuxConn) Dial(id byte) (io.ReadWriteCloser, error) { + stream, err := m.openStream(id) + if err != nil { + return nil, err + } + + // If the stream isn't closed, then it is already open somehow + stream.mu.Lock() + if stream.state != streamStateClosed { + stream.mu.Unlock() + return nil, fmt.Errorf("Stream already open in bad state: %d", stream.state) + } + + // Open a connection + if _, err := m.write(stream.id, muxPacketSyn, nil); err != nil { + return nil, err + } + stream.setState(streamStateSynSent) + stream.mu.Unlock() + + for { + time.Sleep(50 * time.Millisecond) + stream.mu.Lock() + switch stream.state { + case streamStateSynSent: + stream.mu.Unlock() + case streamStateEstablished: + stream.mu.Unlock() + return stream, nil + default: + defer stream.mu.Unlock() + return nil, fmt.Errorf("Stream went to bad state: %d", stream.state) + } + } +} + +func (m *MuxConn) openStream(id byte) (*Stream, error) { m.mu.Lock() + defer m.mu.Unlock() - if _, ok := m.streams[id]; ok { - m.mu.Unlock() - return nil, fmt.Errorf("Stream %d already exists", id) + if stream, ok := m.streams[id]; ok { + return stream, nil } // Create the stream object and channel where data will be sent to dataR, dataW := io.Pipe() // Set the data channel so we can write to it. - m.streams[id] = dataW - - // Unlock the lock so that the reader can access the stream writer. - m.mu.Unlock() - stream := &Stream{ id: id, mux: m, reader: dataR, + writer: dataW, } + stream.setState(streamStateClosed) - return stream, nil + m.streams[id] = stream + return m.streams[id], nil } func (m *MuxConn) loop() { defer m.Close() + var id byte + var packetType muxPacketType + var length int32 for { - var id byte - var length int32 - if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { log.Printf("[ERR] Error reading stream ID: %s", err) return } + if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil { + log.Printf("[ERR] Error reading packet type: %s", err) + return + } if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { log.Printf("[ERR] Error reading length: %s", err) return @@ -103,44 +197,115 @@ func (m *MuxConn) loop() { // TODO(mitchellh): probably would be better to re-use a buffer... data := make([]byte, length) - if _, err := m.rwc.Read(data); err != nil { - log.Printf("[ERR] Error reading data: %s", err) + if length > 0 { + if _, err := m.rwc.Read(data); err != nil { + log.Printf("[ERR] Error reading data: %s", err) + return + } + } + + stream, err := m.openStream(id) + if err != nil { + log.Printf("[ERR] Error opening stream %d: %s", id, err) return } - m.mu.RLock() - w, ok := m.streams[id] - if ok { - // Note that if this blocks, it'll block the whole read loop. - // Danger here... not sure how to handle it though. - w.Write(data) + switch packetType { + case muxPacketAck: + stream.mu.Lock() + if stream.state == streamStateSynSent { + stream.setState(streamStateEstablished) + } else { + log.Printf("[ERR] Ack received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketSyn: + stream.mu.Lock() + switch stream.state { + case streamStateClosed: + stream.setState(streamStateSynRecv) + case streamStateListen: + stream.setState(streamStateEstablished) + default: + log.Printf("[ERR] Syn received for stream in state: %d", stream.state) + } + stream.mu.Unlock() + case muxPacketFin: + stream.mu.Lock() + stream.setState(streamStateClosed) + stream.writer.Close() + stream.mu.Unlock() + + m.mu.Lock() + delete(m.streams, stream.id) + m.mu.Unlock() + case muxPacketData: + stream.mu.Lock() + if stream.state == streamStateEstablished { + stream.writer.Write(data) + } else { + log.Printf("[ERR] Data received for stream in state: %d", stream.state) + } + stream.mu.Unlock() } - m.mu.RUnlock() } } -func (m *MuxConn) write(id byte, p []byte) (int, error) { +func (m *MuxConn) write(id byte, dataType muxPacketType, p []byte) (int, error) { m.wlock.Lock() defer m.wlock.Unlock() if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { return 0, err } + if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { + return 0, err + } if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { return 0, err } + if len(p) == 0 { + return 0, nil + } return m.rwc.Write(p) } // Stream is a single stream of data and implements io.ReadWriteCloser type Stream struct { - id byte - mux *MuxConn - reader io.Reader + id byte + mux *MuxConn + reader io.Reader + writer io.WriteCloser + state streamState + stateUpdated time.Time + mu sync.Mutex } +type streamState byte + +const ( + streamStateClosed streamState = iota + streamStateListen + streamStateSynRecv + streamStateSynSent + streamStateEstablished + streamStateFinWait +) + func (s *Stream) Close() error { - // Not functional yet, does it ever have to be? + s.mu.Lock() + defer s.mu.Unlock() + + if s.state != streamStateEstablished { + return fmt.Errorf("Stream in bad state: %d", s.state) + } + + if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil { + return err + } + + s.setState(streamStateClosed) + s.writer.Close() return nil } @@ -149,5 +314,10 @@ func (s *Stream) Read(p []byte) (int, error) { } func (s *Stream) Write(p []byte) (int, error) { - return s.mux.write(s.id, p) + return s.mux.write(s.id, muxPacketData, p) +} + +func (s *Stream) setState(state streamState) { + s.state = state + s.stateUpdated = time.Now().UTC() } diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go index 9c7b69eb7..fce29b3af 100644 --- a/packer/rpc/muxconn_test.go +++ b/packer/rpc/muxconn_test.go @@ -56,24 +56,21 @@ func TestMuxConn(t *testing.T) { // When the server is done doneCh := make(chan struct{}) - readyCh := make(chan struct{}) // The server side go func() { defer close(doneCh) - s0, err := server.Stream(0) + s0, err := server.Accept(0) if err != nil { t.Fatalf("err: %s", err) } - s1, err := server.Stream(1) + s1, err := server.Dial(1) if err != nil { t.Fatalf("err: %s", err) } - close(readyCh) - var wg sync.WaitGroup wg.Add(2) @@ -96,19 +93,16 @@ func TestMuxConn(t *testing.T) { wg.Wait() }() - s0, err := client.Stream(0) + s0, err := client.Dial(0) if err != nil { t.Fatalf("err: %s", err) } - s1, err := client.Stream(1) + s1, err := client.Accept(1) if err != nil { t.Fatalf("err: %s", err) } - // Wait for the server to be ready - <-readyCh - if _, err := s0.Write([]byte("hello")); err != nil { t.Fatalf("err: %s", err) } @@ -124,8 +118,9 @@ func TestMuxConn_clientClosesStreams(t *testing.T) { client, server := testMux(t) defer client.Close() defer server.Close() + go server.Accept(0) - s0, err := client.Stream(0) + s0, err := client.Dial(0) if err != nil { t.Fatalf("err: %s", err) } @@ -146,8 +141,9 @@ func TestMuxConn_serverClosesStreams(t *testing.T) { client, server := testMux(t) defer client.Close() defer server.Close() + go server.Accept(0) - s0, err := client.Stream(0) + s0, err := client.Dial(0) if err != nil { t.Fatalf("err: %s", err) } diff --git a/packer/rpc/server.go b/packer/rpc/server.go index c9154160a..e1ab19663 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -1,15 +1,10 @@ package rpc import ( - "fmt" "github.com/mitchellh/packer/packer" "net/rpc" - "sync/atomic" ) -// This keeps track of the endpoint ID to use when registering artifacts. -var endpointId uint64 = 0 - // Registers the appropriate endpoint on an RPC server to serve an // Artifact. func RegisterArtifact(s *rpc.Server, a packer.Artifact) { @@ -82,10 +77,6 @@ func RegisterUi(s *rpc.Server, ui packer.Ui) { // The endpoint name is returned. func registerComponent(s *rpc.Server, name string, rcvr interface{}, id bool) string { endpoint := name - if id { - fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&endpointId, 1)) - } - s.RegisterName(endpoint, rcvr) return endpoint } diff --git a/packer/rpc/server_new.go b/packer/rpc/server_new.go new file mode 100644 index 000000000..4fcda1f00 --- /dev/null +++ b/packer/rpc/server_new.go @@ -0,0 +1,61 @@ +package rpc + +import ( + "fmt" + "github.com/mitchellh/packer/packer" + "io" + "log" + "net/rpc" + "sync/atomic" +) + +// Server represents an RPC server for Packer. This must be paired on +// the other side with a Client. +type Server struct { + endpointId uint64 + rpcServer *rpc.Server +} + +// NewServer returns a new Packer RPC server. +func NewServer() *Server { + return &Server{ + endpointId: 0, + rpcServer: rpc.NewServer(), + } +} + +func (s *Server) RegisterArtifact(a packer.Artifact) { + s.registerComponent("Artifact", &ArtifactServer{a}, false) +} + +// ServeConn serves a single connection over the RPC server. It is up +// to the caller to obtain a proper io.ReadWriteCloser. +func (s *Server) ServeConn(conn io.ReadWriteCloser) { + mux := NewMuxConn(conn) + defer mux.Close() + + // Get stream ID 0, which we always use as the stream for serving + // our RPC server on. + stream, err := mux.Accept(0) + if err != nil { + log.Printf("[ERR] Error retrieving stream for serving: %s", err) + return + } + + s.rpcServer.ServeConn(stream) +} + +// registerComponent registers a single Packer RPC component onto +// the RPC server. If id is true, then a unique ID number will be appended +// onto the end of the endpoint. +// +// The endpoint name is returned. +func (s *Server) registerComponent(name string, rcvr interface{}, id bool) string { + endpoint := name + if id { + fmt.Sprintf("%s.%d", endpoint, atomic.AddUint64(&s.endpointId, 1)) + } + + s.rpcServer.RegisterName(endpoint, rcvr) + return endpoint +} diff --git a/packer/rpc/ui.go b/packer/rpc/ui.go index 1857e4928..4d7ccc57f 100644 --- a/packer/rpc/ui.go +++ b/packer/rpc/ui.go @@ -9,8 +9,7 @@ import ( // An implementation of packer.Ui where the Ui is actually executed // over an RPC connection. type Ui struct { - client *rpc.Client - endpoint string + client *rpc.Client } // UiServer wraps a packer.Ui implementation and makes it exportable @@ -26,12 +25,12 @@ type UiMachineArgs struct { } func (u *Ui) Ask(query string) (result string, err error) { - err = u.client.Call(u.endpoint+".Ask", query, &result) + err = u.client.Call("Ui.Ask", query, &result) return } func (u *Ui) Error(message string) { - if err := u.client.Call(u.endpoint+".Error", message, new(interface{})); err != nil { + if err := u.client.Call("Ui.Error", message, new(interface{})); err != nil { log.Printf("Error in Ui RPC call: %s", err) } } @@ -42,19 +41,19 @@ func (u *Ui) Machine(t string, args ...string) { Args: args, } - if err := u.client.Call(u.endpoint+".Machine", rpcArgs, new(interface{})); err != nil { + if err := u.client.Call("Ui.Machine", rpcArgs, new(interface{})); err != nil { log.Printf("Error in Ui RPC call: %s", err) } } func (u *Ui) Message(message string) { - if err := u.client.Call(u.endpoint+".Message", message, new(interface{})); err != nil { + if err := u.client.Call("Ui.Message", message, new(interface{})); err != nil { log.Printf("Error in Ui RPC call: %s", err) } } func (u *Ui) Say(message string) { - if err := u.client.Call(u.endpoint+".Say", message, new(interface{})); err != nil { + if err := u.client.Call("Ui.Say", message, new(interface{})); err != nil { log.Printf("Error in Ui RPC call: %s", err) } } diff --git a/packer/rpc/ui_test.go b/packer/rpc/ui_test.go index 3a46bf02c..5241f7989 100644 --- a/packer/rpc/ui_test.go +++ b/packer/rpc/ui_test.go @@ -62,7 +62,7 @@ func TestUiRPC(t *testing.T) { panic(err) } - uiClient := &Ui{client: client, endpoint: "Ui"} + uiClient := &Ui{client: client} // Basic error and say tests result, err := uiClient.Ask("query")