|
|
|
|
@ -33,6 +33,7 @@ type muxPacketType byte
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
muxPacketSyn muxPacketType = iota
|
|
|
|
|
muxPacketSynAck
|
|
|
|
|
muxPacketAck
|
|
|
|
|
muxPacketFin
|
|
|
|
|
muxPacketData
|
|
|
|
|
@ -77,49 +78,27 @@ func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
|
|
|
|
|
|
|
|
|
|
// If the stream isn't closed, then it is already open somehow
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if stream.state == streamStateSynRecv {
|
|
|
|
|
// Fast track establishing since we already got the syn
|
|
|
|
|
stream.setState(streamStateEstablished)
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if stream.state != streamStateEstablished {
|
|
|
|
|
// Go into the listening state
|
|
|
|
|
if stream.state == streamStateClosed {
|
|
|
|
|
// Go into the listening state and wait for a syn
|
|
|
|
|
stream.setState(streamStateListen)
|
|
|
|
|
if err := stream.waitState(streamStateSynRecv); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Register a state change listener to wait for changes
|
|
|
|
|
stateCh := make(chan streamState, 10)
|
|
|
|
|
stream.registerStateListener(stateCh)
|
|
|
|
|
defer func() {
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
stream.deregisterStateListener(stateCh)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
// Wait for the connection to establish
|
|
|
|
|
ACCEPT_ESTABLISH_LOOP:
|
|
|
|
|
for {
|
|
|
|
|
state := <-stateCh
|
|
|
|
|
switch state {
|
|
|
|
|
case streamStateListen:
|
|
|
|
|
case streamStateEstablished:
|
|
|
|
|
break ACCEPT_ESTABLISH_LOOP
|
|
|
|
|
default:
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
|
|
|
|
}
|
|
|
|
|
if stream.state == streamStateSynRecv {
|
|
|
|
|
// Send a syn-ack
|
|
|
|
|
if _, err := m.write(stream.id, muxPacketSynAck, nil); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Send the ack down
|
|
|
|
|
if _, err := m.write(stream.id, muxPacketAck, nil); err != nil {
|
|
|
|
|
if err := stream.waitState(streamStateEstablished); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -136,8 +115,8 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
|
|
|
|
|
|
|
|
|
// If the stream isn't closed, then it is already open somehow
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
if stream.state != streamStateClosed {
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
return nil, fmt.Errorf("Stream %d already open in bad state: %d", id, stream.state)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@ -147,28 +126,12 @@ func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
|
|
|
|
|
}
|
|
|
|
|
stream.setState(streamStateSynSent)
|
|
|
|
|
|
|
|
|
|
// Register a state change listener to wait for changes
|
|
|
|
|
stateCh := make(chan streamState, 10)
|
|
|
|
|
stream.registerStateListener(stateCh)
|
|
|
|
|
defer func() {
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
stream.deregisterStateListener(stateCh)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
state := <-stateCh
|
|
|
|
|
switch state {
|
|
|
|
|
case streamStateSynSent:
|
|
|
|
|
case streamStateEstablished:
|
|
|
|
|
return stream, nil
|
|
|
|
|
default:
|
|
|
|
|
defer stream.mu.Unlock()
|
|
|
|
|
return nil, fmt.Errorf("Stream %d went to bad state: %d", id, stream.state)
|
|
|
|
|
}
|
|
|
|
|
if err := stream.waitState(streamStateEstablished); err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
m.write(id, muxPacketAck, nil)
|
|
|
|
|
return stream, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NextId returns the next available stream ID that isn't currently
|
|
|
|
|
@ -247,6 +210,7 @@ func (m *MuxConn) loop() {
|
|
|
|
|
// Force close every stream that we know about when we exit so
|
|
|
|
|
// that they all read EOF and don't block forever.
|
|
|
|
|
defer func() {
|
|
|
|
|
log.Printf("[INFO] Mux connection loop exiting")
|
|
|
|
|
m.mu.Lock()
|
|
|
|
|
defer m.mu.Unlock()
|
|
|
|
|
for _, w := range m.streams {
|
|
|
|
|
@ -288,12 +252,23 @@ func (m *MuxConn) loop() {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
//log.Printf("[DEBUG] Stream %d received packet %d", id, packetType)
|
|
|
|
|
log.Printf("[TRACE] Stream %d received packet %d", id, packetType)
|
|
|
|
|
switch packetType {
|
|
|
|
|
case muxPacketSyn:
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
switch stream.state {
|
|
|
|
|
case streamStateClosed:
|
|
|
|
|
fallthrough
|
|
|
|
|
case streamStateListen:
|
|
|
|
|
stream.setState(streamStateSynRecv)
|
|
|
|
|
default:
|
|
|
|
|
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
|
|
|
|
|
}
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
case muxPacketAck:
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
switch stream.state {
|
|
|
|
|
case streamStateSynSent:
|
|
|
|
|
case streamStateSynRecv:
|
|
|
|
|
stream.setState(streamStateEstablished)
|
|
|
|
|
case streamStateFinWait1:
|
|
|
|
|
stream.setState(streamStateFinWait2)
|
|
|
|
|
@ -301,15 +276,13 @@ func (m *MuxConn) loop() {
|
|
|
|
|
log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
|
|
|
|
|
}
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
case muxPacketSyn:
|
|
|
|
|
case muxPacketSynAck:
|
|
|
|
|
stream.mu.Lock()
|
|
|
|
|
switch stream.state {
|
|
|
|
|
case streamStateClosed:
|
|
|
|
|
stream.setState(streamStateSynRecv)
|
|
|
|
|
case streamStateListen:
|
|
|
|
|
case streamStateSynSent:
|
|
|
|
|
stream.setState(streamStateEstablished)
|
|
|
|
|
default:
|
|
|
|
|
log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
|
|
|
|
|
log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
|
|
|
|
|
}
|
|
|
|
|
stream.mu.Unlock()
|
|
|
|
|
case muxPacketFin:
|
|
|
|
|
@ -451,6 +424,7 @@ func (s *Stream) deregisterStateListener(ch chan<- streamState) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *Stream) setState(state streamState) {
|
|
|
|
|
log.Printf("[TRACE] Stream %d went to state %d", s.id, state)
|
|
|
|
|
s.state = state
|
|
|
|
|
s.stateUpdated = time.Now().UTC()
|
|
|
|
|
for ch, _ := range s.stateChange {
|
|
|
|
|
@ -460,3 +434,22 @@ func (s *Stream) setState(state streamState) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *Stream) waitState(target streamState) error {
|
|
|
|
|
// Register a state change listener to wait for changes
|
|
|
|
|
stateCh := make(chan streamState, 10)
|
|
|
|
|
s.registerStateListener(stateCh)
|
|
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
|
|
|
|
defer func() {
|
|
|
|
|
s.mu.Lock()
|
|
|
|
|
s.deregisterStateListener(stateCh)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
state := <-stateCh
|
|
|
|
|
if state == target {
|
|
|
|
|
return nil
|
|
|
|
|
} else {
|
|
|
|
|
return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|