minio/internal/grid/muxserver.go

332 lines
7.9 KiB
Go

// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package grid
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/minio/minio/internal/logger"
)
const lastPingThreshold = 4 * clientPingInterval
type muxServer struct {
ID uint64
LastPing int64
SendSeq, RecvSeq uint32
Resp chan []byte
BaseFlags Flags
ctx context.Context
cancel context.CancelFunc
inbound chan []byte
parent *Connection
sendMu sync.Mutex
recvMu sync.Mutex
outBlock chan struct{}
}
func newMuxStateless(ctx context.Context, msg message, c *Connection, handler StatelessHandler) *muxServer {
var cancel context.CancelFunc
ctx = setCaller(ctx, c.remote)
if msg.DeadlineMS > 0 {
ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond)
} else {
ctx, cancel = context.WithCancel(ctx)
}
m := muxServer{
ID: msg.MuxID,
RecvSeq: msg.Seq + 1,
SendSeq: msg.Seq,
ctx: ctx,
cancel: cancel,
parent: c,
LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags,
}
go func() {
// TODO: Handle
}()
return &m
}
func newMuxStream(ctx context.Context, msg message, c *Connection, handler StreamHandler) *muxServer {
var cancel context.CancelFunc
ctx = setCaller(ctx, c.remote)
if len(handler.Subroute) > 0 {
ctx = setSubroute(ctx, handler.Subroute)
}
if msg.DeadlineMS > 0 {
ctx, cancel = context.WithTimeout(ctx, time.Duration(msg.DeadlineMS)*time.Millisecond+c.addDeadline)
} else {
ctx, cancel = context.WithCancel(ctx)
}
send := make(chan []byte)
inboundCap, outboundCap := handler.InCapacity, handler.OutCapacity
if outboundCap <= 0 {
outboundCap = 1
}
m := muxServer{
ID: msg.MuxID,
RecvSeq: msg.Seq + 1,
SendSeq: msg.Seq,
ctx: ctx,
cancel: cancel,
parent: c,
inbound: nil,
outBlock: make(chan struct{}, outboundCap),
LastPing: time.Now().Unix(),
BaseFlags: c.baseFlags,
}
// Acknowledge Mux created.
var ack message
ack.Op = OpAckMux
ack.Flags = m.BaseFlags
ack.MuxID = m.ID
m.send(ack)
if debugPrint {
fmt.Println("connected stream mux:", ack.MuxID)
}
// Data inbound to the handler
var handlerIn chan []byte
if inboundCap > 0 {
m.inbound = make(chan []byte, inboundCap)
handlerIn = make(chan []byte, 1)
go func(inbound <-chan []byte) {
defer close(handlerIn)
// Send unblocks when we have delivered the message to the handler.
for in := range inbound {
handlerIn <- in
m.send(message{Op: OpUnblockClMux, MuxID: m.ID, Flags: c.baseFlags})
}
}(m.inbound)
}
for i := 0; i < outboundCap; i++ {
m.outBlock <- struct{}{}
}
// Handler goroutine.
var handlerErr *RemoteErr
go func() {
start := time.Now()
defer func() {
if debugPrint {
fmt.Println("Mux", m.ID, "Handler took", time.Since(start).Round(time.Millisecond))
}
if r := recover(); r != nil {
logger.LogIf(ctx, fmt.Errorf("grid handler (%v) panic: %v", msg.Handler, r))
err := RemoteErr(fmt.Sprintf("panic: %v", r))
handlerErr = &err
}
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "Returned with", handlerErr)
}
close(send)
}()
// handlerErr is guarded by 'send' channel.
handlerErr = handler.Handle(ctx, msg.Payload, handlerIn, send)
}()
// Response sender gorutine...
go func(outBlock <-chan struct{}) {
defer m.parent.deleteMux(true, m.ID)
for {
// Process outgoing message.
var payload []byte
var ok bool
select {
case payload, ok = <-send:
case <-ctx.Done():
return
}
select {
case <-ctx.Done():
return
case <-outBlock:
}
msg := message{
MuxID: m.ID,
Op: OpMuxServerMsg,
Flags: c.baseFlags,
}
if !ok {
if debugPrint {
fmt.Println("muxServer: Mux", m.ID, "send EOF", handlerErr)
}
msg.Flags |= FlagEOF
if handlerErr != nil {
msg.Flags |= FlagPayloadIsErr
msg.Payload = []byte(*handlerErr)
}
msg.setZeroPayloadFlag()
m.send(msg)
return
}
msg.Payload = payload
msg.setZeroPayloadFlag()
m.send(msg)
}
}(m.outBlock)
// Remote aliveness check.
if msg.DeadlineMS == 0 || msg.DeadlineMS > uint32(lastPingThreshold/time.Millisecond) {
go func() {
t := time.NewTicker(lastPingThreshold / 4)
defer t.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-t.C:
last := time.Since(time.Unix(atomic.LoadInt64(&m.LastPing), 0))
if last > lastPingThreshold {
logger.LogIf(m.ctx, fmt.Errorf("canceling remote connection %s not seen for %v", m.parent, last))
m.close()
return
}
}
}
}()
}
return &m
}
// checkSeq will check if sequence number is correct and increment it by 1.
func (m *muxServer) checkSeq(seq uint32) (ok bool) {
if seq != m.RecvSeq {
if debugPrint {
fmt.Printf("expected sequence %d, got %d\n", m.RecvSeq, seq)
}
m.disconnect(fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq))
return false
}
m.RecvSeq++
return true
}
func (m *muxServer) message(msg message) {
if debugPrint {
fmt.Printf("muxServer: recevied message %d, length %d\n", msg.Seq, len(msg.Payload))
}
m.recvMu.Lock()
defer m.recvMu.Unlock()
if cap(m.inbound) == 0 {
m.disconnect("did not expect inbound message")
return
}
if !m.checkSeq(msg.Seq) {
return
}
// Note, on EOF no value can be sent.
if msg.Flags&FlagEOF != 0 {
if len(msg.Payload) > 0 {
logger.LogIf(m.ctx, fmt.Errorf("muxServer: EOF message with payload"))
}
close(m.inbound)
m.inbound = nil
return
}
select {
case <-m.ctx.Done():
case m.inbound <- msg.Payload:
if debugPrint {
fmt.Printf("muxServer: Sent seq %d to handler\n", msg.Seq)
}
default:
m.disconnect("handler blocked")
}
}
func (m *muxServer) unblockSend(seq uint32) {
if !m.checkSeq(seq) {
return
}
m.recvMu.Lock()
defer m.recvMu.Unlock()
if m.outBlock == nil {
// Closed
return
}
select {
case m.outBlock <- struct{}{}:
default:
logger.LogIf(m.ctx, errors.New("output unblocked overflow"))
}
}
func (m *muxServer) ping(seq uint32) pongMsg {
if !m.checkSeq(seq) {
msg := fmt.Sprintf("receive sequence number mismatch. want %d, got %d", m.RecvSeq, seq)
return pongMsg{Err: &msg}
}
select {
case <-m.ctx.Done():
err := context.Cause(m.ctx).Error()
return pongMsg{Err: &err}
default:
atomic.StoreInt64(&m.LastPing, time.Now().Unix())
return pongMsg{}
}
}
func (m *muxServer) disconnect(msg string) {
if debugPrint {
fmt.Println("Mux", m.ID, "disconnecting. Reason:", msg)
}
if msg != "" {
m.send(message{Op: OpMuxServerMsg, MuxID: m.ID, Flags: FlagPayloadIsErr | FlagEOF, Payload: []byte(msg)})
} else {
m.send(message{Op: OpDisconnectClientMux, MuxID: m.ID})
}
m.parent.deleteMux(true, m.ID)
}
func (m *muxServer) send(msg message) {
m.sendMu.Lock()
defer m.sendMu.Unlock()
msg.MuxID = m.ID
msg.Seq = m.SendSeq
m.SendSeq++
if debugPrint {
fmt.Printf("Mux %d, Sending %+v\n", m.ID, msg)
}
logger.LogIf(m.ctx, m.parent.queueMsg(msg, nil))
}
func (m *muxServer) close() {
m.cancel()
m.recvMu.Lock()
defer m.recvMu.Unlock()
if m.inbound != nil {
close(m.inbound)
m.inbound = nil
}
if m.outBlock != nil {
close(m.outBlock)
m.outBlock = nil
}
}