Joh/ministerial-swan (#212096)

* first cut of embeddings API

https://github.com/microsoft/vscode/issues/212083

* add event

* fix tests
This commit is contained in:
Johannes Rieken 2024-05-06 16:50:13 +02:00 committed by GitHub
parent facea7e2c2
commit 6874fc7394
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 277 additions and 1 deletions

View file

@ -19,6 +19,7 @@
"documentFiltersExclusive",
"documentPaste",
"editorInsets",
"embeddings",
"extensionRuntime",
"extensionsAny",
"externalUriOpener",
@ -80,7 +81,7 @@
"id": "api-test.participant2",
"name": "participant2",
"description": "test",
"commands": [ ]
"commands": []
}
],
"configuration": {

View file

@ -20,6 +20,7 @@ import './mainThreadBulkEdits';
import './mainThreadLanguageModels';
import './mainThreadChatAgents2';
import './mainThreadChatVariables';
import './mainThreadEmbeddings';
import './mainThreadCodeInsets';
import './mainThreadCLICommands';
import './mainThreadClipboard';

View file

@ -0,0 +1,112 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CancellationToken } from 'vs/base/common/cancellation';
import { Emitter, Event } from 'vs/base/common/event';
import { DisposableMap, DisposableStore, IDisposable } from 'vs/base/common/lifecycle';
import { InstantiationType, registerSingleton } from 'vs/platform/instantiation/common/extensions';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
import { ExtHostContext, ExtHostEmbeddingsShape, MainContext, MainThreadEmbeddingsShape } from 'vs/workbench/api/common/extHost.protocol';
import { extHostNamedCustomer, IExtHostContext } from 'vs/workbench/services/extensions/common/extHostCustomers';
interface IEmbeddingsProvider {
provideEmbeddings(input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
}
const IEmbeddingsService = createDecorator<IEmbeddingsService>('embeddingsService');
interface IEmbeddingsService {
_serviceBrand: undefined;
readonly onDidChange: Event<void>;
allProviders: Iterable<string>;
registerProvider(id: string, provider: IEmbeddingsProvider): IDisposable;
computeEmbeddings(id: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
}
class EmbeddingsService implements IEmbeddingsService {
_serviceBrand: undefined;
private providers: Map<string, IEmbeddingsProvider>;
private readonly _onDidChange = new Emitter<void>();
readonly onDidChange: Event<void> = this._onDidChange.event;
constructor() {
this.providers = new Map<string, IEmbeddingsProvider>();
}
get allProviders(): Iterable<string> {
return this.providers.keys();
}
registerProvider(id: string, provider: IEmbeddingsProvider): IDisposable {
this.providers.set(id, provider);
this._onDidChange.fire();
return {
dispose: () => {
this.providers.delete(id);
this._onDidChange.fire();
}
};
}
computeEmbeddings(id: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
const provider = this.providers.get(id);
if (provider) {
return provider.provideEmbeddings(input, token);
} else {
return Promise.reject(new Error(`No embeddings provider registered with id: ${id}`));
}
}
}
registerSingleton(IEmbeddingsService, EmbeddingsService, InstantiationType.Delayed);
@extHostNamedCustomer(MainContext.MainThreadEmbeddings)
export class MainThreadEmbeddings implements MainThreadEmbeddingsShape {
private readonly _store = new DisposableStore();
private readonly _providers = this._store.add(new DisposableMap<number>);
private readonly _proxy: ExtHostEmbeddingsShape;
constructor(
context: IExtHostContext,
@IEmbeddingsService private readonly embeddingsService: IEmbeddingsService
) {
this._proxy = context.getProxy(ExtHostContext.ExtHostEmbeddings);
this._store.add(embeddingsService.onDidChange((() => {
this._proxy.$acceptEmbeddingModels(Array.from(embeddingsService.allProviders));
})));
}
dispose(): void {
this._store.dispose();
}
$registerEmbeddingProvider(handle: number, identifier: string): void {
const registration = this.embeddingsService.registerProvider(identifier, {
provideEmbeddings: (input: string[], token: CancellationToken): Promise<{ values: number[] }[]> => {
return this._proxy.$provideEmbeddings(handle, input, token);
}
});
this._providers.set(handle, registration);
}
$unregisterEmbeddingProvider(handle: number): void {
this._providers.deleteAndDispose(handle);
}
$computeEmbeddings(embeddingsModel: string, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
return this.embeddingsService.computeEmbeddings(embeddingsModel, input, token);
}
}

View file

@ -44,6 +44,7 @@ import { ExtHostDocumentSaveParticipant } from 'vs/workbench/api/common/extHostD
import { ExtHostDocuments } from 'vs/workbench/api/common/extHostDocuments';
import { IExtHostDocumentsAndEditors } from 'vs/workbench/api/common/extHostDocumentsAndEditors';
import { IExtHostEditorTabs } from 'vs/workbench/api/common/extHostEditorTabs';
import { ExtHostEmbeddings } from 'vs/workbench/api/common/extHostEmbedding';
import { ExtHostAiEmbeddingVector } from 'vs/workbench/api/common/extHostEmbeddingVector';
import { Extension, IExtHostExtensionService } from 'vs/workbench/api/common/extHostExtensionService';
import { ExtHostFileSystem } from 'vs/workbench/api/common/extHostFileSystem';
@ -214,6 +215,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
const extHostAiEmbeddingVector = rpcProtocol.set(ExtHostContext.ExtHostAiEmbeddingVector, new ExtHostAiEmbeddingVector(rpcProtocol));
const extHostStatusBar = rpcProtocol.set(ExtHostContext.ExtHostStatusBar, new ExtHostStatusBar(rpcProtocol, extHostCommands.converter));
const extHostSpeech = rpcProtocol.set(ExtHostContext.ExtHostSpeech, new ExtHostSpeech(rpcProtocol));
const extHostEmbeddings = rpcProtocol.set(ExtHostContext.ExtHostEmbeddings, new ExtHostEmbeddings(rpcProtocol));
// Check that no named customers are missing
const expected = Object.values<ProxyIdentifier<any>>(ExtHostContext);
@ -1450,6 +1452,27 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
getLanguageModelInformation(languageModel: string) {
checkProposedApiEnabled(extension, 'languageModels');
return extHostLanguageModels.getLanguageModelInfo(languageModel);
},
// --- embeddings
get embeddingModels() {
checkProposedApiEnabled(extension, 'embeddings');
return extHostEmbeddings.embeddingsModels;
},
onDidChangeEmbeddingModels: (listener, thisArgs?, disposables?) => {
checkProposedApiEnabled(extension, 'embeddings');
return extHostEmbeddings.onDidChange(listener, thisArgs, disposables);
},
registerEmbeddingsProvider(embeddingsModel, provider) {
checkProposedApiEnabled(extension, 'embeddings');
return extHostEmbeddings.registerEmbeddingsProvider(extension, embeddingsModel, provider);
},
async computeEmbeddings(embeddingsModel, input, token?): Promise<any> {
checkProposedApiEnabled(extension, 'embeddings');
if (typeof input === 'string') {
return extHostEmbeddings.computeEmbeddings(embeddingsModel, input, token);
} else {
return extHostEmbeddings.computeEmbeddings(embeddingsModel, input, token);
}
}
};

View file

@ -1216,6 +1216,17 @@ export interface ExtHostLanguageModelsShape {
$provideTokenLength(handle: number, value: string | IChatMessage, token: CancellationToken): Promise<number>;
}
export interface MainThreadEmbeddingsShape extends IDisposable {
$registerEmbeddingProvider(handle: number, identifier: string): void;
$unregisterEmbeddingProvider(handle: number): void;
$computeEmbeddings(embeddingsModel: string, input: string[], token: CancellationToken): Promise<({ values: number[] }[])>;
}
export interface ExtHostEmbeddingsShape {
$provideEmbeddings(handle: number, input: string[], token: CancellationToken): Promise<{ values: number[] }[]>;
$acceptEmbeddingModels(models: string[]): void;
}
export interface IExtensionChatAgentMetadata extends Dto<IChatAgentMetadata> {
hasFollowups?: boolean;
}
@ -2774,6 +2785,7 @@ export const MainContext = {
MainThreadAuthentication: createProxyIdentifier<MainThreadAuthenticationShape>('MainThreadAuthentication'),
MainThreadBulkEdits: createProxyIdentifier<MainThreadBulkEditsShape>('MainThreadBulkEdits'),
MainThreadLanguageModels: createProxyIdentifier<MainThreadLanguageModelsShape>('MainThreadLanguageModels'),
MainThreadEmbeddings: createProxyIdentifier<MainThreadEmbeddingsShape>('MainThreadEmbeddings'),
MainThreadChatAgents2: createProxyIdentifier<MainThreadChatAgentsShape2>('MainThreadChatAgents2'),
MainThreadChatVariables: createProxyIdentifier<MainThreadChatVariablesShape>('MainThreadChatVariables'),
MainThreadClipboard: createProxyIdentifier<MainThreadClipboardShape>('MainThreadClipboard'),
@ -2897,6 +2909,7 @@ export const ExtHostContext = {
ExtHostChatVariables: createProxyIdentifier<ExtHostChatVariablesShape>('ExtHostChatVariables'),
ExtHostChatProvider: createProxyIdentifier<ExtHostLanguageModelsShape>('ExtHostChatProvider'),
ExtHostSpeech: createProxyIdentifier<ExtHostSpeechShape>('ExtHostSpeech'),
ExtHostEmbeddings: createProxyIdentifier<ExtHostEmbeddingsShape>('ExtHostEmbeddings'),
ExtHostAiRelatedInformation: createProxyIdentifier<ExtHostAiRelatedInformationShape>('ExtHostAiRelatedInformation'),
ExtHostAiEmbeddingVector: createProxyIdentifier<ExtHostAiEmbeddingVectorShape>('ExtHostAiEmbeddingVector'),
ExtHostTheming: createProxyIdentifier<ExtHostThemingShape>('ExtHostTheming'),

View file

@ -0,0 +1,92 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
import { CancellationToken } from 'vs/base/common/cancellation';
import { Emitter, Event } from 'vs/base/common/event';
import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
import { IExtensionDescription } from 'vs/platform/extensions/common/extensions';
import { ExtHostEmbeddingsShape, IMainContext, MainContext, MainThreadEmbeddingsShape } from 'vs/workbench/api/common/extHost.protocol';
import type * as vscode from 'vscode';
export class ExtHostEmbeddings implements ExtHostEmbeddingsShape {
private readonly _proxy: MainThreadEmbeddingsShape;
private readonly _provider = new Map<number, { id: string; provider: vscode.EmbeddingsProvider }>();
private readonly _onDidChange = new Emitter<void>();
readonly onDidChange: Event<void> = this._onDidChange.event;
private _allKnownModels = new Set<string>();
private _handlePool: number = 0;
constructor(
mainContext: IMainContext
) {
this._proxy = mainContext.getProxy(MainContext.MainThreadEmbeddings);
}
registerEmbeddingsProvider(_extension: IExtensionDescription, embeddingsModel: string, provider: vscode.EmbeddingsProvider): IDisposable {
if (this._allKnownModels.has(embeddingsModel)) {
throw new Error('An embeddings provider for this model is already registered');
}
const handle = this._handlePool++;
this._proxy.$registerEmbeddingProvider(handle, embeddingsModel);
this._provider.set(handle, { id: embeddingsModel, provider });
return toDisposable(() => {
this._proxy.$unregisterEmbeddingProvider(handle);
this._provider.delete(handle);
});
}
async computeEmbeddings(embeddingsModel: string, input: string, token?: vscode.CancellationToken): Promise<vscode.Embedding>;
async computeEmbeddings(embeddingsModel: string, input: string[], token?: vscode.CancellationToken): Promise<vscode.Embedding[]>;
async computeEmbeddings(embeddingsModel: string, input: string | string[], token?: vscode.CancellationToken): Promise<vscode.Embedding[] | vscode.Embedding> {
token ??= CancellationToken.None;
let returnSingle = false;
if (typeof input === 'string') {
input = [input];
returnSingle = true;
}
const result = await this._proxy.$computeEmbeddings(embeddingsModel, input, token);
if (result.length !== input.length) {
throw new Error();
}
if (returnSingle) {
if (result.length !== 1) {
throw new Error();
}
return result[0];
}
return result;
}
async $provideEmbeddings(handle: number, input: string[], token: CancellationToken): Promise<{ values: number[] }[]> {
const data = this._provider.get(handle);
if (!data) {
return [];
}
const result = await data.provider.provideEmbeddings(input, token);
if (!result) {
return [];
}
return result;
}
get embeddingsModels(): string[] {
return Array.from(this._allKnownModels);
}
$acceptEmbeddingModels(models: string[]): void {
this._allKnownModels = new Set(models);
this._onDidChange.fire();
}
}

View file

@ -60,6 +60,7 @@ export const allApiProposals = Object.freeze({
editSessionIdentityProvider: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editSessionIdentityProvider.d.ts',
editorHoverVerbosityLevel: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editorHoverVerbosityLevel.d.ts',
editorInsets: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.editorInsets.d.ts',
embeddings: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.embeddings.d.ts',
extensionRuntime: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.extensionRuntime.d.ts',
extensionsAny: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.extensionsAny.d.ts',
externalUriOpener: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.externalUriOpener.d.ts',

View file

@ -0,0 +1,33 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/
declare module 'vscode' {
// https://github.com/microsoft/vscode/issues/212083
export interface Embedding {
readonly values: number[];
}
// TODO@API strictly not the right namespace...
export namespace lm {
export const embeddingModels: string[];
export const onDidChangeEmbeddingModels: Event<void>;
export function computeEmbeddings(embeddingsModel: string, input: string, token?: CancellationToken): Thenable<Embedding>;
export function computeEmbeddings(embeddingsModel: string, input: string[], token?: CancellationToken): Thenable<Embedding[]>;
}
export interface EmbeddingsProvider {
provideEmbeddings(input: string[], token: CancellationToken): ProviderResult<Embedding[]>;
}
export namespace lm {
export function registerEmbeddingsProvider(embeddingsModel: string, provider: EmbeddingsProvider): Disposable;
}
}