assist: emit action events on tool usage (#33737)

* emit correct events on assist actions

* fix typo
This commit is contained in:
Joel 2023-11-08 23:57:23 +01:00 committed by GitHub
parent c6193f439e
commit a5c4df7ebf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -25,7 +25,6 @@ import (
"net/http"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
@ -343,11 +342,8 @@ func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter
return nil, nil
}
func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID string, authClient auth.ClientI) {
// Create a new context to not be bounded by the request timeout.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// reserveTokens preemptively reserves tokens in the ratelimiter.
func (h *Handler) reserveTokens(usedTokens *tokens.TokenCount) (int, int) {
promptTokens, completionTokens := usedTokens.CountAll()
// Once we know how many tokens were consumed for prompt+completion,
@ -357,7 +353,16 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID
extraTokens = 0
}
h.assistantLimiter.ReserveN(time.Now(), extraTokens)
return promptTokens, completionTokens
}
// reportTokenUsage sends a token usage event for a conversation.
func (h *Handler) reportConversationTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, conversationID string) {
// Create a new context to not be bounded by the request timeout.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
promptTokens, completionTokens := h.reserveTokens(usedTokens)
usageEventReq := &proto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{
@ -370,6 +375,32 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID
},
},
}
if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil {
h.log.WithError(err).Warn("Failed to emit usage event")
}
}
// reportTokenUsage sends a token usage event for an action.
func (h *Handler) reportActionTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, action string) {
// Create a new context to not be bounded by the request timeout.
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
promptTokens, completionTokens := h.reserveTokens(usedTokens)
usageEventReq := &proto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistAction{
AssistAction: &usageeventsv1.AssistAction{
Action: action,
TotalTokens: int64(promptTokens + completionTokens),
PromptTokens: int64(promptTokens),
CompletionTokens: int64(completionTokens),
},
},
},
}
if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil {
h.log.WithError(err).Warn("Failed to emit usage event")
}
@ -529,7 +560,7 @@ func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *ass
return trace.Wrap(err)
}
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
go h.reportActionTokenUsage(authClient, tokenCount, tools.AuditQueryGenerationToolName)
}
return nil
}
@ -568,7 +599,7 @@ func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *
return trace.Wrap(err)
}
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
go h.reportActionTokenUsage(authClient, tokenCount, "SSH Explain")
return nil
}
@ -603,7 +634,8 @@ func (h *Handler) assistGenSSHCommandLoop(ctx context.Context, assistClient *ass
return trace.Wrap(err)
}
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
tool := tools.CommandExecutionTool{}
go h.reportActionTokenUsage(authClient, tokenCount, tool.Name())
}
return nil
}
@ -673,7 +705,7 @@ func (h *Handler) assistChatLoop(ctx context.Context, assistClient *assist.Assis
// Token usage reporting is asynchronous as we might still be streaming
// a message, and we don't want to block everything.
go h.reportTokenUsage(usedTokens, conversationID, authClient)
go h.reportConversationTokenUsage(authClient, usedTokens, conversationID)
}
h.log.Debug("end assistant conversation loop")