mirror of
https://github.com/gravitational/teleport
synced 2024-10-20 09:13:39 +00:00
Add the ability to run a specific tool to Assist. (#31113)
This commit contains: - a refactoring of the model action logic to be exposed through the `DoAction` function - Expose `RunTool()` capability through ai.Client and assist.Assist - Add the "audit-query" action that invokes directly the audit generation tool
This commit is contained in:
parent
699f1c23ad
commit
dac5af41d0
|
@ -321,13 +321,13 @@ func generateAccessRequestResponse(t *testing.T) string {
|
|||
func TestChat_Complete_AuditQuery(t *testing.T) {
|
||||
// Test setup: generate the responses that will be served by our OpenAI mock
|
||||
action := model.PlanOutput{
|
||||
Action: "Audit Query Generation",
|
||||
Action: tools.AuditQueryGenerationToolName,
|
||||
ActionInput: "Lists user who connected to a server as root.",
|
||||
Reasoning: "foo",
|
||||
}
|
||||
selectedAction, err := json.Marshal(action)
|
||||
require.NoError(t, err)
|
||||
generatedQuery := "SELECT user FROM session_start WHERE login='root'"
|
||||
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"
|
||||
|
||||
responses := []string{
|
||||
// The model must select the audit query tool
|
||||
|
|
|
@ -92,6 +92,30 @@ func (client *Client) NewCommand(username string) *Chat {
|
|||
}
|
||||
}
|
||||
|
||||
func (client *Client) RunTool(ctx context.Context, toolContext *modeltools.ToolContext, toolName, toolInput string) (any, *tokens.TokenCount, error) {
|
||||
tools := []modeltools.Tool{
|
||||
&modeltools.CommandExecutionTool{},
|
||||
&modeltools.EmbeddingRetrievalTool{},
|
||||
&modeltools.AuditQueryGenerationTool{LLM: client.svc},
|
||||
}
|
||||
// The following tools are only available in the enterprise build. They will fail
|
||||
// if included in OSS due to the lack of the required backend APIs.
|
||||
if modules.GetModules().BuildType() == modules.BuildEnterprise {
|
||||
tools = append(tools, &modeltools.AccessRequestCreateTool{},
|
||||
&modeltools.AccessRequestsListTool{},
|
||||
&modeltools.AccessRequestListRequestableRolesTool{},
|
||||
&modeltools.AccessRequestListRequestableResourcesTool{})
|
||||
}
|
||||
agent := model.NewAgent(toolContext, tools...)
|
||||
action := &model.AgentAction{
|
||||
Action: toolName,
|
||||
Input: toolInput,
|
||||
Reasoning: "Tool invoked directly",
|
||||
}
|
||||
|
||||
return agent.DoAction(ctx, client.svc, action)
|
||||
}
|
||||
|
||||
func (client *Client) NewAuditQuery(username string) *Chat {
|
||||
toolContext := &modeltools.ToolContext{User: username}
|
||||
return &Chat{
|
||||
|
|
110
lib/ai/client_test.go
Normal file
110
lib/ai/client_test.go
Normal file
|
@ -0,0 +1,110 @@
|
|||
/*
|
||||
* Copyright 2023 Gravitational, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
|
||||
"github.com/gravitational/teleport/lib/ai/model/output"
|
||||
"github.com/gravitational/teleport/lib/ai/model/tools"
|
||||
"github.com/gravitational/teleport/lib/ai/testutils"
|
||||
)
|
||||
|
||||
func TestRunTool_AuditQueryGeneration(t *testing.T) {
|
||||
// Test setup: starting a mock openai server and creating the client
|
||||
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"
|
||||
|
||||
responses := []string{
|
||||
// Then the audit query tool chooses to request session.start events
|
||||
"session.start",
|
||||
// Finally the tool builds a query based on the provided schemas
|
||||
generatedQuery,
|
||||
}
|
||||
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
cfg := openai.DefaultConfig("secret-test-token")
|
||||
cfg.BaseURL = server.URL
|
||||
|
||||
client := NewClientFromConfig(cfg)
|
||||
|
||||
// Doing the test: Check that the AuditQueryGeneration tool can be invoked
|
||||
// through client.RunTool and validate its response.
|
||||
ctx := context.Background()
|
||||
toolCtx := &tools.ToolContext{User: "alice"}
|
||||
response, _, err := client.RunTool(ctx, toolCtx, tools.AuditQueryGenerationToolName, "List users who connected to a server as root")
|
||||
require.NoError(t, err)
|
||||
message, ok := response.(*output.StreamingMessage)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, generatedQuery, message.WaitAndConsume())
|
||||
}
|
||||
|
||||
type mockEmbeddingGetter struct {
|
||||
response []*assistpb.EmbeddedDocument
|
||||
}
|
||||
|
||||
func (m *mockEmbeddingGetter) GetAssistantEmbeddings(ctx context.Context, in *assistpb.GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*assistpb.GetAssistantEmbeddingsResponse, error) {
|
||||
return &assistpb.GetAssistantEmbeddingsResponse{Embeddings: m.response}, nil
|
||||
}
|
||||
|
||||
func TestRunTool_EmbeddingRetrieval(t *testing.T) {
|
||||
// Test setup: starting a mock openai server and embedding getter,
|
||||
// then create the client.
|
||||
mock := &mockEmbeddingGetter{
|
||||
[]*assistpb.EmbeddedDocument{
|
||||
{
|
||||
Id: "1",
|
||||
Content: "foo",
|
||||
SimilarityScore: 1,
|
||||
},
|
||||
{
|
||||
Id: "2",
|
||||
Content: "bar",
|
||||
SimilarityScore: 0.9,
|
||||
},
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
toolCtx := &tools.ToolContext{AssistEmbeddingServiceClient: mock}
|
||||
|
||||
responses := make([]string, 0)
|
||||
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
|
||||
t.Cleanup(server.Close)
|
||||
|
||||
cfg := openai.DefaultConfig("secret-test-token")
|
||||
cfg.BaseURL = server.URL
|
||||
client := NewClientFromConfig(cfg)
|
||||
|
||||
// Doing the test: Check that the EmbeddingRetrieval tool can be invoked
|
||||
// through client.RunTool and validate its response.
|
||||
input := tools.EmbeddingRetrievalToolInput{Question: "Find foobar"}
|
||||
inputText, err := json.Marshal(input)
|
||||
require.NoError(t, err)
|
||||
response, _, err := client.RunTool(ctx, toolCtx, "Nodes names and labels retrieval", string(inputText))
|
||||
require.NoError(t, err)
|
||||
message, ok := response.(*output.Message)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "foo\nbar\n", message.Content)
|
||||
}
|
|
@ -133,6 +133,28 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist
|
|||
}
|
||||
}
|
||||
|
||||
func (a *Agent) DoAction(ctx context.Context, llm *openai.Client, action *AgentAction) (any, *tokens.TokenCount, error) {
|
||||
state := &executionState{
|
||||
llm: llm,
|
||||
tokenCount: tokens.NewTokenCount(),
|
||||
}
|
||||
out, err := a.doAction(ctx, state, action)
|
||||
if err != nil {
|
||||
return nil, nil, trace.Wrap(err)
|
||||
}
|
||||
switch {
|
||||
case out.finish != nil:
|
||||
// If the tool already breaks execution, we don't have to do anything
|
||||
return out.finish.output, state.tokenCount, nil
|
||||
case out.observation != "":
|
||||
// If the tool doesn't break execution and returns a single observation,
|
||||
// we wrap the observation in a Message.
|
||||
return &output.Message{Content: out.observation}, state.tokenCount, nil
|
||||
default:
|
||||
return nil, state.tokenCount, trace.Errorf("action %s did not end execution nor returned an observation", action.Action)
|
||||
}
|
||||
}
|
||||
|
||||
// stepOutput represents the inputs and outputs of a single thought step.
|
||||
type stepOutput struct {
|
||||
// if the agent is done, finish is set.
|
||||
|
@ -189,6 +211,10 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
|
|||
// If action is set, the agent is not done and called upon a tool.
|
||||
progressUpdates(action)
|
||||
|
||||
return a.doAction(ctx, state, action)
|
||||
}
|
||||
|
||||
func (a *Agent) doAction(ctx context.Context, state *executionState, action *AgentAction) (stepOutput, error) {
|
||||
var tool tools.Tool
|
||||
for _, candidate := range a.tools {
|
||||
if candidate.Name() == action.Action {
|
||||
|
|
|
@ -30,12 +30,14 @@ import (
|
|||
"github.com/gravitational/teleport/lib/ai/tokens"
|
||||
)
|
||||
|
||||
const AuditQueryGenerationToolName = "Audit Query Generation"
|
||||
|
||||
type AuditQueryGenerationTool struct {
|
||||
LLM *openai.Client
|
||||
}
|
||||
|
||||
func (t *AuditQueryGenerationTool) Name() string {
|
||||
return "Audit Query Generation"
|
||||
return AuditQueryGenerationToolName
|
||||
}
|
||||
|
||||
func (t *AuditQueryGenerationTool) Description() string {
|
||||
|
|
|
@ -176,6 +176,45 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
|
|||
return a.client.Summary(ctx, message)
|
||||
}
|
||||
|
||||
// RunTool runs a model tool without an ai.Chat.
|
||||
func (a *Assist) RunTool(ctx context.Context, onMessage onMessageFunc, toolName, userInput string, toolContext *tools.ToolContext,
|
||||
) (*tokens.TokenCount, error) {
|
||||
message, tc, err := a.client.RunTool(ctx, toolContext, toolName, userInput)
|
||||
if err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
|
||||
switch message := message.(type) {
|
||||
case *output.Message:
|
||||
if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), a.clock.Now().UTC()); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
case *output.GeneratedCommand:
|
||||
if err := onMessage(MessageKindCommand, []byte(message.Command), a.clock.Now().UTC()); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
case *output.StreamingMessage:
|
||||
if err := func() error {
|
||||
var text strings.Builder
|
||||
defer onMessage(MessageKindAssistantPartialFinalize, nil, a.clock.Now().UTC())
|
||||
for part := range message.Parts {
|
||||
text.WriteString(part)
|
||||
|
||||
if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), a.clock.Now().UTC()); err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}(); err != nil {
|
||||
return nil, trace.Wrap(err)
|
||||
}
|
||||
default:
|
||||
return nil, trace.Errorf("Unexpected message type: %T", message)
|
||||
}
|
||||
|
||||
return tc, nil
|
||||
}
|
||||
|
||||
// GenerateCommandSummary summarizes the output of a command executed on one or
|
||||
// many nodes. The conversation history is also sent into the prompt in order
|
||||
// to gather context and know what information is relevant in the command output.
|
||||
|
|
|
@ -48,6 +48,8 @@ const (
|
|||
actionSSHGenerateCommand = "ssh-cmdgen"
|
||||
// actionSSHExplainCommand is a name of the action for explaining terminal output in SSH session.
|
||||
actionSSHExplainCommand = "ssh-explain"
|
||||
// actionGenerateAuditQuery is the name of the action for generating audit queries.
|
||||
actionGenerateAuditQuery = "audit-query"
|
||||
)
|
||||
|
||||
// createAssistantConversationResponse is a response for POST /webapi/assistant/conversations.
|
||||
|
@ -489,6 +491,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
|
|||
err = h.assistGenSSHCommandLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
|
||||
case actionSSHExplainCommand:
|
||||
err = h.assistSSHExplainOutputLoop(ctx, assistClient, ws, authClient)
|
||||
case actionGenerateAuditQuery:
|
||||
err = h.assistGenAuditQueryLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
|
||||
default:
|
||||
err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws)
|
||||
}
|
||||
|
@ -500,6 +504,49 @@ type usageReporter interface {
|
|||
SubmitUsageEvent(ctx context.Context, req *proto.SubmitUsageEventRequest) error
|
||||
}
|
||||
|
||||
// assistGenAuditQueryLoop reads the user's input and generates an audit query.
|
||||
func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, username string, usageRep usageReporter) error {
|
||||
for {
|
||||
_, payload, err := ws.ReadMessage()
|
||||
if err != nil {
|
||||
if wsIsClosed(err) {
|
||||
break
|
||||
}
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error {
|
||||
return onMessageFn(ws, kind, payload, createdTime)
|
||||
}
|
||||
|
||||
toolCtx := &tools.ToolContext{User: username}
|
||||
|
||||
tokenCount, err := assistClient.RunTool(ctx, onMessage, tools.AuditQueryGenerationToolName, string(payload), toolCtx)
|
||||
if err != nil {
|
||||
return trace.Wrap(err)
|
||||
}
|
||||
|
||||
prompt, completion := tokens.CountTokens(tokenCount)
|
||||
|
||||
usageEventReq := &clientproto.SubmitUsageEventRequest{
|
||||
Event: &usageeventsv1.UsageEventOneOf{
|
||||
Event: &usageeventsv1.UsageEventOneOf_AssistAction{
|
||||
AssistAction: &usageeventsv1.AssistAction{
|
||||
Action: actionGenerateAuditQuery,
|
||||
TotalTokens: int64(completion + prompt),
|
||||
PromptTokens: int64(prompt),
|
||||
CompletionTokens: int64(completion),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if err := usageRep.SubmitUsageEvent(ctx, usageEventReq); err != nil {
|
||||
h.log.WithError(err).Warn("Failed to emit usage event")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// assistSSHExplainOutputLoop reads the user's input and generates a command summary.
|
||||
func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, usageRep usageReporter) error {
|
||||
_, payload, err := ws.ReadMessage()
|
||||
|
|
Loading…
Reference in a new issue