make sure errors are recreated when making LM requests (#216807)

* add integration tests for LanguageModelChat#sendRequest

* make sure errors are recreated when making LM requests

* disable test with a note for later

* fix remote integration tests
This commit is contained in:
Johannes Rieken 2024-06-24 09:50:48 +02:00 committed by GitHub
parent 34107733c7
commit feae5bf5d6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 217 additions and 42 deletions

View file

@ -7,13 +7,14 @@
"enabledApiProposals": [
"activeComment",
"authSession",
"defaultChatParticipant",
"chatParticipantPrivate",
"chatProvider",
"chatVariableResolver",
"contribViewsRemote",
"contribStatusBarItems",
"contribViewsRemote",
"createFileSystemWatcher",
"customEditorMove",
"defaultChatParticipant",
"diffCommand",
"documentFiltersExclusive",
"documentPaste",
@ -27,6 +28,8 @@
"findTextInFiles",
"fsChunks",
"interactive",
"languageStatusText",
"lmTools",
"mappedEditsProvider",
"notebookCellExecutionState",
"notebookDeprecated",
@ -35,26 +38,24 @@
"notebookMime",
"portsAttributes",
"quickPickSortByLabel",
"languageStatusText",
"resolvers",
"scmActionButton",
"scmSelectedProvider",
"scmTextDocument",
"scmValidation",
"taskPresentationGroup",
"telemetry",
"terminalDataWriteEvent",
"terminalDimensions",
"terminalShellIntegration",
"tunnels",
"testObserver",
"textSearchProvider",
"timeline",
"tokenInformation",
"treeViewActiveItem",
"treeViewReveal",
"workspaceTrust",
"telemetry",
"lmTools"
"tunnels",
"workspaceTrust"
],
"private": true,
"activationEvents": [],
@ -64,6 +65,11 @@
},
"icon": "media/icon.png",
"contributes": {
"languageModels": [
{
"vendor": "test-lm-vendor"
}
],
"chatParticipants": [
{
"id": "api-test.participant",

View file

@ -0,0 +1,153 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import 'mocha';
import * as assert from 'assert';
import * as vscode from 'vscode';
import { assertNoRpc, closeAllEditors, DeferredPromise, disposeAll } from '../utils';
suite('lm', function () {
let disposables: vscode.Disposable[] = [];
setup(function () {
disposables = [];
});
teardown(async function () {
assertNoRpc();
await closeAllEditors();
disposeAll(disposables);
});
test('lm request and stream', async function () {
let p: vscode.Progress<vscode.ChatResponseFragment> | undefined;
const defer = new DeferredPromise<void>();
disposables.push(vscode.lm.registerChatModelProvider('test-lm', {
async provideLanguageModelResponse(_messages, _options, _extensionId, progress, _token) {
p = progress;
return defer.p;
},
async provideTokenCount(_text, _token) {
return 1;
},
}, {
name: 'test-lm',
version: '1.0.0',
family: 'test',
vendor: 'test-lm-vendor',
maxInputTokens: 100,
maxOutputTokens: 100,
}));
const models = await vscode.lm.selectChatModels({ id: 'test-lm' });
assert.strictEqual(models.length, 1);
const request = await models[0].sendRequest([vscode.LanguageModelChatMessage.User('Hello')]);
// assert we have a request immediately
assert.ok(request);
assert.ok(p);
assert.strictEqual(defer.isSettled, false);
let streamDone = false;
let responseText = '';
const pp = (async () => {
for await (const chunk of request.text) {
responseText += chunk;
}
streamDone = true;
})();
assert.strictEqual(responseText, '');
assert.strictEqual(streamDone, false);
p.report({ index: 0, part: 'Hello' });
defer.complete();
await pp;
await new Promise(r => setTimeout(r, 1000));
assert.strictEqual(streamDone, true);
assert.strictEqual(responseText, 'Hello');
});
test('lm request fail', async function () {
disposables.push(vscode.lm.registerChatModelProvider('test-lm', {
async provideLanguageModelResponse(_messages, _options, _extensionId, _progress, _token) {
throw new Error('BAD');
},
async provideTokenCount(_text, _token) {
return 1;
},
}, {
name: 'test-lm',
version: '1.0.0',
family: 'test',
vendor: 'test-lm-vendor',
maxInputTokens: 100,
maxOutputTokens: 100,
}));
const models = await vscode.lm.selectChatModels({ id: 'test-lm' });
assert.strictEqual(models.length, 1);
try {
await models[0].sendRequest([vscode.LanguageModelChatMessage.User('Hello')]);
assert.ok(false, 'EXPECTED error');
} catch (error) {
assert.ok(error instanceof Error);
}
});
test('lm stream fail', async function () {
const defer = new DeferredPromise<void>();
disposables.push(vscode.lm.registerChatModelProvider('test-lm', {
async provideLanguageModelResponse(_messages, _options, _extensionId, _progress, _token) {
return defer.p;
},
async provideTokenCount(_text, _token) {
return 1;
},
}, {
name: 'test-lm',
version: '1.0.0',
family: 'test',
vendor: 'test-lm-vendor',
maxInputTokens: 100,
maxOutputTokens: 100,
}));
const models = await vscode.lm.selectChatModels({ id: 'test-lm' });
assert.strictEqual(models.length, 1);
const res = await models[0].sendRequest([vscode.LanguageModelChatMessage.User('Hello')]);
assert.ok(res);
const result = (async () => {
for await (const _chunk of res.text) {
}
})();
defer.error(new Error('STREAM FAIL'));
try {
await result;
assert.ok(false, 'EXPECTED error');
} catch (error) {
assert.ok(error);
// assert.ok(error instanceof Error); // todo@jrieken requires one more insiders
}
});
});

View file

@ -137,6 +137,19 @@ export function transformErrorForSerialization(error: any): any {
return error;
}
export function transformErrorFromSerialization(data: SerializedError): Error {
let error: Error;
if (data.noTelemetry) {
error = new ErrorNoTelemetry();
} else {
error = new Error();
error.name = data.name;
}
error.message = data.message;
error.stack = data.stack;
return error;
}
// see https://github.com/v8/v8/wiki/Stack%20Trace%20API#basic-stack-traces
export interface V8CallSite {
getThis(): unknown;

View file

@ -3,7 +3,7 @@
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { SerializedError, onUnexpectedError, ErrorNoTelemetry } from 'vs/base/common/errors';
import { SerializedError, onUnexpectedError, transformErrorFromSerialization } from 'vs/base/common/errors';
import { extHostNamedCustomer } from 'vs/workbench/services/extensions/common/extHostCustomers';
import { MainContext, MainThreadErrorsShape } from 'vs/workbench/api/common/extHost.protocol';
@ -16,11 +16,7 @@ export class MainThreadErrors implements MainThreadErrorsShape {
$onUnexpectedError(err: any | SerializedError): void {
if (err && err.$isError) {
const { name, message, stack } = err;
err = err.noTelemetry ? new ErrorNoTelemetry() : new Error();
err.message = message;
err.name = name;
err.stack = stack;
err = transformErrorFromSerialization(err);
}
onUnexpectedError(err);
}

View file

@ -6,7 +6,7 @@
import { Action } from 'vs/base/common/actions';
import { VSBuffer } from 'vs/base/common/buffer';
import { CancellationToken } from 'vs/base/common/cancellation';
import { SerializedError } from 'vs/base/common/errors';
import { SerializedError, transformErrorFromSerialization } from 'vs/base/common/errors';
import { FileAccess } from 'vs/base/common/network';
import Severity from 'vs/base/common/severity';
import { URI, UriComponents } from 'vs/base/common/uri';
@ -73,19 +73,13 @@ export class MainThreadExtensionService implements MainThreadExtensionServiceSha
this._internalExtensionService._onDidActivateExtension(extensionId, codeLoadingTime, activateCallTime, activateResolvedTime, activationReason);
}
$onExtensionRuntimeError(extensionId: ExtensionIdentifier, data: SerializedError): void {
const error = new Error();
error.name = data.name;
error.message = data.message;
error.stack = data.stack;
const error = transformErrorFromSerialization(data);
this._internalExtensionService._onExtensionRuntimeError(extensionId, error);
console.error(`[${extensionId.value}]${error.message}`);
console.error(error.stack);
}
async $onExtensionActivationError(extensionId: ExtensionIdentifier, data: SerializedError, missingExtensionDependency: MissingExtensionDependency | null): Promise<void> {
const error = new Error();
error.name = data.name;
error.message = data.message;
error.stack = data.stack;
const error = transformErrorFromSerialization(data);
this._internalExtensionService._onDidActivateExtensionError(extensionId, error);

View file

@ -5,6 +5,7 @@
import { AsyncIterableSource, DeferredPromise } from 'vs/base/common/async';
import { CancellationToken } from 'vs/base/common/cancellation';
import { SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from 'vs/base/common/errors';
import { Emitter, Event } from 'vs/base/common/event';
import { Disposable, DisposableMap, DisposableStore, IDisposable, toDisposable } from 'vs/base/common/lifecycle';
import { localize } from 'vs/nls';
@ -77,20 +78,26 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
this._providerRegistrations.set(handle, dipsosables);
}
async $handleResponsePart(requestId: number, chunk: IChatResponseFragment): Promise<void> {
this._pendingProgress.get(requestId)?.stream.emitOne(chunk);
async $reportResponsePart(requestId: number, chunk: IChatResponseFragment): Promise<void> {
const data = this._pendingProgress.get(requestId);
this._logService.trace('[LM] report response PART', Boolean(data), requestId, chunk);
if (data) {
data.stream.emitOne(chunk);
}
}
async $handleResponseDone(requestId: number, error: any | undefined): Promise<void> {
async $reportResponseDone(requestId: number, err: SerializedError | undefined): Promise<void> {
const data = this._pendingProgress.get(requestId);
this._logService.trace('[LM] report response DONE', Boolean(data), requestId, err);
if (data) {
this._pendingProgress.delete(requestId);
if (error) {
data.defer.error(error);
if (err) {
const error = transformErrorFromSerialization(err);
data.stream.reject(error);
data.defer.error(error);
} else {
data.defer.complete(undefined);
data.stream.resolve();
data.defer.complete(undefined);
}
}
}
@ -108,7 +115,7 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
}
async $tryStartChatRequest(extension: ExtensionIdentifier, providerId: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise<any> {
this._logService.debug('[CHAT] extension request STARTED', extension.value, requestId);
this._logService.trace('[CHAT] request STARTED', extension.value, requestId);
const response = await this._chatProviderService.sendChatRequest(providerId, extension, messages, options, token);
@ -116,24 +123,26 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
// This method must return before the response is done (has streamed all parts)
// and because of that we consume the stream without awaiting
// !!! IMPORTANT !!!
(async () => {
const streaming = (async () => {
try {
for await (const part of response.stream) {
this._logService.trace('[CHAT] request PART', extension.value, requestId, part);
await this._proxy.$acceptResponsePart(requestId, part);
}
this._logService.trace('[CHAT] request DONE', extension.value, requestId);
} catch (err) {
this._logService.error('[CHAT] extension request ERRORED in STREAM', err, extension.value, requestId);
this._proxy.$acceptResponseDone(requestId, err);
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
}
})();
// When the response is done (signaled via its result) we tell the EH
response.result.then(() => {
Promise.allSettled([response.result, streaming]).then(() => {
this._logService.debug('[CHAT] extension request DONE', extension.value, requestId);
this._proxy.$acceptResponseDone(requestId, undefined);
}, err => {
this._logService.error('[CHAT] extension request ERRORED', err, extension.value, requestId);
this._proxy.$acceptResponseDone(requestId, err);
this._proxy.$acceptResponseDone(requestId, transformErrorForSerialization(err));
});
}

View file

@ -1203,8 +1203,8 @@ export interface MainThreadLanguageModelsShape extends IDisposable {
$registerLanguageModelProvider(handle: number, identifier: string, metadata: ILanguageModelChatMetadata): void;
$unregisterProvider(handle: number): void;
$tryStartChatRequest(extension: ExtensionIdentifier, provider: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise<void>;
$handleResponsePart(requestId: number, chunk: IChatResponseFragment): Promise<void>;
$handleResponseDone(requestId: number, error: any | undefined): Promise<void>;
$reportResponsePart(requestId: number, chunk: IChatResponseFragment): Promise<void>;
$reportResponseDone(requestId: number, error: SerializedError | undefined): Promise<void>;
$selectChatModels(selector: ILanguageModelChatSelector): Promise<string[]>;
$whenLanguageModelChatRequestMade(identifier: string, extension: ExtensionIdentifier, participant?: string, tokenCount?: number): void;
$countTokens(provider: string, value: string | IChatMessage, token: CancellationToken): Promise<number>;
@ -1215,7 +1215,7 @@ export interface ExtHostLanguageModelsShape {
$updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void;
$startChatRequest(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise<void>;
$acceptResponsePart(requestId: number, chunk: IChatResponseFragment): Promise<void>;
$acceptResponseDone(requestId: number, error: any | undefined): Promise<void>;
$acceptResponseDone(requestId: number, error: SerializedError | undefined): Promise<void>;
$provideTokenLength(handle: number, value: string | IChatMessage, token: CancellationToken): Promise<number>;
}

View file

@ -6,7 +6,7 @@
import { AsyncIterableObject, AsyncIterableSource } from 'vs/base/common/async';
import { CancellationToken } from 'vs/base/common/cancellation';
import { toErrorMessage } from 'vs/base/common/errorMessage';
import { CancellationError } from 'vs/base/common/errors';
import { CancellationError, SerializedError, transformErrorForSerialization, transformErrorFromSerialization } from 'vs/base/common/errors';
import { Emitter, Event } from 'vs/base/common/event';
import { Iterable } from 'vs/base/common/iterator';
import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
@ -212,7 +212,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
return;
}
this._proxy.$handleResponsePart(requestId, { index: fragment.index, part });
this._proxy.$reportResponsePart(requestId, { index: fragment.index, part });
});
let p: Promise<any>;
@ -243,9 +243,9 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
}
p.then(() => {
this._proxy.$handleResponseDone(requestId, undefined);
this._proxy.$reportResponseDone(requestId, undefined);
}, err => {
this._proxy.$handleResponseDone(requestId, err);
this._proxy.$reportResponseDone(requestId, transformErrorForSerialization(err));
});
}
@ -411,7 +411,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
}
}
async $acceptResponseDone(requestId: number, error: any | undefined): Promise<void> {
async $acceptResponseDone(requestId: number, error: SerializedError | undefined): Promise<void> {
const data = this._pendingRequest.get(requestId);
if (!data) {
return;
@ -420,7 +420,7 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
if (error) {
// we error the stream because that's the only way to signal
// that the request has failed
data.res.reject(error);
data.res.reject(transformErrorFromSerialization(error));
} else {
data.res.resolve();
}

View file

@ -105,6 +105,10 @@ async function runTestsInBrowser(browserType: BrowserType, endpoint: url.UrlWith
console.error(`Error saving web client logs (${error})`);
}
if (args.debug) {
return;
}
try {
await browser.close();
} catch (error) {