Add rate limiting to Assist (#26011)

* Add rate limiting to Assist

* Only rate limit Assist in Cloud

* Add a comment to assistantLimiter

* Fixes after rebase

* Add 'rate-limited' test case to assistant_test

* Handle CHAT_MESSAGE_UI in Assist web UI

* Add godoc

* CHAT_MESSAGE_UI -> CHAT_MESSAGE_ERROR

* Run assistant test cases in parallel
This commit is contained in:
Justinas Stankevičius 2023-05-15 18:48:52 +03:00 committed by GitHub
parent 9f794f1049
commit 21c534b17c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 180 additions and 56 deletions

View file

@ -58,6 +58,8 @@ const (
MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE"
// MessageKindSystemMessage is the type of Assist message that contains the system message.
MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM"
// MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like.
MessageKindError MessageType = "CHAT_MESSAGE_ERROR"
)
// Assist is the Teleport Assist client.
@ -210,7 +212,7 @@ type TokensUsed struct {
// Prompt is a number of tokens used in the prompt.
Prompt int
// Completion is a number of tokens used in the completion.
Competition int
Completion int
}
// ProcessComplete processes the completion request and returns the number of tokens used.
@ -386,8 +388,8 @@ func (c *Chat) ProcessComplete(ctx context.Context,
}
return &TokensUsed{
Prompt: promptTokens,
Competition: numTokens,
Prompt: promptTokens,
Completion: numTokens,
}, nil
}

View file

@ -51,6 +51,7 @@ import (
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/mod/semver"
"golang.org/x/time/rate"
"google.golang.org/protobuf/encoding/protojson"
"github.com/gravitational/teleport"
@ -92,6 +93,15 @@ import (
const (
// SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages.
SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details."
// assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour.
assistantTokensPerHour = 140
// assistantLimiterRate is the rate (in tokens per second)
// at which tokens for the assistant rate limiter are replenished
assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second))
// assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter.
// The bucket starts full, prefilled for a week.
assistantLimiterCapacity = assistantTokensPerHour * 24 * 7
)
// healthCheckAppServerFunc defines a function used to perform a health check
@ -111,7 +121,13 @@ type Handler struct {
clock clockwork.Clock
limiter *limiter.RateLimiter
highLimiter *limiter.RateLimiter
healthCheckAppServer healthCheckAppServerFunc
// assistantLimiter limits the amount of tokens that can be consumed
// by OpenAI API calls when using a shared key.
// golang.org/x/time/rate is used, as the oxy ratelimiter
// is quite tightly tied to individual http.Requests,
// and instead we want to consume arbitrary amounts of tokens.
assistantLimiter *rate.Limiter
healthCheckAppServer healthCheckAppServerFunc
// sshPort specifies the SSH proxy port extracted
// from configuration
sshPort string
@ -301,6 +317,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
healthCheckAppServer: cfg.HealthCheckAppServer,
}
// Check for self-hosted vs Cloud.
// TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud
if modules.GetModules().Features().Cloud {
h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity)
} else {
// Set up a limiter with "infinite limit", the "burst" parameter is ignored
h.assistantLimiter = rate.NewLimiter(rate.Inf, 0)
}
// for properly handling url-encoded parameter values.
h.UseRawPath = true

View file

@ -405,6 +405,17 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
return trace.Wrap(err)
}
// We can not know how many tokens we will consume in advance.
// Try to consume a small amount of tokens first.
const lookaheadTokens = 100
if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) {
err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC())
if err != nil {
return trace.Wrap(err)
}
continue
}
//TODO(jakule): Should we sanitize the payload?
if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil {
return trace.Wrap(err)
@ -415,14 +426,22 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
return trace.Wrap(err)
}
// Once we know how many tokens were consumed for prompt+completion,
// consume the remaining tokens from the rate limiter bucket.
extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens
if extraTokens < 0 {
extraTokens = 0
}
h.assistantLimiter.ReserveN(time.Now(), extraTokens)
usageEventReq := &proto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{
AssistCompletion: &usageeventsv1.AssistCompletionEvent{
ConversationId: conversationID,
TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition),
TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion),
PromptTokens: int64(usedTokens.Prompt),
CompletionTokens: int64(usedTokens.Competition),
CompletionTokens: int64(usedTokens.Completion),
},
},
},

View file

