From 21c534b17c835bebcf715f43f845d89e8e8df222 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Justinas=20Stankevi=C4=8Dius?= Date: Mon, 15 May 2023 18:48:52 +0300 Subject: [PATCH] Add rate limiting to Assist (#26011) * Add rate limiting to Assist * Only rate limit Assist in Cloud * Add a comment to assistantLimiter * Fixes after rebase * Add 'rate-limited' test case to assistant_test * Handle CHAT_MESSAGE_UI in Assist web UI * Add godoc * CHAT_MESSAGE_UI -> CHAT_MESSAGE_ERROR * Run assistant test cases in parallel --- lib/assist/assist.go | 8 +- lib/web/apiserver.go | 27 ++- lib/web/assistant.go | 23 ++- lib/web/assistant_test.go | 166 +++++++++++++----- .../teleport/src/Assist/contexts/messages.tsx | 12 +- 5 files changed, 180 insertions(+), 56 deletions(-) diff --git a/lib/assist/assist.go b/lib/assist/assist.go index 1a003cbce7f..0e24731d08d 100644 --- a/lib/assist/assist.go +++ b/lib/assist/assist.go @@ -58,6 +58,8 @@ const ( MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE" // MessageKindSystemMessage is the type of Assist message that contains the system message. MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM" + // MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like. + MessageKindError MessageType = "CHAT_MESSAGE_ERROR" ) // Assist is the Teleport Assist client. @@ -210,7 +212,7 @@ type TokensUsed struct { // Prompt is a number of tokens used in the prompt. Prompt int // Completion is a number of tokens used in the completion. - Competition int + Completion int } // ProcessComplete processes the completion request and returns the number of tokens used. @@ -386,8 +388,8 @@ func (c *Chat) ProcessComplete(ctx context.Context, } return &TokensUsed{ - Prompt: promptTokens, - Competition: numTokens, + Prompt: promptTokens, + Completion: numTokens, }, nil } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index bbc83f4b204..2a4f907457b 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -51,6 +51,7 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/exp/slices" "golang.org/x/mod/semver" + "golang.org/x/time/rate" "google.golang.org/protobuf/encoding/protojson" "github.com/gravitational/teleport" @@ -92,6 +93,15 @@ import ( const ( // SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages. SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details." + + // assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour. + assistantTokensPerHour = 140 + // assistantLimiterRate is the rate (in tokens per second) + // at which tokens for the assistant rate limiter are replenished + assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second)) + // assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter. + // The bucket starts full, prefilled for a week. + assistantLimiterCapacity = assistantTokensPerHour * 24 * 7 ) // healthCheckAppServerFunc defines a function used to perform a health check @@ -111,7 +121,13 @@ type Handler struct { clock clockwork.Clock limiter *limiter.RateLimiter highLimiter *limiter.RateLimiter - healthCheckAppServer healthCheckAppServerFunc + // assistantLimiter limits the amount of tokens that can be consumed + // by OpenAI API calls when using a shared key. + // golang.org/x/time/rate is used, as the oxy ratelimiter + // is quite tightly tied to individual http.Requests, + // and instead we want to consume arbitrary amounts of tokens. + assistantLimiter *rate.Limiter + healthCheckAppServer healthCheckAppServerFunc // sshPort specifies the SSH proxy port extracted // from configuration sshPort string @@ -301,6 +317,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { healthCheckAppServer: cfg.HealthCheckAppServer, } + // Check for self-hosted vs Cloud. + // TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud + if modules.GetModules().Features().Cloud { + h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity) + } else { + // Set up a limiter with "infinite limit", the "burst" parameter is ignored + h.assistantLimiter = rate.NewLimiter(rate.Inf, 0) + } + // for properly handling url-encoded parameter values. h.UseRawPath = true diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 171efb6d98d..db7197bc4b2 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -405,6 +405,17 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } + // We can not know how many tokens we will consume in advance. + // Try to consume a small amount of tokens first. + const lookaheadTokens = 100 + if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) { + err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC()) + if err != nil { + return trace.Wrap(err) + } + continue + } + //TODO(jakule): Should we sanitize the payload? if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil { return trace.Wrap(err) @@ -415,14 +426,22 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } + // Once we know how many tokens were consumed for prompt+completion, + // consume the remaining tokens from the rate limiter bucket. + extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens + if extraTokens < 0 { + extraTokens = 0 + } + h.assistantLimiter.ReserveN(time.Now(), extraTokens) + usageEventReq := &proto.SubmitUsageEventRequest{ Event: &usageeventsv1.UsageEventOneOf{ Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ AssistCompletion: &usageeventsv1.AssistCompletionEvent{ ConversationId: conversationID, - TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition), + TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion), PromptTokens: int64(usedTokens.Prompt), - CompletionTokens: int64(usedTokens.Competition), + CompletionTokens: int64(usedTokens.Completion), }, }, }, diff --git a/lib/web/assistant_test.go b/lib/web/assistant_test.go index 551dc1d273c..25cff80c560 100644 --- a/lib/web/assistant_test.go +++ b/lib/web/assistant_test.go @@ -32,6 +32,7 @@ import ( "github.com/gravitational/trace" "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/require" + "golang.org/x/time/rate" "github.com/gravitational/teleport/lib/assist" "github.com/gravitational/teleport/lib/client" @@ -40,46 +41,9 @@ import ( func Test_runAssistant(t *testing.T) { t.Parallel() - responses := [][]byte{ - generateTextResponse(), - } - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/event-stream") - - require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") - dataBytes := responses[0] - - _, err := w.Write(dataBytes) - require.NoError(t, err, "Write error") - - responses = responses[1:] - })) - defer server.Close() - - openaiCfg := openai.DefaultConfig("test-token") - openaiCfg.BaseURL = server.URL - s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) - - ws, err := s.makeAssistant(t, s.authPack(t, "foo")) - require.NoError(t, err) - t.Cleanup(func() { require.NoError(t, ws.Close()) }) - - _, payload, err := ws.ReadMessage() - require.NoError(t, err) - - var msg assistantMessage - err = json.Unmarshal(payload, &msg) - require.NoError(t, err) - - require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) - require.Contains(t, msg.Payload, "Hey, I'm Teleport") - - err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) - require.NoError(t, err) - - readPartialMessage := func() string { - _, payload, err = ws.ReadMessage() + readPartialMessage := func(t *testing.T, ws *websocket.Conn) string { + var msg assistantMessage + _, payload, err := ws.ReadMessage() require.NoError(t, err) err = json.Unmarshal(payload, &msg) @@ -89,13 +53,9 @@ func Test_runAssistant(t *testing.T) { return msg.Payload } - require.Contains(t, readPartialMessage(), "Which") - require.Contains(t, readPartialMessage(), "node do") - require.Contains(t, readPartialMessage(), "you want") - require.Contains(t, readPartialMessage(), "use?") - - readStraemEnd := func() { - _, payload, err = ws.ReadMessage() + readStreamEnd := func(t *testing.T, ws *websocket.Conn) { + var msg assistantMessage + _, payload, err := ws.ReadMessage() require.NoError(t, err) err = json.Unmarshal(payload, &msg) @@ -104,7 +64,117 @@ func Test_runAssistant(t *testing.T) { require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type) } - readStraemEnd() + readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) { + var msg assistantMessage + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + require.Equal(t, assist.MessageKindError, msg.Type) + require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.") + } + + testCases := []struct { + name string + responses [][]byte + setup func(*WebSuite) + act func(*testing.T, *websocket.Conn) + }{ + { + name: "normal", + responses: [][]byte{ + generateTextResponse(), + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + require.Contains(t, readPartialMessage(t, ws), "Which") + require.Contains(t, readPartialMessage(t, ws), "node do") + require.Contains(t, readPartialMessage(t, ws), "you want") + require.Contains(t, readPartialMessage(t, ws), "use?") + + readStreamEnd(t, ws) + }, + }, + { + name: "rate limited", + responses: [][]byte{ + generateTextResponse(), + generateTextResponse(), + }, + setup: func(s *WebSuite) { + // 101 token capacity (lookaheadTokens+1) and a slow replenish rate + // to let the first completion request succeed, but not the second one + s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101) + + }, + act: func(t *testing.T, ws *websocket.Conn) { + err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`)) + require.NoError(t, err) + + require.Contains(t, readPartialMessage(t, ws), "Which") + require.Contains(t, readPartialMessage(t, ws), "node do") + require.Contains(t, readPartialMessage(t, ws), "you want") + require.Contains(t, readPartialMessage(t, ws), "use?") + + readStreamEnd(t, ws) + + err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`)) + require.NoError(t, err) + + readRateLimitedMessage(t, ws) + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + responses := tc.responses + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + + require.GreaterOrEqual(t, len(responses), 1, "Unexpected request") + dataBytes := responses[0] + + _, err := w.Write(dataBytes) + require.NoError(t, err, "Write error") + + responses = responses[1:] + })) + t.Cleanup(server.Close) + + openaiCfg := openai.DefaultConfig("test-token") + openaiCfg.BaseURL = server.URL + s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg}) + + if tc.setup != nil { + tc.setup(s) + } + + ws, err := s.makeAssistant(t, s.authPack(t, "foo")) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) }) + + _, payload, err := ws.ReadMessage() + require.NoError(t, err) + + var msg assistantMessage + err = json.Unmarshal(payload, &msg) + require.NoError(t, err) + + // Expect "hello" message + require.Equal(t, assist.MessageKindAssistantMessage, msg.Type) + require.Contains(t, msg.Payload, "Hey, I'm Teleport") + + tc.act(t, ws) + }) + } + } func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) { diff --git a/web/packages/teleport/src/Assist/contexts/messages.tsx b/web/packages/teleport/src/Assist/contexts/messages.tsx index 2f65447d463..ad99de260e0 100644 --- a/web/packages/teleport/src/Assist/contexts/messages.tsx +++ b/web/packages/teleport/src/Assist/contexts/messages.tsx @@ -165,7 +165,10 @@ async function convertServerMessage( message: ServerMessage, clusterId: string ): Promise { - if (message.type === 'CHAT_MESSAGE_ASSISTANT') { + if ( + message.type === 'CHAT_MESSAGE_ASSISTANT' || + message.type === 'CHAT_MESSAGE_ERROR' + ) { const newMessage: Message = { author: Author.Teleport, timestamp: message.created_time, @@ -276,6 +279,8 @@ async function convertServerMessage( return (messages: Message[]) => messages.push(newMessage); } + + throw new Error('unrecognized message type'); } function findIntersection(elems: T[][]): T[] { @@ -377,9 +382,12 @@ export function MessagesContextProvider( if (lastMessage !== null) { const value = JSON.parse(lastMessage.data) as ServerMessage; + // When a streaming message ends, or a non-streaming message arrives if ( value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' || - value.type === 'COMMAND' + value.type === 'COMMAND' || + value.type === 'CHAT_MESSAGE_ASSISTANT' || + value.type === 'CHAT_MESSAGE_ERROR' ) { setResponding(false); }