Fix two-way stream cancelation and pings (#19763)

Do not log errors on oneway streams when sending ping fails. Instead, cancel the stream.

This also makes sure pings are sent when blocked on sending responses.
This commit is contained in:
Klaus Post 2024-05-22 01:25:25 -07:00 committed by GitHub
parent 9906b3ade9
commit 4d698841f4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 396 additions and 88 deletions

View file

@ -123,10 +123,11 @@ type Connection struct {
baseFlags Flags
// For testing only
debugInConn net.Conn
debugOutConn net.Conn
addDeadline time.Duration
connMu sync.Mutex
debugInConn net.Conn
debugOutConn net.Conn
blockMessages atomic.Pointer[<-chan struct{}]
addDeadline time.Duration
connMu sync.Mutex
}
// Subroute is a connection subroute that can be used to route to a specific handler with the same handler ID.
@ -975,6 +976,11 @@ func (c *Connection) handleMessages(ctx context.Context, conn net.Conn) {
gridLogIfNot(ctx, fmt.Errorf("ws read: %w", err), net.ErrClosed, io.EOF)
return
}
block := c.blockMessages.Load()
if block != nil && *block != nil {
<-*block
}
if c.incomingBytes != nil {
c.incomingBytes(int64(len(msg)))
}
@ -1363,6 +1369,10 @@ func (c *Connection) handleRequest(ctx context.Context, m message, subID *subHan
}
func (c *Connection) handlePong(ctx context.Context, m message) {
if m.MuxID == 0 && m.Payload == nil {
atomic.StoreInt64(&c.LastPong, time.Now().Unix())
return
}
var pong pongMsg
_, err := pong.UnmarshalMsg(m.Payload)
PutByteBuffer(m.Payload)
@ -1382,7 +1392,9 @@ func (c *Connection) handlePong(ctx context.Context, m message) {
func (c *Connection) handlePing(ctx context.Context, m message) {
if m.MuxID == 0 {
gridLogIf(ctx, c.queueMsg(m, &pongMsg{}))
m.Flags.Clear(FlagPayloadIsZero)
m.Op = OpPong
gridLogIf(ctx, c.queueMsg(m, nil))
return
}
// Single calls do not support pinging.
@ -1599,7 +1611,12 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
c.connMu.Lock()
defer c.connMu.Unlock()
c.connPingInterval = args[0].(time.Duration)
if c.connPingInterval < time.Second {
panic("CONN ping interval too low")
}
case debugSetClientPingDuration:
c.connMu.Lock()
defer c.connMu.Unlock()
c.clientPingInterval = args[0].(time.Duration)
case debugAddToDeadline:
c.addDeadline = args[0].(time.Duration)
@ -1615,6 +1632,11 @@ func (c *Connection) debugMsg(d debugMsg, args ...any) {
mid.respMu.Lock()
resp(mid.closed)
mid.respMu.Unlock()
case debugBlockInboundMessages:
c.connMu.Lock()
block := (<-chan struct{})(args[0].(chan struct{}))
c.blockMessages.Store(&block)
c.connMu.Unlock()
}
}

View file

@ -50,6 +50,7 @@ const (
debugSetClientPingDuration
debugAddToDeadline
debugIsOutgoingClosed
debugBlockInboundMessages
)
// TestGrid contains a grid of servers for testing purposes.

View file

@ -16,11 +16,12 @@ func _() {
_ = x[debugSetClientPingDuration-5]
_ = x[debugAddToDeadline-6]
_ = x[debugIsOutgoingClosed-7]
_ = x[debugBlockInboundMessages-8]
}
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingClosed"
const _debugMsg_name = "debugShutdowndebugKillInbounddebugKillOutbounddebugWaitForExitdebugSetConnPingDurationdebugSetClientPingDurationdebugAddToDeadlinedebugIsOutgoingCloseddebugBlockInboundMessages"
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151}
var _debugMsg_index = [...]uint8{0, 13, 29, 46, 62, 86, 112, 130, 151, 176}
func (i debugMsg) String() string {
if i < 0 || i >= debugMsg(len(_debugMsg_index)-1) {

View file

@ -378,6 +378,54 @@ func TestStreamSuite(t *testing.T) {
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 0)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 1)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
}
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
@ -491,12 +539,12 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
register(remote)
// local to remote
testHandler := func(t *testing.T, handler HandlerID) {
testHandler := func(t *testing.T, handler HandlerID, sendReq bool) {
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithCancel(context.Background())
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
errFatal(err)
clientCanceled := make(chan time.Time, 1)
err = nil
@ -513,6 +561,18 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
clientCanceled <- time.Now()
}(t)
start := time.Now()
if st.Requests != nil {
defer close(st.Requests)
}
// Fill up queue.
for sendReq {
select {
case st.Requests <- []byte("Hello"):
time.Sleep(10 * time.Millisecond)
default:
sendReq = false
}
}
cancel()
<-serverCanceled
t.Log("server cancel time:", time.Since(start))
@ -524,11 +584,13 @@ func testStreamCancel(t *testing.T, local, remote *Manager) {
}
// local to remote, unbuffered
t.Run("unbuffered", func(t *testing.T) {
testHandler(t, handlerTest)
testHandler(t, handlerTest, false)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2)
testHandler(t, handlerTest2, false)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, true)
})
}
@ -1025,6 +1087,167 @@ func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
}
}
// testServerStreamNoPing will test if server and client handle no pings.
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Stop processing requests
nowBlocking := make(chan struct{})
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
// Check that local returned.
err = st.Results(func(b []byte) error {
return nil
})
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("error: %v", err)
// Check that remote is canceled.
<-serverCanceled
close(nowBlocking)
}
// testServerStreamNoPing will test if server and client handle ping even when blocked.
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
for blockResp {
select {
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
case resp <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Block until we have exceeded the deadline several times over.
nowBlocking := make(chan struct{})
time.AfterFunc(time.Second, func() {
cancel()
close(nowBlocking)
})
if inCap > 0 {
go func() {
defer close(st.Requests)
if !blockReq {
<-nowBlocking
return
}
for {
select {
case <-nowBlocking:
return
case <-st.Done():
case st.Requests <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
}()
}
// Check that local returned.
err = st.Results(func(b []byte) error {
select {
case <-nowBlocking:
case <-st.Done():
return ctx.Err()
}
return nil
})
select {
case <-nowBlocking:
default:
t.Fatal("expected to be blocked. got err", err)
}
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("error: %v", err)
// Check that remote is canceled.
<-serverCanceled
}
func timeout(after time.Duration) (cancel func()) {
c := time.After(after)
cc := make(chan struct{})

View file

@ -32,24 +32,25 @@ import (
// muxClient is a stateful connection to a remote.
type muxClient struct {
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
MuxID uint64
SendSeq, RecvSeq uint32
LastPong int64
BaseFlags Flags
ctx context.Context
cancelFn context.CancelCauseFunc
parent *Connection
respWait chan<- Response
respMu sync.Mutex
singleResp bool
closed bool
stateless bool
acked bool
init bool
deadline time.Duration
outBlock chan struct{}
subroute *subHandlerID
respErr atomic.Pointer[error]
clientPingInterval time.Duration
}
// Response is a response from the server.
@ -61,12 +62,13 @@ type Response struct {
func newMuxClient(ctx context.Context, muxID uint64, parent *Connection) *muxClient {
ctx, cancelFn := context.WithCancelCause(ctx)
return &muxClient{
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().Unix(),
BaseFlags: parent.baseFlags,
MuxID: muxID,
ctx: ctx,
cancelFn: cancelFn,
parent: parent,
LastPong: time.Now().UnixNano(),
BaseFlags: parent.baseFlags,
clientPingInterval: parent.clientPingInterval,
}
}
@ -309,11 +311,11 @@ func (m *muxClient) handleOneWayStream(respHandler chan<- Response, respServer <
}
}()
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > clientPingInterval {
ticker := time.NewTicker(clientPingInterval)
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
defer m.parent.deleteMux(false, m.MuxID)
for {
@ -367,7 +369,7 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
}
// Only check ping when not closed.
if got := time.Since(time.Unix(atomic.LoadInt64(&m.LastPong), 0)); got > clientPingInterval*2 {
if got := time.Since(time.Unix(0, atomic.LoadInt64(&m.LastPong))); got > m.clientPingInterval*2 {
m.respMu.Unlock()
if debugPrint {
fmt.Printf("Mux %d: last pong %v ago, disconnecting\n", m.MuxID, got)
@ -388,15 +390,20 @@ func (m *muxClient) doPing(respHandler chan<- Response) (ok bool) {
// responseCh is the channel to that goes to the requester.
// internalResp is the channel that comes from the server.
func (m *muxClient) handleTwowayResponses(responseCh chan<- Response, internalResp <-chan Response) {
defer m.parent.deleteMux(false, m.MuxID)
defer xioutil.SafeClose(responseCh)
defer func() {
m.parent.deleteMux(false, m.MuxID)
// addErrorNonBlockingClose will close the response channel.
xioutil.SafeClose(responseCh)
}()
// Cancelation and errors are handled by handleTwowayRequests below.
for resp := range internalResp {
responseCh <- resp
m.send(message{Op: OpUnblockSrvMux, MuxID: m.MuxID})
responseCh <- resp
}
}
func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests <-chan []byte) {
func (m *muxClient) handleTwowayRequests(errResp chan<- Response, requests <-chan []byte) {
var errState bool
if debugPrint {
start := time.Now()
@ -405,24 +412,30 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
}()
}
var pingTimer <-chan time.Time
if m.deadline == 0 || m.deadline > m.clientPingInterval {
ticker := time.NewTicker(m.clientPingInterval)
defer ticker.Stop()
pingTimer = ticker.C
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// Listen for client messages.
for {
if errState {
go func() {
// Drain requests.
for range requests {
}
}()
return
}
reqLoop:
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
case req, ok := <-requests:
if !ok {
// Done send EOF
@ -432,22 +445,28 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
msg := message{
Op: OpMuxClientMsg,
MuxID: m.MuxID,
Seq: 1,
Flags: FlagEOF,
}
msg.setZeroPayloadFlag()
err := m.send(msg)
if err != nil {
m.addErrorNonBlockingClose(internalResp, err)
m.addErrorNonBlockingClose(errResp, err)
}
return
break reqLoop
}
// Grab a send token.
sendReq:
select {
case <-m.ctx.Done():
m.addErrorNonBlockingClose(internalResp, context.Cause(m.ctx))
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
errState = true
continue
case <-pingTimer:
if !m.doPing(errResp) {
errState = true
continue
}
goto sendReq
case <-m.outBlock:
}
msg := message{
@ -460,13 +479,41 @@ func (m *muxClient) handleTwowayRequests(internalResp chan<- Response, requests
err := m.send(msg)
PutByteBuffer(req)
if err != nil {
m.addErrorNonBlockingClose(internalResp, err)
m.addErrorNonBlockingClose(errResp, err)
errState = true
continue
}
msg.Seq++
}
}
if errState {
// Drain requests.
for {
select {
case r, ok := <-requests:
if !ok {
return
}
PutByteBuffer(r)
default:
return
}
}
}
for !errState {
select {
case <-m.ctx.Done():
if debugPrint {
fmt.Println("Client sending disconnect to mux", m.MuxID)
}
m.addErrorNonBlockingClose(errResp, context.Cause(m.ctx))
return
case <-pingTimer:
errState = !m.doPing(errResp)
}
}
}
// checkSeq will check if sequence number is correct and increment it by 1.
@ -502,7 +549,7 @@ func (m *muxClient) response(seq uint32, r Response) {
m.addResponse(r)
return
}
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
ok := m.addResponse(r)
if !ok {
PutByteBuffer(r.Msg)
@ -553,7 +600,7 @@ func (m *muxClient) pong(msg pongMsg) {
m.addResponse(Response{Err: err})
return
}
atomic.StoreInt64(&m.LastPong, time.Now().Unix())
atomic.StoreInt64(&m.LastPong, time.Now().UnixNano())
}
// addResponse will add a response to the response channel.

View file

@ -28,21 +28,20 @@ import (
xioutil "github.com/minio/minio/internal/ioutil"
)
const lastPingThreshold = 4 * clientPingInterval
type muxServer struct {
ID uint64
LastPing int64
SendSeq, RecvSeq uint32
Resp chan []byte
BaseFlags Flags
ctx context.Context
cancel context.CancelFunc
inbound chan []byte
parent *Connection
sendMu sync.Mutex
recvMu sync.Mutex
outBlock chan struct{}
ID uint64
LastPing int64
SendSeq, RecvSeq uint32
Resp chan []byte
BaseFlags Flags
ctx context.Context
cancel context.CancelFunc
inbound chan []byte
parent *Connection
sendMu sync.Mutex
recvMu sync.Mutex
outBlock chan struct{}
clientPingInterval time.Duration
}
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
@ -89,16 +88,17 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
}
m := muxServer{
ID: msg.MuxID,
RecvSeq: msg.Seq + 1,
SendSeq: msg.Seq,
ctx: ctx,
cancel: cancel,
parent: c,
inbound: nil,
outBlock: make(chan struct{}, outboundCap),
LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags,
ID: msg.MuxID,
RecvSeq: msg.Seq + 1,
SendSeq: msg.Seq,
ctx: ctx,
cancel: cancel,
parent: c,
inbound: nil,
outBlock: make(chan struct{}, outboundCap),
LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags,
clientPingInterval: c.clientPingInterval,
}
// Acknowledge Mux created.
// Send async.
@ -153,7 +153,7 @@ func newMuxStream(ctx context.Context, msg message, c *Connection, handler Strea
}(m.outBlock)
// Remote aliveness check if needed.
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(4*c.clientPingInterval/time.Millisecond) {
go func() {
wg.Wait()
m.checkRemoteAlive()
@ -234,7 +234,7 @@ func (m *muxServer) handleRequests(ctx context.Context, msg message, send chan<-
// checkRemoteAlive will check if the remote is alive.
func (m *muxServer) checkRemoteAlive() {
t := time.NewTicker(lastPingThreshold / 4)
t := time.NewTicker(m.clientPingInterval)
defer t.Stop()
for {
select {
@ -242,7 +242,7 @@ func (m *muxServer) checkRemoteAlive() {
return
case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > lastPingThreshold {
if last > 4*m.clientPingInterval {
gridLogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close()
return

View file

@ -89,12 +89,26 @@ func (s *Stream) Results(next func(b []byte) error) (err error) {
return nil
}
if resp.Err != nil {
s.cancel(resp.Err)
return resp.Err
}
err = next(resp.Msg)
if err != nil {
s.cancel(err)
return err
}
}
}
}
// Done will return a channel that will be closed when the stream is done.
// This mirrors context.Done().
func (s *Stream) Done() <-chan struct{} {
return s.ctx.Done()
}
// Err will return the error that caused the stream to end.
// This mirrors context.Err().
func (s *Stream) Err() error {
return s.ctx.Err()
}