From 7372c32b6b57ca810c3e7ac8ea4a2e6905b36683 Mon Sep 17 00:00:00 2001 From: Mitchell Hashimoto Date: Tue, 10 Dec 2013 15:51:22 -0800 Subject: [PATCH] packer/rpc: implement proper close_wait state --- packer/rpc/muxconn.go | 38 +++++++++++++++++++------------------- packer/rpc/muxconn_test.go | 14 ++++++++------ 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/packer/rpc/muxconn.go b/packer/rpc/muxconn.go index e29d938bb..63ac39139 100644 --- a/packer/rpc/muxconn.go +++ b/packer/rpc/muxconn.go @@ -150,6 +150,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { // and it closed all within the time period we wait above. // This case will be fixed when we have edge-triggered checks. fallthrough + case streamStateCloseWait: + fallthrough case streamStateEstablished: stream.mu.Unlock() return stream, nil @@ -274,7 +276,7 @@ func (m *MuxConn) loop() { case streamStateSynSent: stream.setState(streamStateEstablished) case streamStateFinWait1: - stream.remoteClose() + stream.setState(streamStateFinWait2) default: log.Printf("[ERR] Ack received for stream in state: %d", stream.state) } @@ -294,9 +296,15 @@ func (m *MuxConn) loop() { stream.mu.Lock() switch stream.state { case streamStateEstablished: + stream.setState(streamStateCloseWait) m.write(id, muxPacketAck, nil) - fallthrough + + // Close the writer on our end since we won't receive any + // more data. + stream.writeCh <- nil case streamStateFinWait1: + fallthrough + case streamStateFinWait2: stream.remoteClose() // Remove this stream from being active so that it @@ -364,34 +372,26 @@ const ( streamStateSynSent streamStateEstablished streamStateFinWait1 + streamStateFinWait2 + streamStateCloseWait ) func (s *Stream) Close() error { s.mu.Lock() - if s.state != streamStateEstablished { - s.mu.Unlock() + defer s.mu.Unlock() + + if s.state != streamStateEstablished && s.state != streamStateCloseWait { return fmt.Errorf("Stream in bad state: %d", s.state) } if _, err := s.mux.write(s.id, muxPacketFin, nil); err != nil { return err } - s.setState(streamStateFinWait1) - s.mu.Unlock() - for { - time.Sleep(50 * time.Millisecond) - s.mu.Lock() - switch s.state { - case streamStateFinWait1: - s.mu.Unlock() - case streamStateClosed: - s.mu.Unlock() - return nil - default: - defer s.mu.Unlock() - return fmt.Errorf("Stream %d went to bad state: %d", s.id, s.state) - } + if s.state == streamStateEstablished { + s.setState(streamStateFinWait1) + } else { + s.remoteClose() } return nil diff --git a/packer/rpc/muxconn_test.go b/packer/rpc/muxconn_test.go index e6c23f26f..392ee6b5a 100644 --- a/packer/rpc/muxconn_test.go +++ b/packer/rpc/muxconn_test.go @@ -118,18 +118,20 @@ func TestMuxConn_clientClosesStreams(t *testing.T) { client, server := testMux(t) defer client.Close() defer server.Close() - go server.Accept(0) + + go func() { + conn, err := server.Accept(0) + if err != nil { + t.Fatalf("err: %s", err) + } + conn.Close() + }() s0, err := client.Dial(0) if err != nil { t.Fatalf("err: %s", err) } - if err := client.Close(); err != nil { - t.Fatalf("err: %s", err) - } - - // This should block forever since we never write onto this stream. var data [1024]byte _, err = s0.Read(data[:]) if err != io.EOF {