minio/internal/grid/grid_test.go

1330 lines
36 KiB
Go

// Copyright (c) 2015-2023 MinIO, Inc.
//
// This file is part of MinIO Object Storage stack
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.
package grid
import (
"bytes"
"context"
"errors"
"fmt"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/minio/minio/internal/logger/target/testlogger"
)
func TestSingleRoundtrip(t *testing.T) {
defer testlogger.T.SetLogTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
grid, err := SetupTestGrid(2)
errFatal(err)
remoteHost := grid.Hosts[1]
local := grid.Managers[0]
// 1: Echo
errFatal(local.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("1: server payload: ", len(payload), "bytes.")
return append([]byte{}, payload...), nil
}))
// 2: Return as error
errFatal(local.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("2: server payload: ", len(payload), "bytes.")
err := RemoteErr(payload)
return nil, &err
}))
remote := grid.Managers[1]
// 1: Echo
errFatal(remote.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("1: server payload: ", len(payload), "bytes.")
return append([]byte{}, payload...), nil
}))
// 2: Return as error
errFatal(remote.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("2: server payload: ", len(payload), "bytes.")
err := RemoteErr(payload)
return nil, &err
}))
// local to remote
remoteConn := local.Connection(remoteHost)
remoteConn.WaitForConnect(context.Background())
defer testlogger.T.SetErrorTB(t)()
t.Run("localToRemote", func(t *testing.T) {
const testPayload = "Hello Grid World!"
start := time.Now()
resp, err := remoteConn.Request(context.Background(), handlerTest, []byte(testPayload))
errFatal(err)
if string(resp) != testPayload {
t.Errorf("want %q, got %q", testPayload, string(resp))
}
t.Log("Roundtrip:", time.Since(start))
})
t.Run("localToRemoteErr", func(t *testing.T) {
const testPayload = "Hello Grid World!"
start := time.Now()
resp, err := remoteConn.Request(context.Background(), handlerTest2, []byte(testPayload))
t.Log("Roundtrip:", time.Since(start))
if len(resp) != 0 {
t.Errorf("want nil, got %q", string(resp))
}
if err != RemoteErr(testPayload) {
t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err)
}
t.Log("Roundtrip:", time.Since(start))
})
t.Run("localToRemoteHuge", func(t *testing.T) {
testPayload := bytes.Repeat([]byte("?"), 1<<20)
start := time.Now()
resp, err := remoteConn.Request(context.Background(), handlerTest, testPayload)
errFatal(err)
if string(resp) != string(testPayload) {
t.Errorf("want %q, got %q", testPayload, string(resp))
}
t.Log("Roundtrip:", time.Since(start))
})
t.Run("localToRemoteErrHuge", func(t *testing.T) {
testPayload := bytes.Repeat([]byte("!"), 1<<10)
start := time.Now()
resp, err := remoteConn.Request(context.Background(), handlerTest2, testPayload)
if len(resp) != 0 {
t.Errorf("want nil, got %q", string(resp))
}
if err != RemoteErr(testPayload) {
t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err)
}
t.Log("Roundtrip:", time.Since(start))
})
}
func TestSingleRoundtripNotReady(t *testing.T) {
defer testlogger.T.SetLogTB(t)()
errFatal := func(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
grid, err := SetupTestGrid(2)
errFatal(t, err)
remoteHost := grid.Hosts[1]
local := grid.Managers[0]
// 1: Echo
errFatal(t, local.RegisterSingleHandler(handlerTest, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("1: server payload: ", len(payload), "bytes.")
return append([]byte{}, payload...), nil
}))
// 2: Return as error
errFatal(t, local.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) {
t.Log("2: server payload: ", len(payload), "bytes.")
err := RemoteErr(payload)
return nil, &err
}))
// Do not register remote handlers
// local to remote
remoteConn := local.Connection(remoteHost)
remoteConn.WaitForConnect(context.Background())
defer testlogger.T.SetErrorTB(t)()
t.Run("localToRemote", func(t *testing.T) {
const testPayload = "Hello Grid World!"
// Single requests should have remote errors.
_, err := remoteConn.Request(context.Background(), handlerTest, []byte(testPayload))
if v, ok := err.(*RemoteErr); !ok || v.Error() != "Invalid Handler for type" {
t.Fatalf("Unexpected error: %v, %T", err, err)
}
// Streams should not be able to set up until registered.
// Thus, the error is a local error.
_, err = remoteConn.NewStream(context.Background(), handlerTest, []byte(testPayload))
if !errors.Is(err, ErrUnknownHandler) {
t.Fatalf("Unexpected error: %v, %T", err, err)
}
})
}
func TestSingleRoundtripGenerics(t *testing.T) {
defer testlogger.T.SetLogTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
grid, err := SetupTestGrid(2)
errFatal(err)
remoteHost := grid.Hosts[1]
local := grid.Managers[0]
remote := grid.Managers[1]
// 1: Echo
h1 := NewSingleHandler[*testRequest, *testResponse](handlerTest, func() *testRequest {
return &testRequest{}
}, func() *testResponse {
return &testResponse{}
})
// Handles incoming requests, returns a response
handler1 := func(req *testRequest) (resp *testResponse, err *RemoteErr) {
resp = h1.NewResponse()
*resp = testResponse{
OrgNum: req.Num,
OrgString: req.String,
Embedded: *req,
}
return resp, nil
}
// Return error
h2 := NewSingleHandler[*testRequest, *testResponse](handlerTest2, newTestRequest, newTestResponse)
handler2 := func(req *testRequest) (resp *testResponse, err *RemoteErr) {
r := RemoteErr(req.String)
return nil, &r
}
errFatal(h1.Register(local, handler1))
errFatal(h2.Register(local, handler2))
errFatal(h1.Register(remote, handler1))
errFatal(h2.Register(remote, handler2))
// local to remote connection
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
start := time.Now()
req := testRequest{Num: 1, String: testPayload}
resp, err := h1.Call(context.Background(), remoteConn, &req)
errFatal(err)
if resp.OrgString != testPayload {
t.Errorf("want %q, got %q", testPayload, resp.OrgString)
}
t.Log("Roundtrip:", time.Since(start))
h1.PutResponse(resp)
start = time.Now()
resp, err = h2.Call(context.Background(), remoteConn, &testRequest{Num: 1, String: testPayload})
t.Log("Roundtrip:", time.Since(start))
if err != RemoteErr(testPayload) {
t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err)
}
if resp != nil {
t.Errorf("want nil, got %q", resp)
}
h2.PutResponse(resp)
t.Log("Roundtrip:", time.Since(start))
}
func TestSingleRoundtripGenericsRecycle(t *testing.T) {
defer testlogger.T.SetLogTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
grid, err := SetupTestGrid(2)
errFatal(err)
remoteHost := grid.Hosts[1]
local := grid.Managers[0]
remote := grid.Managers[1]
// 1: Echo
h1 := NewSingleHandler[*MSS, *MSS](handlerTest, NewMSS, NewMSS)
// Handles incoming requests, returns a response
handler1 := func(req *MSS) (resp *MSS, err *RemoteErr) {
resp = h1.NewResponse()
for k, v := range *req {
(*resp)[k] = v
}
return resp, nil
}
// Return error
h2 := NewSingleHandler[*MSS, *MSS](handlerTest2, NewMSS, NewMSS)
handler2 := func(req *MSS) (resp *MSS, err *RemoteErr) {
defer req.Recycle()
r := RemoteErr(req.Get("err"))
return nil, &r
}
errFatal(h1.Register(local, handler1))
errFatal(h2.Register(local, handler2))
errFatal(h1.Register(remote, handler1))
errFatal(h2.Register(remote, handler2))
// local to remote connection
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
start := time.Now()
req := NewMSSWith(map[string]string{"test": testPayload})
resp, err := h1.Call(context.Background(), remoteConn, req)
errFatal(err)
if resp.Get("test") != testPayload {
t.Errorf("want %q, got %q", testPayload, resp.Get("test"))
}
t.Log("Roundtrip:", time.Since(start))
h1.PutResponse(resp)
start = time.Now()
resp, err = h2.Call(context.Background(), remoteConn, NewMSSWith(map[string]string{"err": testPayload}))
t.Log("Roundtrip:", time.Since(start))
if err != RemoteErr(testPayload) {
t.Errorf("want error %v(%T), got %v(%T)", RemoteErr(testPayload), RemoteErr(testPayload), err, err)
}
if resp != nil {
t.Errorf("want nil, got %q", resp)
}
t.Log("Roundtrip:", time.Since(start))
h2.PutResponse(resp)
}
func TestStreamSuite(t *testing.T) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
grid, err := SetupTestGrid(2)
errFatal(err)
t.Cleanup(grid.Cleanup)
local := grid.Managers[0]
localHost := grid.Hosts[0]
remote := grid.Managers[1]
remoteHost := grid.Hosts[1]
connLocalToRemote := local.Connection(remoteHost)
connRemoteLocal := remote.Connection(localHost)
t.Run("testStreamRoundtrip", func(t *testing.T) {
defer timeout(5 * time.Second)()
testStreamRoundtrip(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testStreamCancel", func(t *testing.T) {
defer timeout(5 * time.Second)()
testStreamCancel(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testStreamDeadline", func(t *testing.T) {
defer timeout(5 * time.Second)()
testStreamDeadline(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerOutCongestion", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerOutCongestion(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerInCongestion", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerInCongestion(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testGenericsStreamRoundtrip", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testGenericsStreamRoundtrip(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testGenericsStreamRoundtripSubroute", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testGenericsStreamRoundtripSubroute(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamResponseBlocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamResponseBlocked(t, local, remote)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 0)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayNoPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamNoPing(t, local, remote, 1)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReq", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamTwowayPingReqResp", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 1, true, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPing", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, true)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
t.Run("testServerStreamOnewayPingUnblocked", func(t *testing.T) {
defer timeout(1 * time.Minute)()
testServerStreamPingRunning(t, local, remote, 0, false, false)
assertNoActive(t, connRemoteLocal)
assertNoActive(t, connLocalToRemote)
})
}
func testStreamRoundtrip(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
defer timeout(5 * time.Second)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
for in := range request {
b := append([]byte{}, payload...)
b = append(b, in...)
resp <- b
}
t.Log(GetCaller(ctx).Name, "Handler done")
return nil
},
OutCapacity: 1,
InCapacity: 1,
}))
// 2: Return as error
errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
for in := range request {
t.Log("2: Got err request", string(in))
err := RemoteErr(append(payload, in...))
return &err
}
return nil
},
OutCapacity: 1,
InCapacity: 1,
}))
}
register(local)
register(remote)
// local to remote
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
start := time.Now()
stream, err := remoteConn.NewStream(context.Background(), handlerTest, []byte(testPayload))
errFatal(err)
var n int
stream.Requests <- []byte(strconv.Itoa(n))
for resp := range stream.responses {
errFatal(resp.Err)
t.Logf("got resp: %+v", string(resp.Msg))
if string(resp.Msg) != testPayload+strconv.Itoa(n) {
t.Errorf("want %q, got %q", testPayload+strconv.Itoa(n), string(resp.Msg))
}
if n == 10 {
close(stream.Requests)
continue
}
n++
t.Log("sending new client request")
stream.Requests <- []byte(strconv.Itoa(n))
}
t.Log("EOF. 10 Roundtrips:", time.Since(start))
}
func testStreamCancel(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
<-ctx.Done()
serverCanceled <- struct{}{}
fmt.Println(GetCaller(ctx).Name, "Server Context canceled")
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
<-ctx.Done()
serverCanceled <- struct{}{}
fmt.Println(GetCaller(ctx).Name, "Server Context canceled")
return nil
},
OutCapacity: 1,
InCapacity: 1,
}))
}
register(local)
register(remote)
// local to remote
testHandler := func(t *testing.T, handler HandlerID, sendReq bool) {
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithCancel(context.Background())
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
errFatal(err)
clientCanceled := make(chan time.Time, 1)
err = nil
go func(t *testing.T) {
for resp := range st.responses {
t.Log("got resp:", string(resp.Msg), "err:", resp.Err)
if err != nil {
t.Log("ERROR: got second error:", resp.Err, "first:", err)
continue
}
err = resp.Err
}
t.Log("Client Context canceled. err state:", err)
clientCanceled <- time.Now()
}(t)
start := time.Now()
if st.Requests != nil {
defer close(st.Requests)
}
// Fill up queue.
for sendReq {
select {
case st.Requests <- []byte("Hello"):
time.Sleep(10 * time.Millisecond)
default:
sendReq = false
}
}
cancel()
<-serverCanceled
t.Log("server cancel time:", time.Since(start))
clientEnd := <-clientCanceled
if !errors.Is(err, context.Canceled) {
t.Error("expected context.Canceled, got", err)
}
t.Log("client after", time.Since(clientEnd))
}
// local to remote, unbuffered
t.Run("unbuffered", func(t *testing.T) {
testHandler(t, handlerTest, false)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, false)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2, true)
})
}
// testStreamDeadline will test if server
func testStreamDeadline(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
const wantDL = 50 * time.Millisecond
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverCanceled := make(chan time.Duration, 1)
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
started := time.Now()
dl, _ := ctx.Deadline()
if testing.Verbose() {
fmt.Println(GetCaller(ctx).Name, "Server deadline:", time.Until(dl))
}
<-ctx.Done()
serverCanceled <- time.Since(started)
if testing.Verbose() {
fmt.Println(GetCaller(ctx).Name, "Server Context canceled with", ctx.Err(), "after", time.Since(started))
}
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
errFatal(manager.RegisterStreamingHandler(handlerTest2, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
started := time.Now()
dl, _ := ctx.Deadline()
if testing.Verbose() {
fmt.Println(GetCaller(ctx).Name, "Server deadline:", time.Until(dl))
}
<-ctx.Done()
serverCanceled <- time.Since(started)
if testing.Verbose() {
fmt.Println(GetCaller(ctx).Name, "Server Context canceled with", ctx.Err(), "after", time.Since(started))
}
return nil
},
OutCapacity: 1,
InCapacity: 1,
}))
}
register(local)
register(remote)
// Double remote DL
local.debugMsg(debugAddToDeadline, wantDL)
defer local.debugMsg(debugAddToDeadline, time.Duration(0))
remote.debugMsg(debugAddToDeadline, wantDL)
defer remote.debugMsg(debugAddToDeadline, time.Duration(0))
testHandler := func(t *testing.T, handler HandlerID) {
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), wantDL)
defer cancel()
st, err := remoteConn.NewStream(ctx, handler, []byte(testPayload))
errFatal(err)
clientCanceled := make(chan time.Duration, 1)
go func() {
started := time.Now()
for resp := range st.responses {
err = resp.Err
}
clientCanceled <- time.Since(started)
}()
serverEnd := <-serverCanceled
clientEnd := <-clientCanceled
t.Log("server cancel time:", serverEnd)
t.Log("client cancel time:", clientEnd)
if !errors.Is(err, context.DeadlineExceeded) {
t.Error("expected context.DeadlineExceeded, got", err)
}
}
// local to remote, unbuffered
t.Run("unbuffered", func(t *testing.T) {
testHandler(t, handlerTest)
})
t.Run("buffered", func(t *testing.T) {
testHandler(t, handlerTest2)
})
}
func testServerOutCongestion(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverSent := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
// Send many responses.
// Test that this doesn't block.
for i := byte(0); i < 100; i++ {
select {
case resp <- []byte{i}:
// ok
case <-ctx.Done():
return NewRemoteErr(ctx.Err())
}
if i == 0 {
close(serverSent)
}
}
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
errFatal(manager.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) {
// Simple roundtrip
return append([]byte{}, payload...), nil
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server to send the first response.
<-serverSent
// Now do 100 other requests to ensure that the server doesn't block.
for i := 0; i < 100; i++ {
_, err := remoteConn.Request(ctx, handlerTest2, []byte(testPayload))
errFatal(err)
}
// Drain responses
got := 0
for resp := range st.responses {
// t.Log("got response", resp)
errFatal(resp.Err)
if resp.Msg[0] != byte(got) {
t.Error("expected response", got, "got", resp.Msg[0])
}
got++
}
if got != 100 {
t.Error("expected 100 responses, got", got)
}
}
func testServerInCongestion(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
processHandler := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, request <-chan []byte, resp chan<- []byte) *RemoteErr {
// Block incoming requests.
var n byte
<-processHandler
for {
select {
case in, ok := <-request:
if !ok {
return nil
}
if in[0] != n {
return NewRemoteErrString(fmt.Sprintf("expected incoming %d, got %d", n, in[0]))
}
n++
resp <- append([]byte{}, in...)
case <-ctx.Done():
return NewRemoteErr(ctx.Err())
}
}
},
OutCapacity: 5,
InCapacity: 5,
}))
errFatal(manager.RegisterSingleHandler(handlerTest2, func(payload []byte) ([]byte, *RemoteErr) {
// Simple roundtrip
return append([]byte{}, payload...), nil
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Start sending requests.
go func() {
for i := byte(0); i < 100; i++ {
st.Requests <- []byte{i}
}
close(st.Requests)
}()
// Now do 100 other requests to ensure that the server doesn't block.
for i := 0; i < 100; i++ {
_, err := remoteConn.Request(ctx, handlerTest2, []byte(testPayload))
errFatal(err)
}
// Start processing requests.
close(processHandler)
// Drain responses
got := 0
for resp := range st.responses {
// t.Log("got response", resp)
errFatal(resp.Err)
if resp.Msg[0] != byte(got) {
t.Error("expected response", got, "got", resp.Msg[0])
}
got++
}
if got != 100 {
t.Error("expected 100 responses, got", got)
}
}
func testGenericsStreamRoundtrip(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
defer timeout(5 * time.Second)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
handler := NewStream[*testRequest, *testRequest, *testResponse](handlerTest, newTestRequest, newTestRequest, newTestResponse)
handler.InCapacity = 1
handler.OutCapacity = 1
const payloads = 10
// 1: Echo
register := func(manager *Manager) {
errFatal(handler.Register(manager, func(ctx context.Context, pp *testRequest, in <-chan *testRequest, out chan<- *testResponse) *RemoteErr {
n := 0
for i := range in {
if n > payloads {
panic("too many requests")
}
// t.Log("Got request:", *i)
out <- &testResponse{
OrgNum: i.Num + pp.Num,
OrgString: pp.String + i.String,
Embedded: *i,
}
n++
}
return nil
}))
}
register(local)
register(remote)
// local to remote
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
start := time.Now()
stream, err := handler.Call(context.Background(), remoteConn, &testRequest{Num: 1, String: testPayload})
errFatal(err)
go func() {
defer close(stream.Requests)
for i := 0; i < payloads; i++ {
// t.Log("sending new client request")
stream.Requests <- &testRequest{Num: i, String: testPayload}
}
}()
var n int
err = stream.Results(func(resp *testResponse) error {
const wantString = testPayload + testPayload
if resp.OrgString != testPayload+testPayload {
t.Errorf("want %q, got %q", wantString, resp.OrgString)
}
if resp.OrgNum != n+1 {
t.Errorf("want %d, got %d", n+1, resp.OrgNum)
}
handler.PutResponse(resp)
n++
return nil
})
errFatal(err)
t.Log("EOF.", payloads, " Roundtrips:", time.Since(start))
}
func testGenericsStreamRoundtripSubroute(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
defer timeout(5 * time.Second)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
handler := NewStream[*testRequest, *testRequest, *testResponse](handlerTest, newTestRequest, newTestRequest, newTestResponse)
handler.InCapacity = 1
handler.OutCapacity = 1
const payloads = 10
// 1: Echo
register := func(manager *Manager) {
errFatal(handler.Register(manager, func(ctx context.Context, pp *testRequest, in <-chan *testRequest, out chan<- *testResponse) *RemoteErr {
sub := GetSubroute(ctx)
if sub != "subroute/1" {
t.Fatal("expected subroute/1, got", sub)
}
n := 0
for i := range in {
if n > payloads {
panic("too many requests")
}
// t.Log("Got request:", *i)
out <- &testResponse{
OrgNum: i.Num + pp.Num,
OrgString: pp.String + i.String,
Embedded: *i,
}
n++
}
return nil
}, "subroute", "1"))
}
register(local)
register(remote)
// local to remote
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
// Add subroute
remoteSub := remoteConn.Subroute(strings.Join([]string{"subroute", "1"}, "/"))
start := time.Now()
stream, err := handler.Call(context.Background(), remoteSub, &testRequest{Num: 1, String: testPayload})
errFatal(err)
go func() {
defer close(stream.Requests)
for i := 0; i < payloads; i++ {
// t.Log("sending new client request")
stream.Requests <- &testRequest{Num: i, String: testPayload}
}
}()
var n int
err = stream.Results(func(resp *testResponse) error {
// t.Logf("got resp: %+v", *resp.Msg)
const wantString = testPayload + testPayload
if resp.OrgString != testPayload+testPayload {
t.Errorf("want %q, got %q", wantString, resp.OrgString)
}
if resp.OrgNum != n+1 {
t.Errorf("want %d, got %d", n+1, resp.OrgNum)
}
handler.PutResponse(resp)
n++
return nil
})
errFatal(err)
t.Log("EOF.", payloads, " Roundtrips:", time.Since(start))
}
// testServerStreamResponseBlocked will test if server can handle a blocked response stream
func testServerStreamResponseBlocked(t *testing.T, local, remote *Manager) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
serverSent := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
// Send many responses.
// Test that this doesn't block.
for i := byte(0); i < 100; i++ {
select {
case resp <- []byte{i}:
// ok
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
}
if i == 1 {
close(serverSent)
}
}
return nil
},
OutCapacity: 1,
InCapacity: 0,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server to send the first response.
<-serverSent
// Read back from the stream and block.
nowBlocking := make(chan struct{})
stopBlocking := make(chan struct{})
defer close(stopBlocking)
go func() {
st.Results(func(b []byte) error {
close(nowBlocking)
// Block until test is done.
<-stopBlocking
return nil
})
}()
<-nowBlocking
// Wait for the receiver channel to fill.
for len(st.responses) != cap(st.responses) {
time.Sleep(time.Millisecond)
}
cancel()
<-serverCanceled
local.debugMsg(debugIsOutgoingClosed, st.muxID, func(closed bool) {
if !closed {
t.Error("expected outgoing closed")
} else {
t.Log("outgoing was closed")
}
})
// Drain responses and check if error propagated.
err = st.Results(func(b []byte) error {
return nil
})
if !errors.Is(err, context.Canceled) {
t.Error("expected context.Canceled, got", err)
}
}
// testServerStreamNoPing will test if server and client handle no pings.
func testServerStreamNoPing(t *testing.T, local, remote *Manager, inCap int) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, _ <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Stop processing requests
nowBlocking := make(chan struct{})
remoteConn.debugMsg(debugBlockInboundMessages, nowBlocking)
// Check that local returned.
err = st.Results(func(b []byte) error {
return nil
})
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("response: %v", err)
// Check that remote is canceled.
<-serverCanceled
close(nowBlocking)
}
// testServerStreamPingRunning will test if server and client handle ping even when blocked.
func testServerStreamPingRunning(t *testing.T, local, remote *Manager, inCap int, blockResp, blockReq bool) {
defer testlogger.T.SetErrorTB(t)()
errFatal := func(err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}
// We fake a local and remote server.
remoteHost := remote.HostName()
// 1: Echo
reqStarted := make(chan struct{})
serverCanceled := make(chan struct{})
register := func(manager *Manager) {
errFatal(manager.RegisterStreamingHandler(handlerTest, StreamHandler{
Handle: func(ctx context.Context, payload []byte, req <-chan []byte, resp chan<- []byte) *RemoteErr {
close(reqStarted)
// Just wait for it to cancel.
for blockResp {
select {
case <-ctx.Done():
close(serverCanceled)
return NewRemoteErr(ctx.Err())
case resp <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
// Just wait for it to cancel.
<-ctx.Done()
close(serverCanceled)
return NewRemoteErr(ctx.Err())
},
OutCapacity: 1,
InCapacity: inCap,
}))
}
register(local)
register(remote)
remoteConn := local.Connection(remoteHost)
const testPayload = "Hello Grid World!"
remoteConn.debugMsg(debugSetClientPingDuration, 100*time.Millisecond)
defer remoteConn.debugMsg(debugSetClientPingDuration, clientPingInterval)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
st, err := remoteConn.NewStream(ctx, handlerTest, []byte(testPayload))
errFatal(err)
// Wait for the server start the request.
<-reqStarted
// Block until we have exceeded the deadline several times over.
nowBlocking := make(chan struct{})
var mu sync.Mutex
time.AfterFunc(time.Second, func() {
mu.Lock()
cancel()
close(nowBlocking)
mu.Unlock()
})
if inCap > 0 {
go func() {
defer close(st.Requests)
if !blockReq {
<-nowBlocking
return
}
for {
select {
case <-nowBlocking:
return
case <-st.Done():
case st.Requests <- []byte{1}:
time.Sleep(10 * time.Millisecond)
}
}
}()
}
// Check that local returned.
err = st.Results(func(b []byte) error {
<-st.Done()
return ctx.Err()
})
mu.Lock()
select {
case <-nowBlocking:
default:
t.Fatal("expected to be blocked. got err", err)
}
if err == nil {
t.Fatal("expected error, got nil")
}
t.Logf("response: %v", err)
// Check that remote is canceled.
<-serverCanceled
}
func timeout(after time.Duration) (cancel func()) {
c := time.After(after)
cc := make(chan struct{})
go func() {
select {
case <-cc:
return
case <-c:
buf := make([]byte, 1<<20)
stacklen := runtime.Stack(buf, true)
fmt.Printf("=== Timeout, assuming deadlock ===\n*** goroutine dump...\n%s\n*** end\n", string(buf[:stacklen]))
os.Exit(2)
}
}()
return func() {
close(cc)
}
}
func assertNoActive(t *testing.T, c *Connection) {
t.Helper()
// Tiny bit racy for tests, but we try to play nice.
for i := 10; i >= 0; i-- {
runtime.Gosched()
stats := c.Stats()
if stats.IncomingStreams != 0 {
if i > 0 {
time.Sleep(100 * time.Millisecond)
continue
}
var found []uint64
c.inStream.Range(func(key uint64, value *muxServer) bool {
found = append(found, key)
return true
})
t.Errorf("expected no active streams, got %d incoming: %v", stats.IncomingStreams, found)
}
if stats.OutgoingStreams != 0 {
if i > 0 {
time.Sleep(100 * time.Millisecond)
continue
}
var found []uint64
c.outgoing.Range(func(key uint64, value *muxClient) bool {
found = append(found, key)
return true
})
t.Errorf("expected no active streams, got %d outgoing: %v", stats.OutgoingStreams, found)
}
return
}
}
// Inserted manually.
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[StateUnconnected-0]
_ = x[StateConnecting-1]
_ = x[StateConnected-2]
_ = x[StateConnectionError-3]
_ = x[StateShutdown-4]
}
const stateName = "UnconnectedConnectingConnectedConnectionErrorShutdown"
var stateIndex = [...]uint8{0, 11, 21, 30, 45, 53}
func (i State) String() string {
if i >= State(len(stateIndex)-1) {
return "State(" + strconv.FormatInt(int64(i), 10) + ")"
}
return stateName[stateIndex[i]:stateIndex[i+1]]
}