exp/ssh: fix length header leaking into channel data streams.

The payload of a data message is defined as an SSH string type,
which uses the first four bytes to encode its length. When channelData
and channelExtendedData were added I defined Payload as []byte to
be able to use it directly without a string to []byte conversion. This
resulted in the length data leaking into the payload data.

This CL fixes the bug, and restores agl's original fast path code.

Additionally, a bug whereby s.lock was not released if a packet arrived
for an invalid channel has been fixed.

Finally, as they were no longer used, I have removed
the channelData and channelExtedendData structs.

R=agl, rsc
CC=golang-dev
https://golang.org/cl/5330053
This commit is contained in:
Dave Cheney 2011-10-29 14:22:30 -04:00 committed by Adam Langley
parent 604e10c34d
commit 0f6b80c694
3 changed files with 131 additions and 114 deletions

View file

@ -258,51 +258,71 @@ func (c *ClientConn) openChan(typ string) (*clientChan, os.Error) {
// mainloop reads incoming messages and routes channel messages
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
// TODO(dfc) signal the underlying close to all channels
defer c.Close()
for {
packet, err := c.readPacket()
if err != nil {
// TODO(dfc) signal the underlying close to all channels
c.Close()
return
break
}
// TODO(dfc) A note on blocking channel use.
// The msg, win, data and dataExt channels of a clientChan can
// cause this loop to block indefinately if the consumer does
// not service them.
switch msg := decode(packet).(type) {
case *channelOpenMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenConfirmMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelCloseMsg:
ch := c.getChan(msg.PeersId)
close(ch.win)
close(ch.data)
close(ch.dataExt)
c.chanlist.remove(msg.PeersId)
case *channelEOFMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestSuccessMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestMsg:
c.getChan(msg.PeersId).msg <- msg
case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
case *channelData:
c.getChan(msg.PeersId).data <- msg.Payload
case *channelExtendedData:
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if msg.Datatype == 1 {
c.getChan(msg.PeersId).dataExt <- msg.Payload
switch packet[0] {
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
break
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
packet = packet[9:]
c.getChan(peersId).data <- packet[:length]
}
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
break
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
datatype := uint32(packet[5])<<24 | uint32(packet[6])<<16 | uint32(packet[7])<<8 | uint32(packet[8])
if length := int(packet[9])<<24 | int(packet[10])<<16 | int(packet[11])<<8 | int(packet[12]); length > 0 {
packet = packet[13:]
// RFC 4254 5.2 defines data_type_code 1 to be data destined
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
c.getChan(peersId).dataExt <- packet[:length]
}
}
default:
fmt.Printf("mainLoop: unhandled %#v\n", msg)
switch msg := decode(packet).(type) {
case *channelOpenMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenConfirmMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelOpenFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelCloseMsg:
ch := c.getChan(msg.PeersId)
close(ch.win)
close(ch.data)
close(ch.dataExt)
c.chanlist.remove(msg.PeersId)
case *channelEOFMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestSuccessMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestFailureMsg:
c.getChan(msg.PeersId).msg <- msg
case *channelRequestMsg:
c.getChan(msg.PeersId).msg <- msg
case *windowAdjustMsg:
c.getChan(msg.PeersId).win <- int(msg.AdditionalBytes)
default:
fmt.Printf("mainLoop: unhandled %#v\n", msg)
}
}
}
}

View file

@ -144,19 +144,6 @@ type channelOpenFailureMsg struct {
Language string
}
// See RFC 4254, section 5.2.
type channelData struct {
PeersId uint32
Payload []byte `ssh:"rest"`
}
// See RFC 4254, section 5.2.
type channelExtendedData struct {
PeersId uint32
Datatype uint32
Payload []byte `ssh:"rest"`
}
type channelRequestMsg struct {
PeersId uint32
Request string
@ -612,10 +599,6 @@ func decode(packet []byte) interface{} {
msg = new(channelOpenFailureMsg)
case msgChannelWindowAdjust:
msg = new(windowAdjustMsg)
case msgChannelData:
msg = new(channelData)
case msgChannelExtendedData:
msg = new(channelExtendedData)
case msgChannelEOF:
msg = new(channelEOFMsg)
case msgChannelClose:

View file

@ -581,75 +581,89 @@ func (s *ServerConn) Accept() (Channel, os.Error) {
return nil, err
}
switch msg := decode(packet).(type) {
case *channelOpenMsg:
c := new(channel)
c.chanType = msg.ChanType
c.theirId = msg.PeersId
c.theirWindow = msg.PeersWindow
c.maxPacketSize = msg.MaxPacketSize
c.extraData = msg.TypeSpecificData
c.myWindow = defaultWindowSize
c.serverConn = s
c.cond = sync.NewCond(&c.lock)
c.pendingData = make([]byte, c.myWindow)
switch packet[0] {
case msgChannelData:
if len(packet) < 9 {
// malformed data packet
return nil, ParseError{msgChannelData}
}
peersId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
s.lock.Lock()
c.myId = s.nextChanId
s.nextChanId++
s.channels[c.myId] = c
s.lock.Unlock()
return c, nil
case *channelRequestMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
c, ok := s.channels[peersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelData:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
if length := int(packet[5])<<24 | int(packet[6])<<16 | int(packet[7])<<8 | int(packet[8]); length > 0 {
packet = packet[9:]
c.handleData(packet[:length])
}
c.handleData(msg.Payload)
s.lock.Unlock()
case *channelEOFMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelCloseMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *globalRequestMsg:
if msg.WantReply {
if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
case UnexpectedMessageError:
return nil, msg
case *disconnectMsg:
return nil, os.EOF
default:
// Unknown message. Ignore.
switch msg := decode(packet).(type) {
case *channelOpenMsg:
c := new(channel)
c.chanType = msg.ChanType
c.theirId = msg.PeersId
c.theirWindow = msg.PeersWindow
c.maxPacketSize = msg.MaxPacketSize
c.extraData = msg.TypeSpecificData
c.myWindow = defaultWindowSize
c.serverConn = s
c.cond = sync.NewCond(&c.lock)
c.pendingData = make([]byte, c.myWindow)
s.lock.Lock()
c.myId = s.nextChanId
s.nextChanId++
s.channels[c.myId] = c
s.lock.Unlock()
return c, nil
case *channelRequestMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelEOFMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *channelCloseMsg:
s.lock.Lock()
c, ok := s.channels[msg.PeersId]
if !ok {
s.lock.Unlock()
continue
}
c.handlePacket(msg)
s.lock.Unlock()
case *globalRequestMsg:
if msg.WantReply {
if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
case UnexpectedMessageError:
return nil, msg
case *disconnectMsg:
return nil, os.EOF
default:
// Unknown message. Ignore.
}
}
}