mirror of
https://github.com/Microsoft/vscode
synced 2024-10-04 02:14:06 +00:00
Joh/ministerial-swan (#212096)
* first cut of embeddings API https://github.com/microsoft/vscode/issues/212083 * add event * fix tests
This commit is contained in:
parent
facea7e2c2
commit
6874fc7394
|
@ -19,6 +19,7 @@
|
|||
"documentFiltersExclusive",
|
||||
"documentPaste",
|
||||
"editorInsets",
|
||||
"embeddings",
|
||||
"extensionRuntime",
|
||||
"extensionsAny",
|
||||
"externalUriOpener",
|
||||
|
@ -80,7 +81,7 @@
|
|||
"id": "api-test.participant2",
|
||||
"name": "participant2",
|
||||
"description": "test",
|
||||
"commands": [ ]
|
||||
"commands": []
|
||||
}
|
||||
],
|
||||
"configuration": {
|
||||
|
|
|
@ -20,6 +20,7 @@ import './mainThreadBulkEdits';
|
|||
import './mainThreadLanguageModels';
|
||||
import './mainThreadChatAgents2';
|
||||
import './mainThreadChatVariables';
|
||||
import './mainThreadEmbeddings';
|
||||
import './mainThreadCodeInsets';
|
||||
import './mainThreadCLICommands';
|
||||
import './mainThreadClipboard';
|
||||
|
|
112
src/vs/workbench/api/browser/mainThreadEmbeddings.ts
Normal file
112
src/vs/workbench/api/browser/mainThreadEmbeddings.ts
Normal file
|
@ -0,0 +1,112 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { CancellationToken } from 'vs/base/common/cancellation';
|
||||
import { Emitter, Event } from 'vs/base/common/event';
|
||||
import { DisposableMap, DisposableStore, IDisposable } from 'vs/base/common/lifecycle';
|
||||
import { InstantiationType, registerSingleton } from 'vs/platform/instantiation/common/extensions';
|
||||
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
|
||||
import { ExtHostContext, ExtHostEmbeddingsShape, MainContext, MainThreadEmbeddingsShape } from 'vs/workbench/api/common/extHost.protocol';
|
||||
import { extHostNamedCustomer, IExtHostContext } from 'vs/workbench/services/extensions/common/extHostCustomers';
|
||||
|
||||
|
||||
interface IEmbeddingsProvider {
|
||||
provideEmbeddings(input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
|
||||
}
|
||||
|
||||
const IEmbeddingsService = createDecorator<IEmbeddingsService>('embeddingsService');
|
||||
|
||||
interface IEmbeddingsService {
|
||||
|
||||
_serviceBrand: undefined;
|
||||
|
||||
readonly onDidChange: Event<void>;
|
||||
|
||||
allProviders: Iterable<string>;
|
||||
|
||||
registerProvider(id: string, provider: IEmbeddingsProvider): IDisposable;
|
||||
|
||||
computeEmbeddings(id: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
|
||||
}
|
||||
|
||||
class EmbeddingsService implements IEmbeddingsService {
|
||||
_serviceBrand: undefined;
|
||||
|
||||
private providers: Map<string, IEmbeddingsProvider>;
|
||||
|
||||
private readonly _onDidChange = new Emitter<void>();
|
||||
readonly onDidChange: Event<void> = this._onDidChange.event;
|
||||
|
||||
constructor() {
|
||||
this.providers = new Map<string, IEmbeddingsProvider>();
|
||||
}
|
||||
|
||||
get allProviders(): Iterable<string> {
|
||||
return this.providers.keys();
|
||||
}
|
||||
|
||||
registerProvider(id: string, provider: IEmbeddingsProvider): IDisposable {
|
||||
this.providers.set(id, provider);
|
||||
this._onDidChange.fire();
|
||||
return {
|
||||
dispose: () => {
|
||||
this.providers.delete(id);
|
||||
this._onDidChange.fire();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
computeEmbeddings(id: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
|
||||
const provider = this.providers.get(id);
|
||||
if (provider) {
|
||||
return provider.provideEmbeddings(input, token);
|
||||
} else {
|
||||
return Promise.reject(new Error(`No embeddings provider registered with id: ${id}`));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
registerSingleton(IEmbeddingsService, EmbeddingsService, InstantiationType.Delayed);
|
||||
|
||||
@extHostNamedCustomer(MainContext.MainThreadEmbeddings)
|
||||
export class MainThreadEmbeddings implements MainThreadEmbeddingsShape {
|
||||
|
||||
private readonly _store = new DisposableStore();
|
||||
private readonly _providers = this._store.add(new DisposableMap<number>);
|
||||
private readonly _proxy: ExtHostEmbeddingsShape;
|
||||
|
||||
constructor(
|
||||
context: IExtHostContext,
|
||||
@IEmbeddingsService private readonly embeddingsService: IEmbeddingsService
|
||||
) {
|
||||
this._proxy = context.getProxy(ExtHostContext.ExtHostEmbeddings);
|
||||
|
||||
this._store.add(embeddingsService.onDidChange((() => {
|
||||
this._proxy.$acceptEmbeddingModels(Array.from(embeddingsService.allProviders));
|
||||
})));
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
this._store.dispose();
|
||||
}
|
||||
|
||||
$registerEmbeddingProvider(handle: number, identifier: string): void {
|
||||
const registration = this.embeddingsService.registerProvider(identifier, {
|
||||
provideEmbeddings: (input: string[], token: CancellationToken): Promise<{ values: number[] }[]> => {
|
||||
return this._proxy.$provideEmbeddings(handle, input, token);
|
||||
}
|
||||
});
|
||||
this._providers.set(handle, registration);
|
||||
}
|
||||
|
||||
$unregisterEmbeddingProvider(handle: number): void {
|
||||
this._providers.deleteAndDispose(handle);
|
||||
}
|
||||
|
||||
$computeEmbeddings(embeddingsModel: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
|
||||
return this.embeddingsService.computeEmbeddings(embeddingsModel, input, token);
|
||||
}
|
||||
}
|
|
@ -44,6 +44,7 @@ import { ExtHostDocumentSaveParticipant } from 'vs/workbench/api/common/extHostD
|
|||
import { ExtHostDocuments } from 'vs/workbench/api/common/extHostDocuments';
|
||||
import { IExtHostDocumentsAndEditors } from 'vs/workbench/api/common/extHostDocumentsAndEditors';
|
||||
import { IExtHostEditorTabs } from 'vs/workbench/api/common/extHostEditorTabs';
|
||||
import { ExtHostEmbeddings } from 'vs/workbench/api/common/extHostEmbedding';
|
||||
import { ExtHostAiEmbeddingVector } from 'vs/workbench/api/common/extHostEmbeddingVector';
|
||||
import { Extension, IExtHostExtensionService } from 'vs/workbench/api/common/extHostExtensionService';
|
||||
import { ExtHostFileSystem } from 'vs/workbench/api/common/extHostFileSystem';
|
||||
|
@ -214,6 +215,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
|
|||
const extHostAiEmbeddingVector = rpcProtocol.set(ExtHostContext.ExtHostAiEmbeddingVector, new ExtHostAiEmbeddingVector(rpcProtocol));
|
||||
const extHostStatusBar = rpcProtocol.set(ExtHostContext.ExtHostStatusBar, new ExtHostStatusBar(rpcProtocol, extHostCommands.converter));
|
||||
const extHostSpeech = rpcProtocol.set(ExtHostContext.ExtHostSpeech, new ExtHostSpeech(rpcProtocol));
|
||||
const extHostEmbeddings = rpcProtocol.set(ExtHostContext.ExtHostEmbeddings, new ExtHostEmbeddings(rpcProtocol));
|
||||
|
||||
// Check that no named customers are missing
|
||||
const expected = Object.values<ProxyIdentifier<any>>(ExtHostContext);
|
||||
|
@ -1450,6 +1452,27 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
|
|||
getLanguageModelInformation(languageModel: string) {
|
||||
checkProposedApiEnabled(extension, 'languageModels');
|
||||
return extHostLanguageModels.getLanguageModelInfo(languageModel);
|
||||
},
|
||||
// --- embeddings
|
||||
get embeddingModels() {
|
||||
checkProposedApiEnabled(extension, 'embeddings');
|
||||
return extHostEmbeddings.embeddingsModels;
|
||||
},
|
||||
onDidChangeEmbeddingModels: (listener, thisArgs?, disposables?) => {
|
||||
checkProposedApiEnabled(extension, 'embeddings');
|
||||
return extHostEmbeddings.onDidChange(listener, thisArgs, disposables);
|
||||
},
|
||||
registerEmbeddingsProvider(embeddingsModel, provider) {
|
||||
checkProposedApiEnabled(extension, 'embeddings');
|
||||
return extHostEmbeddings.registerEmbeddingsProvider(extension, embeddingsModel, provider);
|
||||
},
|
||||
async computeEmbeddings(embeddingsModel, input, token?): Promise<any> {
|
||||
checkProposedApiEnabled(extension, 'embeddings');
|
||||
if (typeof input === 'string') {
|
||||
return extHostEmbeddings.computeEmbeddings(embeddingsModel, input, token);
|
||||
} else {
|
||||
return extHostEmbeddings.computeEmbeddings(embeddingsModel, input, token);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
@ -1216,6 +1216,17 @@ export interface ExtHostLanguageModelsShape {
|
|||
$provideTokenLength(handle: number, value: string | IChatMessage, token: CancellationToken): Promise<number>;
|
||||
}
|
||||
|
||||
export interface MainThreadEmbeddingsShape extends IDisposable {
|
||||
$registerEmbeddingProvider(handle: number, identifier: string): void;
|
||||
$unregisterEmbeddingProvider(handle: number): void;
|
||||
$computeEmbeddings(embeddingsModel: string, input: string[], token: CancellationToken): Promise<({ values: number[] }[])>;
|
||||
}
|
||||
|
||||
export interface ExtHostEmbeddingsShape {
|
||||
$provideEmbeddings(handle: number, input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
|
||||
$acceptEmbeddingModels(models: string[]): void;
|
||||
}
|
||||
|
||||
export interface IExtensionChatAgentMetadata extends Dto<IChatAgentMetadata> {
|
||||
hasFollowups?: boolean;
|
||||
}
|
||||
|
@ -2774,6 +2785,7 @@ export const MainContext = {
|
|||
MainThreadAuthentication: createProxyIdentifier<MainThreadAuthenticationShape>('MainThreadAuthentication'),
|
||||
MainThreadBulkEdits: createProxyIdentifier<MainThreadBulkEditsShape>('MainThreadBulkEdits'),
|
||||
MainThreadLanguageModels: createProxyIdentifier<MainThreadLanguageModelsShape>('MainThreadLanguageModels'),
|
||||
MainThreadEmbeddings: createProxyIdentifier<MainThreadEmbeddingsShape>('MainThreadEmbeddings'),
|
||||
MainThreadChatAgents2: createProxyIdentifier<MainThreadChatAgentsShape2>('MainThreadChatAgents2'),
|
||||
MainThreadChatVariables: createProxyIdentifier<MainThreadChatVariablesShape>('MainThreadChatVariables'),
|
||||
MainThreadClipboard: createProxyIdentifier<MainThreadClipboardShape>('MainThreadClipboard'),
|
||||
|
@ -2897,6 +2909,7 @@ export const ExtHostContext = {
|
|||
ExtHostChatVariables: createProxyIdentifier<ExtHostChatVariablesShape>('ExtHostChatVariables'),
|
||||
ExtHostChatProvider: createProxyIdentifier<ExtHostLanguageModelsShape>('ExtHostChatProvider'),
|
||||
ExtHostSpeech: createProxyIdentifier<ExtHostSpeechShape>('ExtHostSpeech'),
|
||||
ExtHostEmbeddings: createProxyIdentifier<ExtHostEmbeddingsShape>('ExtHostEmbeddings'),
|
||||
ExtHostAiRelatedInformation: createProxyIdentifier<ExtHostAiRelatedInformationShape>('ExtHostAiRelatedInformation'),
|
||||
ExtHostAiEmbeddingVector: createProxyIdentifier<ExtHostAiEmbeddingVectorShape>('ExtHostAiEmbeddingVector'),
|
||||
ExtHostTheming: createProxyIdentifier<ExtHostThemingShape>('ExtHostTheming'),
|
||||
|
|
92
src/vs/workbench/api/common/extHostEmbedding.ts
Normal file
92
src/vs/workbench/api/common/extHostEmbedding.ts
Normal file
|
@ -0,0 +1,92 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
import { CancellationToken } from 'vs/base/common/cancellation';
|
||||
import { Emitter, Event } from 'vs/base/common/event';
|
||||
import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
|
||||
import { IExtensionDescription } from 'vs/platform/extensions/common/extensions';
|
||||
import { ExtHostEmbeddingsShape, IMainContext, MainContext, MainThreadEmbeddingsShape } from 'vs/workbench/api/common/extHost.protocol';
|
||||
import type * as vscode from 'vscode';
|
||||
|
||||
|
||||
export class ExtHostEmbeddings implements ExtHostEmbeddingsShape {
|
||||
|
||||
private readonly _proxy: MainThreadEmbeddingsShape;
|
||||
private readonly _provider = new Map<number, { id: string; provider: vscode.EmbeddingsProvider }>();
|
||||
|
||||
private readonly _onDidChange = new Emitter<void>();
|
||||
readonly onDidChange: Event<void> = this._onDidChange.event;
|
||||
|
||||
private _allKnownModels = new Set<string>();
|
||||
private _handlePool: number = 0;
|
||||
|
||||
constructor(
|
||||
mainContext: IMainContext
|
||||
) {
|
||||
this._proxy = mainContext.getProxy(MainContext.MainThreadEmbeddings);
|
||||
}
|
||||
|
||||
registerEmbeddingsProvider(_extension: IExtensionDescription, embeddingsModel: string, provider: vscode.EmbeddingsProvider): IDisposable {
|
||||
if (this._allKnownModels.has(embeddingsModel)) {
|
||||
throw new Error('An embeddings provider for this model is already registered');
|
||||
}
|
||||
|
||||
const handle = this._handlePool++;
|
||||
|
||||
this._proxy.$registerEmbeddingProvider(handle, embeddingsModel);
|
||||
this._provider.set(handle, { id: embeddingsModel, provider });
|
||||
|
||||
return toDisposable(() => {
|
||||
this._proxy.$unregisterEmbeddingProvider(handle);
|
||||
this._provider.delete(handle);
|
||||
});
|
||||
}
|
||||
|
||||
async computeEmbeddings(embeddingsModel: string, input: string, token?: vscode.CancellationToken): Promise<vscode.Embedding>;
|
||||
async computeEmbeddings(embeddingsModel: string, input: string[], token?: vscode.CancellationToken): Promise<vscode.Embedding[]>;
|
||||
async computeEmbeddings(embeddingsModel: string, input: string | string[], token?: vscode.CancellationToken): Promise<vscode.Embedding[] | vscode.Embedding> {
|
||||
|
||||
token ??= CancellationToken.None;
|
||||
|
||||
let returnSingle = false;
|
||||
if (typeof input === 'string') {
|
||||
input = [input];
|
||||
returnSingle = true;
|
||||
}
|
||||
const result = await this._proxy.$computeEmbeddings(embeddingsModel, input, token);
|
||||
if (result.length !== input.length) {
|
||||
throw new Error();
|
||||
}
|
||||
if (returnSingle) {
|
||||
if (result.length !== 1) {
|
||||
throw new Error();
|
||||
}
|
||||
return result[0];
|
||||
}
|
||||
return result;
|
||||
|
||||
}
|
||||
|
||||
async $provideEmbeddings(handle: number, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
|
||||
const data = this._provider.get(handle);
|
||||
if (!data) {
|
||||
return [];
|
||||
}
|
||||
const result = await data.provider.provideEmbeddings(input, token);
|
||||
if (!result) {
|
||||
return [];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
get embeddingsModels(): string[] {
|
||||
return Array.from(this._allKnownModels);
|
||||
}
|
||||
|
||||
$acceptEmbeddingModels(models: string[]): void {
|
||||
this._allKnownModels = new Set(models);
|
||||
this._onDidChange.fire();
|
||||
}
|
||||
}
|
|
@ -60,6 +60,7 @@ export const allApiProposals = Object.freeze({
|
|||
editSessionIdentityProvider: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editSessionIdentityProvider.d.ts',
|
||||
editorHoverVerbosityLevel: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editorHoverVerbosityLevel.d.ts',
|
||||
editorInsets: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editorInsets.d.ts',
|
||||
embeddings: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.embeddings.d.ts',
|
||||
extensionRuntime: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.extensionRuntime.d.ts',
|
||||
extensionsAny: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.extensionsAny.d.ts',
|
||||
externalUriOpener: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.externalUriOpener.d.ts',
|
||||
|
|
33
src/vscode-dts/vscode.proposed.embeddings.d.ts
vendored
Normal file
33
src/vscode-dts/vscode.proposed.embeddings.d.ts
vendored
Normal file
|
@ -0,0 +1,33 @@
|
|||
/*---------------------------------------------------------------------------------------------
|
||||
* Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
* Licensed under the MIT License. See License.txt in the project root for license information.
|
||||
*--------------------------------------------------------------------------------------------*/
|
||||
|
||||
declare module 'vscode' {
|
||||
|
||||
// https://github.com/microsoft/vscode/issues/212083
|
||||
|
||||
export interface Embedding {
|
||||
readonly values: number[];
|
||||
}
|
||||
|
||||
// TODO@API strictly not the right namespace...
|
||||
export namespace lm {
|
||||
|
||||
export const embeddingModels: string[];
|
||||
|
||||
export const onDidChangeEmbeddingModels: Event<void>;
|
||||
|
||||
export function computeEmbeddings(embeddingsModel: string, input: string, token?: CancellationToken): Thenable<Embedding>;
|
||||
|
||||
export function computeEmbeddings(embeddingsModel: string, input: string[], token?: CancellationToken): Thenable<Embedding[]>;
|
||||
}
|
||||
|
||||
export interface EmbeddingsProvider {
|
||||
provideEmbeddings(input: string[], token: CancellationToken): ProviderResult<Embedding[]>;
|
||||
}
|
||||
|
||||
export namespace lm {
|
||||
export function registerEmbeddingsProvider(embeddingsModel: string, provider: EmbeddingsProvider): Disposable;
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue