diff --git a/internal/grid/connection.go b/internal/grid/connection.go index eae23c40a..4d5b45d4f 100644 --- a/internal/grid/connection.go +++ b/internal/grid/connection.go @@ -925,137 +925,141 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) { c.handleMsgWg.Add(2) c.reconnectMu.Unlock() - // Read goroutine - go func() { - defer func() { - if rec := recover(); rec != nil { - gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) - debug.PrintStack() - } - c.connChange.L.Lock() - if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { - c.connChange.Broadcast() - } - c.connChange.L.Unlock() - conn.Close() - c.handleMsgWg.Done() - }() + // Start reader and writer + go c.readStream(ctx, conn, cancel) + c.writeStream(ctx, conn, cancel) +} - controlHandler := wsutil.ControlFrameHandler(conn, c.side) - wsReader := wsutil.Reader{ - Source: conn, - State: c.side, - CheckUTF8: true, - SkipHeaderCheck: false, - OnIntermediate: controlHandler, +// readStream handles the read side of the connection. +// It will read messages and send them to c.handleMsg. +// If an error occurs the cancel function will be called and conn be closed. +// The function will block until the connection is closed or an error occurs. +func (c *Connection) readStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) { + defer func() { + if rec := recover(); rec != nil { + gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec)) + debug.PrintStack() } - readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) { - dst = dst[:0] - for { - hdr, err := wsReader.NextFrame() - if err != nil { - return nil, err - } - if hdr.OpCode.IsControl() { - if err := controlHandler(hdr, &wsReader); err != nil { - return nil, err - } - continue - } - if hdr.OpCode&want == 0 { - if err := wsReader.Discard(); err != nil { - return nil, err - } - continue - } - if int64(cap(dst)) < hdr.Length+1 { - dst = make([]byte, 0, hdr.Length+hdr.Length>>3) - } - return readAllInto(dst[:0], &wsReader) - } - } - - // Keep reusing the same buffer. - var msg []byte - for { - if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { - cancel(ErrDisconnected) - return - } - if cap(msg) > readBufferSize*4 { - // Don't keep too much memory around. - msg = nil - } - - var err error - msg, err = readDataInto(msg, conn, c.side, ws.OpBinary) - if err != nil { - cancel(ErrDisconnected) - if !xnet.IsNetworkOrHostDown(err, true) { - gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) - } - return - } - block := c.blockMessages.Load() - if block != nil && *block != nil { - <-*block - } - - if c.incomingBytes != nil { - c.incomingBytes(int64(len(msg))) - } - - // Parse the received message - var m message - subID, remain, err := m.parse(msg) - if err != nil { - if !xnet.IsNetworkOrHostDown(err, true) { - gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err)) - } - cancel(ErrDisconnected) - return - } - if debugPrint { - fmt.Printf("%s Got msg: %v\n", c.Local, m) - } - if m.Op != OpMerged { - c.inMessages.Add(1) - c.handleMsg(ctx, m, subID) - continue - } - // Handle merged messages. - messages := int(m.Seq) - c.inMessages.Add(int64(messages)) - for i := 0; i < messages; i++ { - if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { - cancel(ErrDisconnected) - return - } - var next []byte - next, remain, err = msgp.ReadBytesZC(remain) - if err != nil { - if !xnet.IsNetworkOrHostDown(err, true) { - gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err)) - } - cancel(ErrDisconnected) - return - } - - m.Payload = nil - subID, _, err = m.parse(next) - if err != nil { - if !xnet.IsNetworkOrHostDown(err, true) { - gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err)) - } - cancel(ErrDisconnected) - return - } - c.handleMsg(ctx, m, subID) - } + cancel(ErrDisconnected) + c.connChange.L.Lock() + if atomic.CompareAndSwapUint32((*uint32)(&c.state), StateConnected, StateConnectionError) { + c.connChange.Broadcast() } + c.connChange.L.Unlock() + conn.Close() + c.handleMsgWg.Done() }() - // Write function. + controlHandler := wsutil.ControlFrameHandler(conn, c.side) + wsReader := wsutil.Reader{ + Source: conn, + State: c.side, + CheckUTF8: true, + SkipHeaderCheck: false, + OnIntermediate: controlHandler, + } + readDataInto := func(dst []byte, rw io.ReadWriter, s ws.State, want ws.OpCode) ([]byte, error) { + dst = dst[:0] + for { + hdr, err := wsReader.NextFrame() + if err != nil { + return nil, err + } + if hdr.OpCode.IsControl() { + if err := controlHandler(hdr, &wsReader); err != nil { + return nil, err + } + continue + } + if hdr.OpCode&want == 0 { + if err := wsReader.Discard(); err != nil { + return nil, err + } + continue + } + if int64(cap(dst)) < hdr.Length+1 { + dst = make([]byte, 0, hdr.Length+hdr.Length>>3) + } + return readAllInto(dst[:0], &wsReader) + } + } + + // Keep reusing the same buffer. + var msg []byte + for atomic.LoadUint32((*uint32)(&c.state)) == StateConnected { + if cap(msg) > readBufferSize*4 { + // Don't keep too much memory around. + msg = nil + } + + var err error + msg, err = readDataInto(msg, conn, c.side, ws.OpBinary) + if err != nil { + if !xnet.IsNetworkOrHostDown(err, true) { + gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF) + } + return + } + block := c.blockMessages.Load() + if block != nil && *block != nil { + <-*block + } + + if c.incomingBytes != nil { + c.incomingBytes(int64(len(msg))) + } + + // Parse the received message + var m message + subID, remain, err := m.parse(msg) + if err != nil { + if !xnet.IsNetworkOrHostDown(err, true) { + gridLogIf(ctx, fmt.Errorf("ws parse package: %w", err)) + } + return + } + if debugPrint { + fmt.Printf("%s Got msg: %v\n", c.Local, m) + } + if m.Op != OpMerged { + c.inMessages.Add(1) + c.handleMsg(ctx, m, subID) + continue + } + // Handle merged messages. + messages := int(m.Seq) + c.inMessages.Add(int64(messages)) + for i := 0; i < messages; i++ { + if atomic.LoadUint32((*uint32)(&c.state)) != StateConnected { + return + } + var next []byte + next, remain, err = msgp.ReadBytesZC(remain) + if err != nil { + if !xnet.IsNetworkOrHostDown(err, true) { + gridLogIf(ctx, fmt.Errorf("ws read merged: %w", err)) + } + return + } + + m.Payload = nil + subID, _, err = m.parse(next) + if err != nil { + if !xnet.IsNetworkOrHostDown(err, true) { + gridLogIf(ctx, fmt.Errorf("ws parse merged: %w", err)) + } + return + } + c.handleMsg(ctx, m, subID) + } + } +} + +// writeStream handles the read side of the connection. +// It will grab messages from c.outQueue and write them to the connection. +// If an error occurs the cancel function will be called and conn be closed. +// The function will block until the connection is closed or an error occurs. +func (c *Connection) writeStream(ctx context.Context, conn net.Conn, cancel context.CancelCauseFunc) { defer func() { if rec := recover(); rec != nil { gridLogIf(ctx, fmt.Errorf("handleMessages: panic recovered: %v", rec))