Add the ability to run a specific tool to Assist. (#31113)

This commit contains:
- a refactoring of the model action logic to be exposed through the
  `DoAction` function
- Expose `RunTool()` capability through ai.Client and assist.Assist
- Add the "audit-query" action that invokes directly the audit
  generation tool
This commit is contained in:
Hugo Shaka 2023-08-31 16:38:32 -04:00 committed by GitHub
parent 699f1c23ad
commit dac5af41d0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 251 additions and 3 deletions

View file

@ -321,13 +321,13 @@ func generateAccessRequestResponse(t *testing.T) string {
func TestChat_Complete_AuditQuery(t *testing.T) {
// Test setup: generate the responses that will be served by our OpenAI mock
action := model.PlanOutput{
Action: "Audit Query Generation",
Action: tools.AuditQueryGenerationToolName,
ActionInput: "Lists user who connected to a server as root.",
Reasoning: "foo",
}
selectedAction, err := json.Marshal(action)
require.NoError(t, err)
generatedQuery := "SELECT user FROM session_start WHERE login='root'"
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"
responses := []string{
// The model must select the audit query tool

View file

@ -92,6 +92,30 @@ func (client *Client) NewCommand(username string) *Chat {
}
}
func (client *Client) RunTool(ctx context.Context, toolContext *modeltools.ToolContext, toolName, toolInput string) (any, *tokens.TokenCount, error) {
tools := []modeltools.Tool{
&modeltools.CommandExecutionTool{},
&modeltools.EmbeddingRetrievalTool{},
&modeltools.AuditQueryGenerationTool{LLM: client.svc},
}
// The following tools are only available in the enterprise build. They will fail
// if included in OSS due to the lack of the required backend APIs.
if modules.GetModules().BuildType() == modules.BuildEnterprise {
tools = append(tools, &modeltools.AccessRequestCreateTool{},
&modeltools.AccessRequestsListTool{},
&modeltools.AccessRequestListRequestableRolesTool{},
&modeltools.AccessRequestListRequestableResourcesTool{})
}
agent := model.NewAgent(toolContext, tools...)
action := &model.AgentAction{
Action: toolName,
Input: toolInput,
Reasoning: "Tool invoked directly",
}
return agent.DoAction(ctx, client.svc, action)
}
func (client *Client) NewAuditQuery(username string) *Chat {
toolContext := &modeltools.ToolContext{User: username}
return &Chat{

110
lib/ai/client_test.go Normal file
View file

@ -0,0 +1,110 @@
/*
* Copyright 2023 Gravitational, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ai
import (
"context"
"encoding/json"
"net/http/httptest"
"testing"
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
assistpb "github.com/gravitational/teleport/api/gen/proto/go/assist/v1"
"github.com/gravitational/teleport/lib/ai/model/output"
"github.com/gravitational/teleport/lib/ai/model/tools"
"github.com/gravitational/teleport/lib/ai/testutils"
)
func TestRunTool_AuditQueryGeneration(t *testing.T) {
// Test setup: starting a mock openai server and creating the client
const generatedQuery = "SELECT user FROM session_start WHERE login='root'"
responses := []string{
// Then the audit query tool chooses to request session.start events
"session.start",
// Finally the tool builds a query based on the provided schemas
generatedQuery,
}
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)
cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL
client := NewClientFromConfig(cfg)
// Doing the test: Check that the AuditQueryGeneration tool can be invoked
// through client.RunTool and validate its response.
ctx := context.Background()
toolCtx := &tools.ToolContext{User: "alice"}
response, _, err := client.RunTool(ctx, toolCtx, tools.AuditQueryGenerationToolName, "List users who connected to a server as root")
require.NoError(t, err)
message, ok := response.(*output.StreamingMessage)
require.True(t, ok)
require.Equal(t, generatedQuery, message.WaitAndConsume())
}
type mockEmbeddingGetter struct {
response []*assistpb.EmbeddedDocument
}
func (m *mockEmbeddingGetter) GetAssistantEmbeddings(ctx context.Context, in *assistpb.GetAssistantEmbeddingsRequest, opts ...grpc.CallOption) (*assistpb.GetAssistantEmbeddingsResponse, error) {
return &assistpb.GetAssistantEmbeddingsResponse{Embeddings: m.response}, nil
}
func TestRunTool_EmbeddingRetrieval(t *testing.T) {
// Test setup: starting a mock openai server and embedding getter,
// then create the client.
mock := &mockEmbeddingGetter{
[]*assistpb.EmbeddedDocument{
{
Id: "1",
Content: "foo",
SimilarityScore: 1,
},
{
Id: "2",
Content: "bar",
SimilarityScore: 0.9,
},
},
}
ctx := context.Background()
toolCtx := &tools.ToolContext{AssistEmbeddingServiceClient: mock}
responses := make([]string, 0)
server := httptest.NewServer(testutils.GetTestHandlerFn(t, responses))
t.Cleanup(server.Close)
cfg := openai.DefaultConfig("secret-test-token")
cfg.BaseURL = server.URL
client := NewClientFromConfig(cfg)
// Doing the test: Check that the EmbeddingRetrieval tool can be invoked
// through client.RunTool and validate its response.
input := tools.EmbeddingRetrievalToolInput{Question: "Find foobar"}
inputText, err := json.Marshal(input)
require.NoError(t, err)
response, _, err := client.RunTool(ctx, toolCtx, "Nodes names and labels retrieval", string(inputText))
require.NoError(t, err)
message, ok := response.(*output.Message)
require.True(t, ok)
require.Equal(t, "foo\nbar\n", message.Content)
}

View file

@ -133,6 +133,28 @@ func (a *Agent) PlanAndExecute(ctx context.Context, llm *openai.Client, chatHist
}
}
func (a *Agent) DoAction(ctx context.Context, llm *openai.Client, action *AgentAction) (any, *tokens.TokenCount, error) {
state := &executionState{
llm: llm,
tokenCount: tokens.NewTokenCount(),
}
out, err := a.doAction(ctx, state, action)
if err != nil {
return nil, nil, trace.Wrap(err)
}
switch {
case out.finish != nil:
// If the tool already breaks execution, we don't have to do anything
return out.finish.output, state.tokenCount, nil
case out.observation != "":
// If the tool doesn't break execution and returns a single observation,
// we wrap the observation in a Message.
return &output.Message{Content: out.observation}, state.tokenCount, nil
default:
return nil, state.tokenCount, trace.Errorf("action %s did not end execution nor returned an observation", action.Action)
}
}
// stepOutput represents the inputs and outputs of a single thought step.
type stepOutput struct {
// if the agent is done, finish is set.
@ -189,6 +211,10 @@ func (a *Agent) takeNextStep(ctx context.Context, state *executionState, progres
// If action is set, the agent is not done and called upon a tool.
progressUpdates(action)
return a.doAction(ctx, state, action)
}
func (a *Agent) doAction(ctx context.Context, state *executionState, action *AgentAction) (stepOutput, error) {
var tool tools.Tool
for _, candidate := range a.tools {
if candidate.Name() == action.Action {

View file

@ -30,12 +30,14 @@ import (
"github.com/gravitational/teleport/lib/ai/tokens"
)
const AuditQueryGenerationToolName = "Audit Query Generation"
type AuditQueryGenerationTool struct {
LLM *openai.Client
}
func (t *AuditQueryGenerationTool) Name() string {
return "Audit Query Generation"
return AuditQueryGenerationToolName
}
func (t *AuditQueryGenerationTool) Description() string {

View file

@ -176,6 +176,45 @@ func (a *Assist) GenerateSummary(ctx context.Context, message string) (string, e
return a.client.Summary(ctx, message)
}
// RunTool runs a model tool without an ai.Chat.
func (a *Assist) RunTool(ctx context.Context, onMessage onMessageFunc, toolName, userInput string, toolContext *tools.ToolContext,
) (*tokens.TokenCount, error) {
message, tc, err := a.client.RunTool(ctx, toolContext, toolName, userInput)
if err != nil {
return nil, trace.Wrap(err)
}
switch message := message.(type) {
case *output.Message:
if err := onMessage(MessageKindAssistantMessage, []byte(message.Content), a.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *output.GeneratedCommand:
if err := onMessage(MessageKindCommand, []byte(message.Command), a.clock.Now().UTC()); err != nil {
return nil, trace.Wrap(err)
}
case *output.StreamingMessage:
if err := func() error {
var text strings.Builder
defer onMessage(MessageKindAssistantPartialFinalize, nil, a.clock.Now().UTC())
for part := range message.Parts {
text.WriteString(part)
if err := onMessage(MessageKindAssistantPartialMessage, []byte(part), a.clock.Now().UTC()); err != nil {
return trace.Wrap(err)
}
}
return nil
}(); err != nil {
return nil, trace.Wrap(err)
}
default:
return nil, trace.Errorf("Unexpected message type: %T", message)
}
return tc, nil
}
// GenerateCommandSummary summarizes the output of a command executed on one or
// many nodes. The conversation history is also sent into the prompt in order
// to gather context and know what information is relevant in the command output.

View file

@ -48,6 +48,8 @@ const (
actionSSHGenerateCommand = "ssh-cmdgen"
// actionSSHExplainCommand is a name of the action for explaining terminal output in SSH session.
actionSSHExplainCommand = "ssh-explain"
// actionGenerateAuditQuery is the name of the action for generating audit queries.
actionGenerateAuditQuery = "audit-query"
)
// createAssistantConversationResponse is a response for POST /webapi/assistant/conversations.
@ -489,6 +491,8 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request,
err = h.assistGenSSHCommandLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
case actionSSHExplainCommand:
err = h.assistSSHExplainOutputLoop(ctx, assistClient, ws, authClient)
case actionGenerateAuditQuery:
err = h.assistGenAuditQueryLoop(ctx, assistClient, ws, sctx.GetUser(), authClient)
default:
err = h.assistChatLoop(ctx, assistClient, authClient, conversationID, sctx, ws)
}
@ -500,6 +504,49 @@ type usageReporter interface {
SubmitUsageEvent(ctx context.Context, req *proto.SubmitUsageEventRequest) error
}
// assistGenAuditQueryLoop reads the user's input and generates an audit query.
func (h *Handler) assistGenAuditQueryLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, username string, usageRep usageReporter) error {
for {
_, payload, err := ws.ReadMessage()
if err != nil {
if wsIsClosed(err) {
break
}
return trace.Wrap(err)
}
onMessage := func(kind assist.MessageType, payload []byte, createdTime time.Time) error {
return onMessageFn(ws, kind, payload, createdTime)
}
toolCtx := &tools.ToolContext{User: username}
tokenCount, err := assistClient.RunTool(ctx, onMessage, tools.AuditQueryGenerationToolName, string(payload), toolCtx)
if err != nil {
return trace.Wrap(err)
}
prompt, completion := tokens.CountTokens(tokenCount)
usageEventReq := &clientproto.SubmitUsageEventRequest{
Event: &usageeventsv1.UsageEventOneOf{
Event: &usageeventsv1.UsageEventOneOf_AssistAction{
AssistAction: &usageeventsv1.AssistAction{
Action: actionGenerateAuditQuery,
TotalTokens: int64(completion + prompt),
PromptTokens: int64(prompt),
CompletionTokens: int64(completion),
},
},
},
}
if err := usageRep.SubmitUsageEvent(ctx, usageEventReq); err != nil {
h.log.WithError(err).Warn("Failed to emit usage event")
}
}
return nil
}
// assistSSHExplainOutputLoop reads the user's input and generates a command summary.
func (h *Handler) assistSSHExplainOutputLoop(ctx context.Context, assistClient *assist.Assist, ws *websocket.Conn, usageRep usageReporter) error {
_, payload, err := ws.ReadMessage()