diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index f449024b1..bc9de8e72 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -400,8 +400,6 @@ func (m *MuxConn) loop() { stream.mu.Unlock() case muxPacketData: - unlocked := false - stream.mu.Lock() switch stream.state { case streamStateFinWait1: @@ -409,26 +407,15 @@ func (m *MuxConn) loop() { case streamStateFinWait2: fallthrough case streamStateEstablished: - if len(data) > 0 { - // Get a reference to the write channel while we have - // the lock because otherwise the field might change. - // We unlock early here because the write might block - // for a long time. - writeCh := stream.writeCh - stream.mu.Unlock() - unlocked = true - - // Blocked write, this provides some backpressure on - // the connection if there is a lot of data incoming. - writeCh <- data + if len(data) > 0 && stream.writeCh != nil { + //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-START", m, id, from) + stream.writeCh <- data + //log.Printf("[TRACE] %p: Stream %d (%s) WRITE-END", m, id, from) } default: log.Printf("[ERR] Data received for stream in state: %d", stream.state) } - - if !unlocked { - stream.mu.Unlock() - } + stream.mu.Unlock() } } } @@ -516,6 +503,7 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { go func() { defer dataW.Close() + drain := false for { data := <-writeCh if data == nil { @@ -524,8 +512,14 @@ func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { return } + if drain { + // We're draining, meaning we're just waiting for the + // write channel to close. + continue + } + if _, err := dataW.Write(data); err != nil { - return + drain = true } } }() @@ -568,7 +562,10 @@ func (s *Stream) Write(p []byte) (int, error) { } func (s *Stream) closeWriter() { - s.writeCh <- nil + if s.writeCh != nil { + s.writeCh <- nil + s.writeCh = nil + } } func (s *Stream) setState(state streamState) { @@ -594,6 +591,7 @@ func (s *Stream) waitState(target streamState) error { delete(s.stateChange, stateCh) }() + //log.Printf("[TRACE] %p: Stream %d (%s) waiting for state: %d", s.mux, s.id, s.from, target) state := <-stateCh if state == target { return nil diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go index 6abeec9d2..8784b1410 100644 --- a/packer/rpc/muxconn_test.go +++ b/packer/rpc/muxconn_test.go @@ -76,6 +76,7 @@ func TestMuxConn(t *testing.T) { go func() { defer wg.Done() + defer s1.Close() data := readStream(t, s1) if data != "another" { t.Fatalf("bad: %#v", data) @@ -84,6 +85,7 @@ func TestMuxConn(t *testing.T) { go func() { defer wg.Done() + defer s0.Close() data := readStream(t, s0) if data != "hello" { t.Fatalf("bad: %#v", data) @@ -110,6 +112,9 @@ func TestMuxConn(t *testing.T) { t.Fatalf("err: %s", err) } + s0.Close() + s1.Close() + // Wait for the server to be done <-doneCh } @@ -131,18 +136,20 @@ func TestMuxConn_lotsOfData(t *testing.T) { t.Fatalf("err: %s", err) } - var wg sync.WaitGroup - wg.Add(1) + var data [1024]byte + for { + n, err := s0.Read(data[:]) + if err == io.EOF { + break + } - go func() { - defer wg.Done() - data := readStream(t, s0) - if data != "hello" { - t.Fatalf("bad: %#v", data) + dataString := string(data[0:n]) + if dataString != "hello" { + t.Fatalf("bad: %#v", dataString) } - }() + } - wg.Wait() + s0.Close() }() s0, err := client.Dial(0) @@ -156,6 +163,10 @@ func TestMuxConn_lotsOfData(t *testing.T) { } } + if err := s0.Close(); err != nil { + t.Fatalf("err: %s", err) + } + // Wait for the server to be done <-doneCh }