diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 0dc82b311d3..843f93f9d39 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -25,7 +25,6 @@ import ( "net/http" "time" - "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" @@ -343,11 +342,8 @@ func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter return nil, nil } -func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID string, authClient auth.ClientI) { - // Create a new context to not be bounded by the request timeout. - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - +// reserveTokens preemptively reserves tokens in the ratelimiter. +func (h *Handler) reserveTokens(usedTokens *tokens.TokenCount) (int, int) { promptTokens, completionTokens := usedTokens.CountAll() // Once we know how many tokens were consumed for prompt+completion, @@ -357,7 +353,16 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID extraTokens = 0 } h.assistantLimiter.ReserveN(time.Now(), extraTokens) + return promptTokens, completionTokens +} +// reportTokenUsage sends a token usage event for a conversation. +func (h *Handler) reportConversationTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, conversationID string) { + // Create a new context to not be bounded by the request timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + promptTokens, completionTokens := h.reserveTokens(usedTokens) usageEventReq := &proto.SubmitUsageEventRequest{ Event: &usageeventsv1.UsageEventOneOf{ Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{ @@ -370,6 +375,32 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID }, }, } + + if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil { + h.log.WithError(err).Warn("Failed to emit usage event") + } +} + +// reportTokenUsage sends a token usage event for an action. +func (h *Handler) reportActionTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, action string) { + // Create a new context to not be bounded by the request timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + promptTokens, completionTokens := h.reserveTokens(usedTokens) + usageEventReq := &proto.SubmitUsageEventRequest{ + Event: &usageeventsv1.UsageEventOneOf{ + Event: &usageeventsv1.UsageEventOneOf_AssistAction{ + AssistAction: &usageeventsv1.AssistAction{ + Action: action, + TotalTokens: int64(promptTokens + completionTokens), + PromptTokens: int64(promptTokens), + CompletionTokens: int64(completionTokens), + }, + }, + }, + } + if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil { h.log.WithError(err).Warn("Failed to emit usage event") } @@ -529,7 +560,7 @@ func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *ass return trace.Wrap(err) } - go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient) + go h.reportActionTokenUsage(authClient, tokenCount, tools.AuditQueryGenerationToolName) } return nil } @@ -568,7 +599,7 @@ func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient * return trace.Wrap(err) } - go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient) + go h.reportActionTokenUsage(authClient, tokenCount, "SSH Explain") return nil } @@ -603,7 +634,8 @@ func (h *Handler) assistGenSSHCommandLoop(ctx context.Context, assistClient *ass return trace.Wrap(err) } - go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient) + tool := tools.CommandExecutionTool{} + go h.reportActionTokenUsage(authClient, tokenCount, tool.Name()) } return nil } @@ -673,7 +705,7 @@ func (h *Handler) assistChatLoop(ctx context.Context, assistClient *assist.Assis // Token usage reporting is asynchronous as we might still be streaming // a message, and we don't want to block everything. - go h.reportTokenUsage(usedTokens, conversationID, authClient) + go h.reportConversationTokenUsage(authClient, usedTokens, conversationID) } h.log.Debug("end assistant conversation loop")