mirror of
https://github.com/gravitational/teleport
synced 2024-10-19 16:53:57 +00:00
assist: emit action events on tool usage (#33737)
* emit correct events on assist actions * fix typo
This commit is contained in:
parent
c6193f439e
commit
a5c4df7ebf
|
@ -25,7 +25,6 @@ import (
|
|||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/gravitational/trace"
|
||||
"github.com/julienschmidt/httprouter"
|
||||
|
@ -343,11 +342,8 @@ func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID string, authClient auth.ClientI) {
|
||||
// Create a new context to not be bounded by the request timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// reserveTokens preemptively reserves tokens in the ratelimiter.
|
||||
func (h *Handler) reserveTokens(usedTokens *tokens.TokenCount) (int, int) {
|
||||
promptTokens, completionTokens := usedTokens.CountAll()
|
||||
|
||||
// Once we know how many tokens were consumed for prompt+completion,
|
||||
|
@ -357,7 +353,16 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID
|
|||
extraTokens = 0
|
||||
}
|
||||
h.assistantLimiter.ReserveN(time.Now(), extraTokens)
|
||||
return promptTokens, completionTokens
|
||||
}
|
||||
|
||||
// reportTokenUsage sends a token usage event for a conversation.
|
||||
func (h *Handler) reportConversationTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, conversationID string) {
|
||||
// Create a new context to not be bounded by the request timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
promptTokens, completionTokens := h.reserveTokens(usedTokens)
|
||||
usageEventReq := &proto.SubmitUsageEventRequest{
|
||||
Event: &usageeventsv1.UsageEventOneOf{
|
||||
Event: &usageeventsv1.UsageEventOneOf_AssistCompletion{
|
||||
|
@ -370,6 +375,32 @@ func (h *Handler) reportTokenUsage(usedTokens *tokens.TokenCount, conversationID
|
|||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil {
|
||||
h.log.WithError(err).Warn("Failed to emit usage event")
|
||||
}
|
||||
}
|
||||
|
||||
// reportTokenUsage sends a token usage event for an action.
|
||||
func (h *Handler) reportActionTokenUsage(authClient auth.ClientI, usedTokens *tokens.TokenCount, action string) {
|
||||
// Create a new context to not be bounded by the request timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
promptTokens, completionTokens := h.reserveTokens(usedTokens)
|
||||
usageEventReq := &proto.SubmitUsageEventRequest{
|
||||
Event: &usageeventsv1.UsageEventOneOf{
|
||||
Event: &usageeventsv1.UsageEventOneOf_AssistAction{
|
||||
AssistAction: &usageeventsv1.AssistAction{
|
||||
Action: action,
|
||||
TotalTokens: int64(promptTokens + completionTokens),
|
||||
PromptTokens: int64(promptTokens),
|
||||
CompletionTokens: int64(completionTokens),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if err := authClient.SubmitUsageEvent(ctx, usageEventReq); err != nil {
|
||||
h.log.WithError(err).Warn("Failed to emit usage event")
|
||||
}
|
||||
|
@ -529,7 +560,7 @@ func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *ass
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
|
||||
go h.reportActionTokenUsage(authClient, tokenCount, tools.AuditQueryGenerationToolName)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -568,7 +599,7 @@ func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
|
||||
go h.reportActionTokenUsage(authClient, tokenCount, "SSH Explain")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -603,7 +634,8 @@ func (h *Handler) assistGenSSHCommandLoop(ctx context.Context, assistClient *ass
|
|||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
go h.reportTokenUsage(tokenCount, uuid.NewString(), authClient)
|
||||
tool := tools.CommandExecutionTool{}
|
||||
go h.reportActionTokenUsage(authClient, tokenCount, tool.Name())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -673,7 +705,7 @@ func (h *Handler) assistChatLoop(ctx context.Context, assistClient *assist.Assis
|
|||
|
||||
// Token usage reporting is asynchronous as we might still be streaming
|
||||
// a message, and we don't want to block everything.
|
||||
go h.reportTokenUsage(usedTokens, conversationID, authClient)
|
||||
go h.reportConversationTokenUsage(authClient, usedTokens, conversationID)
|
||||
}
|
||||
|
||||
h.log.Debug("end assistant conversation loop")
|
||||
|
|
Loading…
Reference in a new issue