@ -32,6 +32,7 @@ import (
"github.com/gravitational/trace"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"golang.org/x/time/rate"
"github.com/gravitational/teleport/lib/assist"
"github.com/gravitational/teleport/lib/client"
@ -40,46 +41,9 @@ import (
func Test_runAssistant(t *testing.T) {
t.Parallel()
responses := [][]byte{
generateTextResponse(),
}
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
responses = responses[1:]
}))
defer server.Close()
openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })
_, payload, err := ws.ReadMessage()
require.NoError(t, err)
var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)
readPartialMessage := func() string {
_, payload, err = ws.ReadMessage()
readPartialMessage := func(t *testing.T, ws *websocket.Conn) string {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)
err = json.Unmarshal(payload, &msg)
@ -89,13 +53,9 @@ func Test_runAssistant(t *testing.T) {
return msg.Payload
}
require.Contains(t, readPartialMessage(), "Which")
require.Contains(t, readPartialMessage(), "node do")
require.Contains(t, readPartialMessage(), "you want")
require.Contains(t, readPartialMessage(), "use?")
readStraemEnd := func() {
_, payload, err = ws.ReadMessage()
readStreamEnd := func(t *testing.T, ws *websocket.Conn) {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)
err = json.Unmarshal(payload, &msg)
@ -104,7 +64,117 @@ func Test_runAssistant(t *testing.T) {
require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type)
}
readStraemEnd()
readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) {
var msg assistantMessage
_, payload, err := ws.ReadMessage()
require.NoError(t, err)
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
require.Equal(t, assist.MessageKindError, msg.Type)
require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.")
}
testCases := []struct {
name string
responses [][]byte
setup func(*WebSuite)
act func(*testing.T, *websocket.Conn)
}{
{
name: "normal",
responses: [][]byte{
generateTextResponse(),
},
act: func(t *testing.T, ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)
require.Contains(t, readPartialMessage(t, ws), "Which")
require.Contains(t, readPartialMessage(t, ws), "node do")
require.Contains(t, readPartialMessage(t, ws), "you want")
require.Contains(t, readPartialMessage(t, ws), "use?")
readStreamEnd(t, ws)
},
},
{
name: "rate limited",
responses: [][]byte{
generateTextResponse(),
generateTextResponse(),
},
setup: func(s *WebSuite) {
// 101 token capacity (lookaheadTokens+1) and a slow replenish rate
// to let the first completion request succeed, but not the second one
s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101)
},
act: func(t *testing.T, ws *websocket.Conn) {
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
require.NoError(t, err)
require.Contains(t, readPartialMessage(t, ws), "Which")
require.Contains(t, readPartialMessage(t, ws), "node do")
require.Contains(t, readPartialMessage(t, ws), "you want")
require.Contains(t, readPartialMessage(t, ws), "use?")
readStreamEnd(t, ws)
err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`))
require.NoError(t, err)
readRateLimitedMessage(t, ws)
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
responses := tc.responses
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
dataBytes := responses[0]
_, err := w.Write(dataBytes)
require.NoError(t, err, "Write error")
responses = responses[1:]
}))
t.Cleanup(server.Close)
openaiCfg := openai.DefaultConfig("test-token")
openaiCfg.BaseURL = server.URL
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
if tc.setup != nil {
tc.setup(s)
}
ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, ws.Close()) })
_, payload, err := ws.ReadMessage()
require.NoError(t, err)
var msg assistantMessage
err = json.Unmarshal(payload, &msg)
require.NoError(t, err)
// Expect "hello" message
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
tc.act(t, ws)
})
}
}
func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) {

View file

@ -165,7 +165,10 @@ async function convertServerMessage(
message: ServerMessage,
clusterId: string
): Promise<MessagesAction> {
if (message.type === 'CHAT_MESSAGE_ASSISTANT') {
if (
message.type === 'CHAT_MESSAGE_ASSISTANT' ||
message.type === 'CHAT_MESSAGE_ERROR'
) {
const newMessage: Message = {
author: Author.Teleport,
timestamp: message.created_time,
@ -276,6 +279,8 @@ async function convertServerMessage(
return (messages: Message[]) => messages.push(newMessage);
}
throw new Error('unrecognized message type');
}
function findIntersection<T>(elems: T[][]): T[] {
@ -377,9 +382,12 @@ export function MessagesContextProvider(
if (lastMessage !== null) {
const value = JSON.parse(lastMessage.data) as ServerMessage;
// When a streaming message ends, or a non-streaming message arrives
if (
value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' ||
value.type === 'COMMAND'
value.type === 'COMMAND' ||
value.type === 'CHAT_MESSAGE_ASSISTANT' ||
value.type === 'CHAT_MESSAGE_ERROR'
) {
setResponding(false);
}