mirror of
https://github.com/gravitational/teleport
synced 2024-10-21 09:44:51 +00:00
Add rate limiting to Assist (#26011)
* Add rate limiting to Assist * Only rate limit Assist in Cloud * Add a comment to assistantLimiter * Fixes after rebase * Add 'rate-limited' test case to assistant_test * Handle CHAT_MESSAGE_UI in Assist web UI * Add godoc * CHAT_MESSAGE_UI -> CHAT_MESSAGE_ERROR * Run assistant test cases in parallel
This commit is contained in:
parent
9f794f1049
commit
21c534b17c
|
@ -58,6 +58,8 @@ const (
|
|||
MessageKindAssistantPartialFinalize MessageType = "CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE"
|
||||
// MessageKindSystemMessage is the type of Assist message that contains the system message.
|
||||
MessageKindSystemMessage MessageType = "CHAT_MESSAGE_SYSTEM"
|
||||
// MessageKindError is the type of Assist message that is presented to user as information, but not stored persistently in the conversation. This can include backend error messages and the like.
|
||||
MessageKindError MessageType = "CHAT_MESSAGE_ERROR"
|
||||
)
|
||||
|
||||
// Assist is the Teleport Assist client.
|
||||
|
@ -210,7 +212,7 @@ type TokensUsed struct {
|
|||
// Prompt is a number of tokens used in the prompt.
|
||||
Prompt int
|
||||
// Completion is a number of tokens used in the completion.
|
||||
Competition int
|
||||
Completion int
|
||||
}
|
||||
|
||||
// ProcessComplete processes the completion request and returns the number of tokens used.
|
||||
|
@ -386,8 +388,8 @@ func (c *Chat) ProcessComplete(ctx context.Context,
|
|||
}
|
||||
|
||||
return &TokensUsed{
|
||||
Prompt: promptTokens,
|
||||
Competition: numTokens,
|
||||
Prompt: promptTokens,
|
||||
Completion: numTokens,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ import (
|
|||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/mod/semver"
|
||||
"golang.org/x/time/rate"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
"github.com/gravitational/teleport"
|
||||
|
@ -92,6 +93,15 @@ import (
|
|||
const (
|
||||
// SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages.
|
||||
SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details."
|
||||
|
||||
// assistantTokensPerHour defines how many assistant rate limiter tokens are replenished every hour.
|
||||
assistantTokensPerHour = 140
|
||||
// assistantLimiterRate is the rate (in tokens per second)
|
||||
// at which tokens for the assistant rate limiter are replenished
|
||||
assistantLimiterRate = rate.Limit(assistantTokensPerHour / float64(time.Hour/time.Second))
|
||||
// assistantLimiterCapacity is the total capacity of the token bucket for the assistant rate limiter.
|
||||
// The bucket starts full, prefilled for a week.
|
||||
assistantLimiterCapacity = assistantTokensPerHour * 24 * 7
|
||||
)
|
||||
|
||||
// healthCheckAppServerFunc defines a function used to perform a health check
|
||||
|
@ -111,7 +121,13 @@ type Handler struct {
|
|||
clock clockwork.Clock
|
||||
limiter *limiter.RateLimiter
|
||||
highLimiter *limiter.RateLimiter
|
||||
healthCheckAppServer healthCheckAppServerFunc
|
||||
// assistantLimiter limits the amount of tokens that can be consumed
|
||||
// by OpenAI API calls when using a shared key.
|
||||
// golang.org/x/time/rate is used, as the oxy ratelimiter
|
||||
// is quite tightly tied to individual http.Requests,
|
||||
// and instead we want to consume arbitrary amounts of tokens.
|
||||
assistantLimiter *rate.Limiter
|
||||
healthCheckAppServer healthCheckAppServerFunc
|
||||
// sshPort specifies the SSH proxy port extracted
|
||||
// from configuration
|
||||
sshPort string
|
||||
|
@ -301,6 +317,15 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) {
|
|||
healthCheckAppServer: cfg.HealthCheckAppServer,
|
||||
}
|
||||
|
||||
// Check for self-hosted vs Cloud.
|
||||
// TODO(justinas): this needs to be modified when we allow user-supplied API keys in Cloud
|
||||
if modules.GetModules().Features().Cloud {
|
||||
h.assistantLimiter = rate.NewLimiter(assistantLimiterRate, assistantLimiterCapacity)
|
||||
} else {
|
||||
// Set up a limiter with "infinite limit", the "burst" parameter is ignored
|
||||
h.assistantLimiter = rate.NewLimiter(rate.Inf, 0)
|
||||
}
|
||||
|
||||
// for properly handling url-encoded parameter values.
|
||||
h.UseRawPath = true
|
||||
|
||||
|
|
|
@ -405,6 +405,17 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// We can not know how many tokens we will consume in advance.
|
||||
// Try to consume a small amount of tokens first.
|
||||
const lookaheadTokens = 100
|
||||
if !h.assistantLimiter.AllowN(time.Now(), lookaheadTokens) {
|
||||
err := onMessageFn(assist.MessageKindError, []byte("You have reached the rate limit. Please try again later."), h.clock.Now().UTC())
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
//TODO(jakule): Should we sanitize the payload?
|
||||
if err := chat.InsertAssistantMessage(ctx, assist.MessageKindUserMessage, wsIncoming.Payload); err != nil {
|
||||
return trace.Wrap(err)
|
||||
|
@ -415,14 +426,22 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
// Once we know how many tokens were consumed for prompt+completion,
|
||||
// consume the remaining tokens from the rate limiter bucket.
|
||||
extraTokens := usedTokens.Prompt + usedTokens.Completion - lookaheadTokens
|
||||
if extraTokens < 0 {
|
||||
extraTokens = 0
|
||||
}
|
||||
h.assistantLimiter.ReserveN(time.Now(), extraTokens)
|
||||
|
||||
usageEventReq := &proto.SubmitUsageEventRequest{
|
||||
Event: &usageeventsv1.UsageEventOneOf{
|
||||
Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{
|
||||
AssistCompletion: &usageeventsv1.AssistCompletionEvent{
|
||||
ConversationId: conversationID,
|
||||
TotalTokens: int64(usedTokens.Prompt + usedTokens.Competition),
|
||||
TotalTokens: int64(usedTokens.Prompt + usedTokens.Completion),
|
||||
PromptTokens: int64(usedTokens.Prompt),
|
||||
CompletionTokens: int64(usedTokens.Competition),
|
||||
CompletionTokens: int64(usedTokens.Completion),
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/gravitational/trace"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/time/rate"
|
||||
|
||||
"github.com/gravitational/teleport/lib/assist"
|
||||
"github.com/gravitational/teleport/lib/client"
|
||||
|
@ -40,46 +41,9 @@ import (
|
|||
func Test_runAssistant(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
responses := [][]byte{
|
||||
generateTextResponse(),
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
|
||||
dataBytes := responses[0]
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
require.NoError(t, err, "Write error")
|
||||
|
||||
responses = responses[1:]
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
openaiCfg := openai.DefaultConfig("test-token")
|
||||
openaiCfg.BaseURL = server.URL
|
||||
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
|
||||
|
||||
ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, ws.Close()) })
|
||||
|
||||
_, payload, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
var msg assistantMessage
|
||||
err = json.Unmarshal(payload, &msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
|
||||
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
|
||||
|
||||
err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
readPartialMessage := func() string {
|
||||
_, payload, err = ws.ReadMessage()
|
||||
readPartialMessage := func(t *testing.T, ws *websocket.Conn) string {
|
||||
var msg assistantMessage
|
||||
_, payload, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = json.Unmarshal(payload, &msg)
|
||||
|
@ -89,13 +53,9 @@ func Test_runAssistant(t *testing.T) {
|
|||
return msg.Payload
|
||||
}
|
||||
|
||||
require.Contains(t, readPartialMessage(), "Which")
|
||||
require.Contains(t, readPartialMessage(), "node do")
|
||||
require.Contains(t, readPartialMessage(), "you want")
|
||||
require.Contains(t, readPartialMessage(), "use?")
|
||||
|
||||
readStraemEnd := func() {
|
||||
_, payload, err = ws.ReadMessage()
|
||||
readStreamEnd := func(t *testing.T, ws *websocket.Conn) {
|
||||
var msg assistantMessage
|
||||
_, payload, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = json.Unmarshal(payload, &msg)
|
||||
|
@ -104,7 +64,117 @@ func Test_runAssistant(t *testing.T) {
|
|||
require.Equal(t, assist.MessageKindAssistantPartialFinalize, msg.Type)
|
||||
}
|
||||
|
||||
readStraemEnd()
|
||||
readRateLimitedMessage := func(t *testing.T, ws *websocket.Conn) {
|
||||
var msg assistantMessage
|
||||
_, payload, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = json.Unmarshal(payload, &msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, assist.MessageKindError, msg.Type)
|
||||
require.Equal(t, msg.Payload, "You have reached the rate limit. Please try again later.")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
responses [][]byte
|
||||
setup func(*WebSuite)
|
||||
act func(*testing.T, *websocket.Conn)
|
||||
}{
|
||||
{
|
||||
name: "normal",
|
||||
responses: [][]byte{
|
||||
generateTextResponse(),
|
||||
},
|
||||
act: func(t *testing.T, ws *websocket.Conn) {
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, readPartialMessage(t, ws), "Which")
|
||||
require.Contains(t, readPartialMessage(t, ws), "node do")
|
||||
require.Contains(t, readPartialMessage(t, ws), "you want")
|
||||
require.Contains(t, readPartialMessage(t, ws), "use?")
|
||||
|
||||
readStreamEnd(t, ws)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rate limited",
|
||||
responses: [][]byte{
|
||||
generateTextResponse(),
|
||||
generateTextResponse(),
|
||||
},
|
||||
setup: func(s *WebSuite) {
|
||||
// 101 token capacity (lookaheadTokens+1) and a slow replenish rate
|
||||
// to let the first completion request succeed, but not the second one
|
||||
s.webHandler.handler.assistantLimiter = rate.NewLimiter(rate.Limit(0.001), 101)
|
||||
|
||||
},
|
||||
act: func(t *testing.T, ws *websocket.Conn) {
|
||||
err := ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "show free disk space"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Contains(t, readPartialMessage(t, ws), "Which")
|
||||
require.Contains(t, readPartialMessage(t, ws), "node do")
|
||||
require.Contains(t, readPartialMessage(t, ws), "you want")
|
||||
require.Contains(t, readPartialMessage(t, ws), "use?")
|
||||
|
||||
readStreamEnd(t, ws)
|
||||
|
||||
err = ws.WriteMessage(websocket.TextMessage, []byte(`{"payload": "all nodes, please"}`))
|
||||
require.NoError(t, err)
|
||||
|
||||
readRateLimitedMessage(t, ws)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
responses := tc.responses
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
require.GreaterOrEqual(t, len(responses), 1, "Unexpected request")
|
||||
dataBytes := responses[0]
|
||||
|
||||
_, err := w.Write(dataBytes)
|
||||
require.NoError(t, err, "Write error")
|
||||
|
||||
responses = responses[1:]
|
||||
}))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
openaiCfg := openai.DefaultConfig("test-token")
|
||||
openaiCfg.BaseURL = server.URL
|
||||
s := newWebSuiteWithConfig(t, webSuiteConfig{OpenAIConfig: &openaiCfg})
|
||||
|
||||
if tc.setup != nil {
|
||||
tc.setup(s)
|
||||
}
|
||||
|
||||
ws, err := s.makeAssistant(t, s.authPack(t, "foo"))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { require.NoError(t, ws.Close()) })
|
||||
|
||||
_, payload, err := ws.ReadMessage()
|
||||
require.NoError(t, err)
|
||||
|
||||
var msg assistantMessage
|
||||
err = json.Unmarshal(payload, &msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect "hello" message
|
||||
require.Equal(t, assist.MessageKindAssistantMessage, msg.Type)
|
||||
require.Contains(t, msg.Payload, "Hey, I'm Teleport")
|
||||
|
||||
tc.act(t, ws)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *WebSuite) makeAssistant(t *testing.T, pack *authPack) (*websocket.Conn, error) {
|
||||
|
|
|
@ -165,7 +165,10 @@ async function convertServerMessage(
|
|||
message: ServerMessage,
|
||||
clusterId: string
|
||||
): Promise<MessagesAction> {
|
||||
if (message.type === 'CHAT_MESSAGE_ASSISTANT') {
|
||||
if (
|
||||
message.type === 'CHAT_MESSAGE_ASSISTANT' ||
|
||||
message.type === 'CHAT_MESSAGE_ERROR'
|
||||
) {
|
||||
const newMessage: Message = {
|
||||
author: Author.Teleport,
|
||||
timestamp: message.created_time,
|
||||
|
@ -276,6 +279,8 @@ async function convertServerMessage(
|
|||
|
||||
return (messages: Message[]) => messages.push(newMessage);
|
||||
}
|
||||
|
||||
throw new Error('unrecognized message type');
|
||||
}
|
||||
|
||||
function findIntersection<T>(elems: T[][]): T[] {
|
||||
|
@ -377,9 +382,12 @@ export function MessagesContextProvider(
|
|||
if (lastMessage !== null) {
|
||||
const value = JSON.parse(lastMessage.data) as ServerMessage;
|
||||
|
||||
// When a streaming message ends, or a non-streaming message arrives
|
||||
if (
|
||||
value.type === 'CHAT_PARTIAL_MESSAGE_ASSISTANT_FINALIZE' ||
|
||||
value.type === 'COMMAND'
|
||||
value.type === 'COMMAND' ||
|
||||
value.type === 'CHAT_MESSAGE_ASSISTANT' ||
|
||||
value.type === 'CHAT_MESSAGE_ERROR'
|
||||
) {
|
||||
setResponding(false);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue