From 062e86e21894965ba78905e305ad8abae63a2529 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 2 Sep 2014 14:05:22 -0700 Subject: [PATCH 1/3] packer/rpc: MuxBroker --- packer/rpc/mux_broker.go | 160 ++++++++++++++++++++++++++++++++++ packer/rpc/mux_broker_test.go | 82 +++++++++++++++++ 2 files changed, 242 insertions(+) create mode 100644 packer/rpc/mux_broker.go create mode 100644 packer/rpc/mux_broker_test.go diff --git a/packer/rpc/mux_broker.go b/packer/rpc/mux_broker.go new file mode 100644 index 000000000..2e1061c9d --- /dev/null +++ b/packer/rpc/mux_broker.go @@ -0,0 +1,160 @@ +package rpc + +import ( + "encoding/binary" + "fmt" + "net" + "sync" + "time" + + "github.com/hashicorp/yamux" +) + +// muxBroker is responsible for brokering multiplexed connections by unique ID. +// +// This allows a plugin to request a channel with a specific ID to connect to +// or accept a connection from, and the broker handles the details of +// holding these channels open while they're being negotiated. +type muxBroker struct { + session *yamux.Session + streams map[uint32]*muxBrokerPending + + sync.Mutex +} + +type muxBrokerPending struct { + ch chan net.Conn + doneCh chan struct{} +} + +func newMuxBroker(s *yamux.Session) *muxBroker { + return &muxBroker{ + session: s, + streams: make(map[uint32]*muxBrokerPending), + } +} + +// Accept accepts a connection by ID. +// +// This should not be called multiple times with the same ID at one time. +func (m *muxBroker) Accept(id uint32) (net.Conn, error) { + var c net.Conn + p := m.getStream(id) + select { + case c = <-p.ch: + close(p.doneCh) + case <-time.After(5 * time.Second): + m.Lock() + defer m.Unlock() + delete(m.streams, id) + + return nil, fmt.Errorf("timeout waiting for accept") + } + + // Ack our connection + if err := binary.Write(c, binary.LittleEndian, id); err != nil { + c.Close() + return nil, err + } + + return c, nil +} + +// Dial opens a connection by ID. +func (m *muxBroker) Dial(id uint32) (net.Conn, error) { + // Open the stream + stream, err := m.session.OpenStream() + if err != nil { + return nil, err + } + + // Write the stream ID onto the wire. + if err := binary.Write(stream, binary.LittleEndian, id); err != nil { + stream.Close() + return nil, err + } + + // Read the ack that we connected. Then we're off! + var ack uint32 + if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { + stream.Close() + return nil, err + } + if ack != id { + stream.Close() + return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) + } + + return stream, nil +} + +// Run starts the brokering and should be executed in a goroutine, since it +// blocks forever, or until the session closes. +func (m *muxBroker) Run() { + for { + stream, err := m.session.AcceptStream() + if err != nil { + // Once we receive an error, just exit + break + } + + // Read the stream ID from the stream + var id uint32 + if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { + stream.Close() + continue + } + + // Initialize the waiter + p := m.getStream(id) + select { + case p.ch <- stream: + default: + } + + // Wait for a timeout + go m.timeoutWait(id, p) + } +} + +func (m *muxBroker) getStream(id uint32) *muxBrokerPending { + m.Lock() + defer m.Unlock() + + p, ok := m.streams[id] + if ok { + return p + } + + m.streams[id] = &muxBrokerPending{ + ch: make(chan net.Conn, 1), + doneCh: make(chan struct{}), + } + return m.streams[id] +} + +func (m *muxBroker) timeoutWait(id uint32, p *muxBrokerPending) { + // Wait for the stream to either be picked up and connected, or + // for a timeout. + timeout := false + select { + case <-p.doneCh: + case <-time.After(5 * time.Second): + timeout = true + } + + m.Lock() + defer m.Unlock() + + // Delete the stream so no one else can grab it + delete(m.streams, id) + + // If we timed out, then check if we have a channel in the buffer, + // and if so, close it. + if timeout { + select { + case s := <-p.ch: + s.Close() + } + } +} diff --git a/packer/rpc/mux_broker_test.go b/packer/rpc/mux_broker_test.go new file mode 100644 index 000000000..88739a0ff --- /dev/null +++ b/packer/rpc/mux_broker_test.go @@ -0,0 +1,82 @@ +package rpc + +import ( + "net" + "testing" + + "github.com/hashicorp/yamux" +) + +func TestMuxBroker(t *testing.T) { + c, s := testYamux(t) + defer c.Close() + defer s.Close() + + bc := newMuxBroker(c) + bs := newMuxBroker(s) + go bc.Run() + go bs.Run() + + go func() { + c, err := bc.Dial(5) + if err != nil { + t.Fatalf("err: %s", err) + } + + if _, err := c.Write([]byte{42}); err != nil { + t.Fatalf("err: %s", err) + } + }() + + client, err := bs.Accept(5) + if err != nil { + t.Fatalf("err: %s", err) + } + + var data [1]byte + if _, err := client.Read(data[:]); err != nil { + t.Fatalf("err: %s", err) + } + + if data[0] != 42 { + t.Fatalf("bad: %d", data[0]) + } +} + +func testYamux(t *testing.T) (client *yamux.Session, server *yamux.Session) { + l, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("err: %s", err) + } + + // Server side + doneCh := make(chan struct{}) + go func() { + defer close(doneCh) + conn, err := l.Accept() + l.Close() + if err != nil { + t.Fatalf("err: %s", err) + } + + server, err = yamux.Server(conn, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + }() + + // Client side + conn, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatalf("err: %s", err) + } + client, err = yamux.Client(conn, nil) + if err != nil { + t.Fatalf("err: %s", err) + } + + // Wait for the server + <-doneCh + + return +} From 9ffa0b8e25dfc4b0b68de681c08ab40b010d0318 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 2 Sep 2014 14:23:06 -0700 Subject: [PATCH 2/3] packer/rpc: no more muxconn --- packer/rpc/build.go | 4 +- packer/rpc/builder.go | 4 +- packer/rpc/client.go | 13 +- packer/rpc/command.go | 4 +- packer/rpc/communicator.go | 6 +- packer/rpc/environment.go | 4 +- packer/rpc/hook.go | 4 +- packer/rpc/mux_broker.go | 31 ++ packer/rpc/muxconn.go | 605 ----------------------------------- packer/rpc/muxconn_test.go | 311 ------------------ packer/rpc/post_processor.go | 4 +- packer/rpc/provisioner.go | 4 +- packer/rpc/server.go | 10 +- 13 files changed, 64 insertions(+), 940 deletions(-) delete mode 100644 packer/rpc/muxconn.go delete mode 100644 packer/rpc/muxconn_test.go diff --git a/packer/rpc/build.go b/packer/rpc/build.go index c2e6dceca..4d0b7edf4 100644 --- a/packer/rpc/build.go +++ b/packer/rpc/build.go @@ -9,14 +9,14 @@ import ( // over an RPC connection. type build struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // BuildServer wraps a packer.Build implementation and makes it exportable // as part of a Golang RPC server. type BuildServer struct { build packer.Build - mux *MuxConn + mux *muxBroker } type BuildPrepareResponse struct { diff --git a/packer/rpc/builder.go b/packer/rpc/builder.go index 5e3f429b5..0e2464bd7 100644 --- a/packer/rpc/builder.go +++ b/packer/rpc/builder.go @@ -10,14 +10,14 @@ import ( // over an RPC connection. type builder struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // BuilderServer wraps a packer.Builder implementation and makes it exportable // as part of a Golang RPC server. type BuilderServer struct { builder packer.Builder - mux *MuxConn + mux *muxBroker } type BuilderPrepareArgs struct { diff --git a/packer/rpc/client.go b/packer/rpc/client.go index 0cc8fc869..7e2ff00fc 100644 --- a/packer/rpc/client.go +++ b/packer/rpc/client.go @@ -12,22 +12,29 @@ import ( // Establishing a connection is up to the user, the Client can just // communicate over any ReadWriteCloser. type Client struct { - mux *MuxConn + mux *muxBroker client *rpc.Client closeMux bool } func NewClient(rwc io.ReadWriteCloser) (*Client, error) { - result, err := newClientWithMux(NewMuxConn(rwc), 0) + mux, err := newMuxBrokerClient(rwc) if err != nil { return nil, err } + go mux.Run() + + result, err := newClientWithMux(mux, 0) + if err != nil { + mux.Close() + return nil, err + } result.closeMux = true return result, err } -func newClientWithMux(mux *MuxConn, streamId uint32) (*Client, error) { +func newClientWithMux(mux *muxBroker, streamId uint32) (*Client, error) { clientConn, err := mux.Dial(streamId) if err != nil { return nil, err diff --git a/packer/rpc/command.go b/packer/rpc/command.go index 1484e5225..b5e5ccd52 100644 --- a/packer/rpc/command.go +++ b/packer/rpc/command.go @@ -9,14 +9,14 @@ import ( // command is actually executed over an RPC connection. type command struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // A CommandServer wraps a packer.Command and makes it exportable as part // of a Golang RPC server. type CommandServer struct { command packer.Command - mux *MuxConn + mux *muxBroker } type CommandRunArgs struct { diff --git a/packer/rpc/communicator.go b/packer/rpc/communicator.go index 9ac539323..e1d6cb649 100644 --- a/packer/rpc/communicator.go +++ b/packer/rpc/communicator.go @@ -12,14 +12,14 @@ import ( // executed over an RPC connection. type communicator struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // CommunicatorServer wraps a packer.Communicator implementation and makes // it exportable as part of a Golang RPC server. type CommunicatorServer struct { c packer.Communicator - mux *MuxConn + mux *muxBroker } type CommandFinished struct { @@ -252,7 +252,7 @@ func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *int return } -func serveSingleCopy(name string, mux *MuxConn, id uint32, dst io.Writer, src io.Reader) { +func serveSingleCopy(name string, mux *muxBroker, id uint32, dst io.Writer, src io.Reader) { conn, err := mux.Accept(id) if err != nil { log.Printf("[ERR] '%s' accept error: %s", name, err) diff --git a/packer/rpc/environment.go b/packer/rpc/environment.go index 644807bc4..5048b54ea 100644 --- a/packer/rpc/environment.go +++ b/packer/rpc/environment.go @@ -10,14 +10,14 @@ import ( // where the actual environment is executed over an RPC connection. type Environment struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // A EnvironmentServer wraps a packer.Environment and makes it exportable // as part of a Golang RPC server. type EnvironmentServer struct { env packer.Environment - mux *MuxConn + mux *muxBroker } type EnvironmentCliArgs struct { diff --git a/packer/rpc/hook.go b/packer/rpc/hook.go index 6f149d41a..4aa7d75bc 100644 --- a/packer/rpc/hook.go +++ b/packer/rpc/hook.go @@ -10,14 +10,14 @@ import ( // over an RPC connection. type hook struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // HookServer wraps a packer.Hook implementation and makes it exportable // as part of a Golang RPC server. type HookServer struct { hook packer.Hook - mux *MuxConn + mux *muxBroker } type HookRunArgs struct { diff --git a/packer/rpc/mux_broker.go b/packer/rpc/mux_broker.go index 2e1061c9d..7af76f640 100644 --- a/packer/rpc/mux_broker.go +++ b/packer/rpc/mux_broker.go @@ -3,8 +3,10 @@ package rpc import ( "encoding/binary" "fmt" + "io" "net" "sync" + "sync/atomic" "time" "github.com/hashicorp/yamux" @@ -16,6 +18,7 @@ import ( // or accept a connection from, and the broker handles the details of // holding these channels open while they're being negotiated. type muxBroker struct { + nextId uint32 session *yamux.Session streams map[uint32]*muxBrokerPending @@ -34,6 +37,24 @@ func newMuxBroker(s *yamux.Session) *muxBroker { } } +func newMuxBrokerClient(rwc io.ReadWriteCloser) (*muxBroker, error) { + s, err := yamux.Client(rwc, nil) + if err != nil { + return nil, err + } + + return newMuxBroker(s), nil +} + +func newMuxBrokerServer(rwc io.ReadWriteCloser) (*muxBroker, error) { + s, err := yamux.Server(rwc, nil) + if err != nil { + return nil, err + } + + return newMuxBroker(s), nil +} + // Accept accepts a connection by ID. // // This should not be called multiple times with the same ID at one time. @@ -60,6 +81,11 @@ func (m *muxBroker) Accept(id uint32) (net.Conn, error) { return c, nil } +// Close closes the connection and all sub-connections. +func (m *muxBroker) Close() error { + return m.session.Close() +} + // Dial opens a connection by ID. func (m *muxBroker) Dial(id uint32) (net.Conn, error) { // Open the stream @@ -88,6 +114,11 @@ func (m *muxBroker) Dial(id uint32) (net.Conn, error) { return stream, nil } +// NextId returns a unique ID to use next. +func (m *muxBroker) NextId() uint32 { + return atomic.AddUint32(&m.nextId, 1) +} + // Run starts the brokering and should be executed in a goroutine, since it // blocks forever, or until the session closes. func (m *muxBroker) Run() { diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go deleted file mode 100644 index bc9de8e72..000000000 --- a/packer/rpc/muxconn.go +++ /dev/null @@ -1,605 +0,0 @@ -package rpc - -import ( - "encoding/binary" - "fmt" - "io" - "log" - "sync" - "time" -) - -// MuxConn is able to multiplex multiple streams on top of any -// io.ReadWriteCloser. These streams act like TCP connections (Dial, Accept, -// Close, full duplex, etc.). -// -// The underlying io.ReadWriteCloser is expected to guarantee delivery -// and ordering, such as TCP. Congestion control and such aren't implemented -// by the streams, so that is also up to the underlying connection. -// -// MuxConn works using a fairly dumb multiplexing technique of simply -// framing every piece of data sent into a prefix + data format. Streams -// are established using a subset of the TCP protocol. Only a subset is -// necessary since we assume ordering on the underlying RWC. -type MuxConn struct { - curId uint32 - rwc io.ReadWriteCloser - streamsAccept map[uint32]*Stream - streamsDial map[uint32]*Stream - muAccept sync.RWMutex - muDial sync.RWMutex - wlock sync.Mutex - doneCh chan struct{} -} - -type muxPacketFrom byte -type muxPacketType byte - -const ( - muxPacketFromAccept muxPacketFrom = iota - muxPacketFromDial -) - -const ( - muxPacketSyn muxPacketType = iota - muxPacketSynAck - muxPacketAck - muxPacketFin - muxPacketData -) - -func (f muxPacketFrom) String() string { - switch f { - case muxPacketFromAccept: - return "accept" - case muxPacketFromDial: - return "dial" - default: - panic("unknown from type") - } -} - -// Create a new MuxConn around any io.ReadWriteCloser. -func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { - m := &MuxConn{ - rwc: rwc, - streamsAccept: make(map[uint32]*Stream), - streamsDial: make(map[uint32]*Stream), - doneCh: make(chan struct{}), - } - - go m.cleaner() - go m.loop() - - return m -} - -// Close closes the underlying io.ReadWriteCloser. This will also close -// all streams that are open. -func (m *MuxConn) Close() error { - m.muAccept.Lock() - m.muDial.Lock() - defer m.muAccept.Unlock() - defer m.muDial.Unlock() - - // Close all the streams - for _, w := range m.streamsAccept { - w.Close() - } - for _, w := range m.streamsDial { - w.Close() - } - m.streamsAccept = make(map[uint32]*Stream) - m.streamsDial = make(map[uint32]*Stream) - - // Close the actual connection. This will also force the loop - // to end since it'll read EOF or closed connection. - return m.rwc.Close() -} - -// Accept accepts a multiplexed connection with the given ID. This -// will block until a request is made to connect. -func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { - //log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id) - - // Get the stream. It is okay if it is already in the list of streams - // because we may have prematurely received a syn for it. - m.muAccept.Lock() - stream, ok := m.streamsAccept[id] - if !ok { - stream = newStream(muxPacketFromAccept, id, m) - m.streamsAccept[id] = stream - } - m.muAccept.Unlock() - - stream.mu.Lock() - defer stream.mu.Unlock() - - // If the stream isn't closed, then it is already open somehow - if stream.state != streamStateSynRecv && stream.state != streamStateClosed { - panic(fmt.Sprintf( - "Stream %d already open in bad state: %d", id, stream.state)) - } - - if stream.state == streamStateClosed { - // Go into the listening state and wait for a syn - stream.setState(streamStateListen) - if err := stream.waitState(streamStateSynRecv); err != nil { - return nil, err - } - } - - if stream.state == streamStateSynRecv { - // Send a syn-ack - if _, err := stream.write(muxPacketSynAck, nil); err != nil { - return nil, err - } - } - - if err := stream.waitState(streamStateEstablished); err != nil { - return nil, err - } - - return stream, nil -} - -// Dial opens a connection to the remote end using the given stream ID. -// An Accept on the remote end will only work with if the IDs match. -func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { - //log.Printf("[TRACE] %p: Dial on stream ID: %d", m, id) - - m.muDial.Lock() - - // If we have any streams with this ID, then it is a failure. The - // reaper should clear out old streams once in awhile. - if stream, ok := m.streamsDial[id]; ok { - m.muDial.Unlock() - panic(fmt.Sprintf( - "Stream %d already open for dial. State: %d", - id, stream.state)) - } - - // Create the new stream and put it in our list. We can then - // unlock because dialing will no longer be allowed on that ID. - stream := newStream(muxPacketFromDial, id, m) - m.streamsDial[id] = stream - - // Don't let anyone else mess with this stream - stream.mu.Lock() - defer stream.mu.Unlock() - - m.muDial.Unlock() - - // Open a connection - if _, err := stream.write(muxPacketSyn, nil); err != nil { - return nil, err - } - - // It is safe to set the state after the write above because - // we hold the stream lock. - stream.setState(streamStateSynSent) - - if err := stream.waitState(streamStateEstablished); err != nil { - return nil, err - } - - stream.write(muxPacketAck, nil) - return stream, nil -} - -// NextId returns the next available listen stream ID that isn't currently -// taken. -func (m *MuxConn) NextId() uint32 { - m.muAccept.Lock() - defer m.muAccept.Unlock() - - for { - // We never use stream ID 0 because 0 is the zero value of a uint32 - // and we want to reserve that for "not in use" - if m.curId == 0 { - m.curId = 1 - } - - result := m.curId - m.curId += 1 - if _, ok := m.streamsAccept[result]; !ok { - return result - } - } -} - -func (m *MuxConn) cleaner() { - checks := []struct { - Map *map[uint32]*Stream - Lock *sync.RWMutex - }{ - {&m.streamsAccept, &m.muAccept}, - {&m.streamsDial, &m.muDial}, - } - - for { - done := false - select { - case <-time.After(500 * time.Millisecond): - case <-m.doneCh: - done = true - } - - for _, check := range checks { - check.Lock.Lock() - for id, s := range *check.Map { - s.mu.Lock() - - if done && s.state != streamStateClosed { - s.closeWriter() - } - - if s.state == streamStateClosed { - // Only clean up the streams that have been closed - // for a certain amount of time. - since := time.Now().UTC().Sub(s.stateUpdated) - if since > 2*time.Second { - delete(*check.Map, id) - } - } - - s.mu.Unlock() - } - check.Lock.Unlock() - } - - if done { - return - } - } -} - -func (m *MuxConn) loop() { - // Force close every stream that we know about when we exit so - // that they all read EOF and don't block forever. - defer func() { - log.Printf("[INFO] Mux connection loop exiting") - close(m.doneCh) - }() - - var from muxPacketFrom - var id uint32 - var packetType muxPacketType - var length int32 - for { - if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil { - log.Printf("[ERR] Error reading stream direction: %s", err) - return - } - if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { - log.Printf("[ERR] Error reading stream ID: %s", err) - return - } - if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil { - log.Printf("[ERR] Error reading packet type: %s", err) - return - } - if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { - log.Printf("[ERR] Error reading length: %s", err) - return - } - - // TODO(mitchellh): probably would be better to re-use a buffer... - data := make([]byte, length) - n := 0 - for n < int(length) { - if n2, err := m.rwc.Read(data[n:]); err != nil { - log.Printf("[ERR] Error reading data: %s", err) - return - } else { - n += n2 - } - } - - // Get the proper stream. Note that the map we look into is - // opposite the "from" because if the dial side is talking to - // us, we need to look into the accept map, and so on. - // - // Note: we also switch the "from" value so that logging - // below is correct. - var stream *Stream - switch from { - case muxPacketFromDial: - m.muAccept.Lock() - stream = m.streamsAccept[id] - m.muAccept.Unlock() - - from = muxPacketFromAccept - case muxPacketFromAccept: - m.muDial.Lock() - stream = m.streamsDial[id] - m.muDial.Unlock() - - from = muxPacketFromDial - default: - panic(fmt.Sprintf("Unknown stream direction: %d", from)) - } - - if stream == nil && packetType != muxPacketSyn { - log.Printf( - "[WARN] %p: Non-existent stream %d (%s) received packer %d", - m, id, from, packetType) - continue - } - - //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType) - switch packetType { - case muxPacketSyn: - // If the stream is nil, this is the only case where we'll - // automatically create the stream struct. - if stream == nil { - var ok bool - - m.muAccept.Lock() - stream, ok = m.streamsAccept[id] - if !ok { - stream = newStream(muxPacketFromAccept, id, m) - m.streamsAccept[id] = stream - } - m.muAccept.Unlock() - } - - stream.mu.Lock() - switch stream.state { - case streamStateClosed: - fallthrough - case streamStateListen: - stream.setState(streamStateSynRecv) - default: - log.Printf("[ERR] Syn received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketAck: - stream.mu.Lock() - switch stream.state { - case streamStateSynRecv: - stream.setState(streamStateEstablished) - case streamStateFinWait1: - stream.setState(streamStateFinWait2) - case streamStateLastAck: - stream.closeWriter() - fallthrough - case streamStateClosing: - stream.setState(streamStateClosed) - default: - log.Printf("[ERR] Ack received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketSynAck: - stream.mu.Lock() - switch stream.state { - case streamStateSynSent: - stream.setState(streamStateEstablished) - default: - log.Printf("[ERR] SynAck received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - case muxPacketFin: - stream.mu.Lock() - switch stream.state { - case streamStateEstablished: - stream.closeWriter() - stream.setState(streamStateCloseWait) - stream.write(muxPacketAck, nil) - case streamStateFinWait2: - stream.closeWriter() - stream.setState(streamStateClosed) - stream.write(muxPacketAck, nil) - case streamStateFinWait1: - stream.closeWriter() - stream.setState(streamStateClosing) - stream.write(muxPacketAck, nil) - default: - log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) - } - stream.mu.Unlock() - - case muxPacketData: - stream.mu.Lock() - switch stream.state { - case streamStateFinWait1: - fallthrough - case streamStateFinWait2: - fallthrough - case streamStateEstablished: - if len(data) > 0 && stream.writeCh != nil { - //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-START", m, id, from) - stream.writeCh <- data - //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-END", m, id, from) - } - default: - log.Printf("[ERR] Data received for stream in state: %d", stream.state) - } - stream.mu.Unlock() - } - } -} - -func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) { - m.wlock.Lock() - defer m.wlock.Unlock() - - if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { - return 0, err - } - if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { - return 0, err - } - - // Write all the bytes. If we don't write all the bytes, report an error - var err error = nil - n := 0 - for n < len(p) { - var n2 int - n2, err = m.rwc.Write(p[n:]) - n += n2 - if err != nil { - log.Printf("[ERR] %p: Stream %d (%s) write error: %s", m, id, from, err) - break - } - } - - return n, err -} - -// Stream is a single stream of data and implements io.ReadWriteCloser. -// A Stream is full-duplex so you can write data as well as read data. -type Stream struct { - from muxPacketFrom - id uint32 - mux *MuxConn - reader io.Reader - state streamState - stateChange map[chan<- streamState]struct{} - stateUpdated time.Time - mu sync.Mutex - writeCh chan<- []byte -} - -type streamState byte - -const ( - streamStateClosed streamState = iota - streamStateListen - streamStateSynRecv - streamStateSynSent - streamStateEstablished - streamStateFinWait1 - streamStateFinWait2 - streamStateCloseWait - streamStateClosing - streamStateLastAck -) - -func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { - // Create the stream object and channel where data will be sent to - dataR, dataW := io.Pipe() - writeCh := make(chan []byte, 4096) - - // Set the data channel so we can write to it. - stream := &Stream{ - from: from, - id: id, - mux: m, - reader: dataR, - writeCh: writeCh, - stateChange: make(map[chan<- streamState]struct{}), - } - stream.setState(streamStateClosed) - - // Start the goroutine that will read from the queue and write - // data out. - go func() { - defer dataW.Close() - - drain := false - for { - data := <-writeCh - if data == nil { - // A nil is a tombstone letting us know we're done - // accepting data. - return - } - - if drain { - // We're draining, meaning we're just waiting for the - // write channel to close. - continue - } - - if _, err := dataW.Write(data); err != nil { - drain = true - } - } - }() - - return stream -} - -func (s *Stream) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - - if s.state != streamStateEstablished && s.state != streamStateCloseWait { - return fmt.Errorf("Stream in bad state: %d", s.state) - } - - if s.state == streamStateEstablished { - s.setState(streamStateFinWait1) - } else { - s.setState(streamStateLastAck) - } - - s.write(muxPacketFin, nil) - return nil -} - -func (s *Stream) Read(p []byte) (int, error) { - return s.reader.Read(p) -} - -func (s *Stream) Write(p []byte) (int, error) { - s.mu.Lock() - state := s.state - s.mu.Unlock() - - if state != streamStateEstablished && state != streamStateCloseWait { - return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) - } - - return s.write(muxPacketData, p) -} - -func (s *Stream) closeWriter() { - if s.writeCh != nil { - s.writeCh <- nil - s.writeCh = nil - } -} - -func (s *Stream) setState(state streamState) { - //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state) - s.state = state - s.stateUpdated = time.Now().UTC() - for ch, _ := range s.stateChange { - select { - case ch <- state: - default: - } - } -} - -func (s *Stream) waitState(target streamState) error { - // Register a state change listener to wait for changes - stateCh := make(chan streamState, 10) - s.stateChange[stateCh] = struct{}{} - s.mu.Unlock() - - defer func() { - s.mu.Lock() - delete(s.stateChange, stateCh) - }() - - //log.Printf("[TRACE] %p: Stream %d (%s) waiting for state: %d", s.mux, s.id, s.from, target) - state := <-stateCh - if state == target { - return nil - } else { - return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) - } -} - -func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) { - return s.mux.write(s.from, s.id, dataType, p) -} diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go deleted file mode 100644 index 27a77fb46..000000000 --- a/packer/rpc/muxconn_test.go +++ /dev/null @@ -1,311 +0,0 @@ -package rpc - -import ( - "io" - "net" - "sync" - "testing" -) - -func readStream(t *testing.T, s io.Reader) string { - var data [1024]byte - n, err := s.Read(data[:]) - if err != nil { - t.Fatalf("err: %s", err) - } - - return string(data[0:n]) -} - -func testMux(t *testing.T) (client *MuxConn, server *MuxConn) { - l, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("err: %s", err) - } - - // Server side - doneCh := make(chan struct{}) - go func() { - defer close(doneCh) - conn, err := l.Accept() - l.Close() - if err != nil { - t.Fatalf("err: %s", err) - } - - server = NewMuxConn(conn) - }() - - // Client side - conn, err := net.Dial("tcp", l.Addr().String()) - if err != nil { - t.Fatalf("err: %s", err) - } - client = NewMuxConn(conn) - - // Wait for the server - <-doneCh - - return -} - -func TestMuxConn(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - // When the server is done - doneCh := make(chan struct{}) - - // The server side - go func() { - defer close(doneCh) - - s0, err := server.Accept(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - s1, err := server.Dial(1) - if err != nil { - t.Fatalf("err: %s", err) - } - - var wg sync.WaitGroup - wg.Add(2) - - go func() { - defer wg.Done() - defer s1.Close() - data := readStream(t, s1) - if data != "another" { - t.Fatalf("bad: %#v", data) - } - }() - - go func() { - defer wg.Done() - defer s0.Close() - data := readStream(t, s0) - if data != "hello" { - t.Fatalf("bad: %#v", data) - } - }() - - wg.Wait() - }() - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - s1, err := client.Accept(1) - if err != nil { - t.Fatalf("err: %s", err) - } - - if _, err := s0.Write([]byte("hello")); err != nil { - t.Fatalf("err: %s", err) - } - if _, err := s1.Write([]byte("another")); err != nil { - t.Fatalf("err: %s", err) - } - - s0.Close() - s1.Close() - - // Wait for the server to be done - <-doneCh -} - -func TestMuxConn_lotsOfData(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - // When the server is done - doneCh := make(chan struct{}) - - // The server side - go func() { - defer close(doneCh) - - s0, err := server.Accept(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - var data [1024]byte - for { - n, err := s0.Read(data[:]) - if err == io.EOF { - break - } - - dataString := string(data[0:n]) - if dataString != "hello" { - t.Fatalf("bad: %#v", dataString) - } - } - - s0.Close() - }() - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - for i := 0; i < 4096*4; i++ { - if _, err := s0.Write([]byte("hello")); err != nil { - t.Fatalf("err: %s", err) - } - } - - if err := s0.Close(); err != nil { - t.Fatalf("err: %s", err) - } - - // Wait for the server to be done - <-doneCh -} - -// This tests that even when the client end is closed, data can be -// read from the server. -func TestMuxConn_clientCloseRead(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - // This channel will be closed when we close - waitCh := make(chan struct{}) - - go func() { - conn, err := server.Accept(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - <-waitCh - - _, err = conn.Write([]byte("foo")) - if err != nil { - t.Fatalf("err: %s", err) - } - - conn.Close() - }() - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := s0.Close(); err != nil { - t.Fatalf("bad: %s", err) - } - - // Close this to continue on on the server-side - close(waitCh) - - var data [1024]byte - n, err := s0.Read(data[:]) - if string(data[:n]) != "foo" { - t.Fatalf("bad: %#v", string(data[:n])) - } -} - -func TestMuxConn_socketClose(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - go func() { - _, err := server.Accept(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - server.rwc.Close() - }() - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - var data [1024]byte - _, err = s0.Read(data[:]) - if err != io.EOF { - t.Fatalf("err: %s", err) - } -} - -func TestMuxConn_clientClosesStreams(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - go func() { - conn, err := server.Accept(0) - if err != nil { - t.Fatalf("err: %s", err) - } - conn.Close() - }() - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - var data [1024]byte - _, err = s0.Read(data[:]) - if err != io.EOF { - t.Fatalf("err: %s", err) - } -} - -func TestMuxConn_serverClosesStreams(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - go server.Accept(0) - - s0, err := client.Dial(0) - if err != nil { - t.Fatalf("err: %s", err) - } - - if err := server.Close(); err != nil { - t.Fatalf("err: %s", err) - } - - // This should block forever since we never write onto this stream. - var data [1024]byte - _, err = s0.Read(data[:]) - if err != io.EOF { - t.Fatalf("err: %s", err) - } -} - -func TestMuxConnNextId(t *testing.T) { - client, server := testMux(t) - defer client.Close() - defer server.Close() - - a := client.NextId() - b := client.NextId() - - if a != 1 || b != 2 { - t.Fatalf("IDs should increment") - } - - a = server.NextId() - b = server.NextId() - - if a != 1 || b != 2 { - t.Fatalf("IDs should increment: %d %d", a, b) - } -} diff --git a/packer/rpc/post_processor.go b/packer/rpc/post_processor.go index 3a22c1a8a..b183780b9 100644 --- a/packer/rpc/post_processor.go +++ b/packer/rpc/post_processor.go @@ -9,14 +9,14 @@ import ( // executed over an RPC connection. type postProcessor struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // PostProcessorServer wraps a packer.PostProcessor implementation and makes it // exportable as part of a Golang RPC server. type PostProcessorServer struct { client *rpc.Client - mux *MuxConn + mux *muxBroker p packer.PostProcessor } diff --git a/packer/rpc/provisioner.go b/packer/rpc/provisioner.go index e2346cd21..08d31700a 100644 --- a/packer/rpc/provisioner.go +++ b/packer/rpc/provisioner.go @@ -10,14 +10,14 @@ import ( // executed over an RPC connection. type provisioner struct { client *rpc.Client - mux *MuxConn + mux *muxBroker } // ProvisionerServer wraps a packer.Provisioner implementation and makes it // exportable as part of a Golang RPC server. type ProvisionerServer struct { p packer.Provisioner - mux *MuxConn + mux *muxBroker } type ProvisionerPrepareArgs struct { diff --git a/packer/rpc/server.go b/packer/rpc/server.go index df50f48ad..ca9691870 100644 --- a/packer/rpc/server.go +++ b/packer/rpc/server.go @@ -29,7 +29,7 @@ const ( // Server represents an RPC server for Packer. This must be paired on // the other side with a Client. type Server struct { - mux *MuxConn + mux *muxBroker streamId uint32 server *rpc.Server closeMux bool @@ -37,12 +37,14 @@ type Server struct { // NewServer returns a new Packer RPC server. func NewServer(conn io.ReadWriteCloser) *Server { - result := newServerWithMux(NewMuxConn(conn), 0) + mux, _ := newMuxBrokerServer(conn) + result := newServerWithMux(mux, 0) result.closeMux = true + go mux.Run() return result } -func newServerWithMux(mux *MuxConn, streamId uint32) *Server { +func newServerWithMux(mux *muxBroker, streamId uint32) *Server { return &Server{ mux: mux, streamId: streamId, @@ -140,11 +142,11 @@ func (s *Server) Serve() { // Accept a connection on stream ID 0, which is always used for // normal client to server connections. stream, err := s.mux.Accept(s.streamId) - defer stream.Close() if err != nil { log.Printf("[ERR] Error retrieving stream for serving: %s", err) return } + defer stream.Close() var h codec.MsgpackHandle rpcCodec := codec.GoRpc.ServerCodec(stream, &h) From b7c604795ed6b07c4391d128d131d3569214343b Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 2 Sep 2014 14:28:21 -0700 Subject: [PATCH 3/3] packer/plugin: increase version for Yamux --- packer/plugin/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packer/plugin/server.go b/packer/plugin/server.go index a36f8beda..83292c320 100644 --- a/packer/plugin/server.go +++ b/packer/plugin/server.go @@ -33,7 +33,7 @@ const MagicCookieValue = "d602bf8f470bc67ca7faa0386276bbdd4330efaf76d1a219cb4d69 // The APIVersion is outputted along with the RPC address. The plugin // client validates this API version and will show an error if it doesn't // know how to speak it. -const APIVersion = "3" +const APIVersion = "4" // Server waits for a connection to this plugin and returns a Packer // RPC server that you can use to register components and serve them.