register vendors statically and active extensions on them

This commit is contained in:
Johannes 2024-05-08 15:33:33 +02:00
parent 6990923699
commit 4baa94788e
No known key found for this signature in database
GPG key ID: 6DEF802A22264FCA
2 changed files with 114 additions and 19 deletions

View file

@ -5,12 +5,16 @@
import { CancellationToken } from 'vs/base/common/cancellation';
import { Emitter, Event } from 'vs/base/common/event';
import { Iterable } from 'vs/base/common/iterator';
import { IJSONSchema } from 'vs/base/common/jsonSchema';
import { IDisposable, toDisposable } from 'vs/base/common/lifecycle';
import { isEmptyObject } from 'vs/base/common/types';
import { isFalsyOrWhitespace } from 'vs/base/common/strings';
import { localize } from 'vs/nls';
import { ExtensionIdentifier } from 'vs/platform/extensions/common/extensions';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
import { IProgress } from 'vs/platform/progress/common/progress';
import { IExtensionService } from 'vs/workbench/services/extensions/common/extensions';
import { IExtensionService, isProposedApiEnabled } from 'vs/workbench/services/extensions/common/extensions';
import { ExtensionsRegistry } from 'vs/workbench/services/extensions/common/extensionsRegistry';
export const enum ChatMessageRole {
System,
@ -81,17 +85,93 @@ export interface ILanguageModelsService {
computeTokenLength(identifier: string, message: string | IChatMessage, token: CancellationToken): Promise<number>;
}
const languageModelType: IJSONSchema = {
type: 'object',
properties: {
vendor: {
type: 'string',
description: localize('vscode.extension.contributes.languageModels.vendor', "A globally unique vendor of language models.")
}
}
};
interface IUserFriendlyLanguageModel {
vendor: string;
}
export const languageModelExtensionPoint = ExtensionsRegistry.registerExtensionPoint<IUserFriendlyLanguageModel | IUserFriendlyLanguageModel[]>({
extensionPoint: 'languageModels',
jsonSchema: {
description: localize('vscode.extension.contributes.languageModels', "Contribute language models of a specific vendor."),
oneOf: [
languageModelType,
{
type: 'array',
items: languageModelType
}
]
},
activationEventsGenerator: (contribs: IUserFriendlyLanguageModel[], result: { push(item: string): void }) => {
for (const contrib of contribs) {
result.push(`onLanguageModel:${contrib.vendor}`);
}
}
});
export class LanguageModelsService implements ILanguageModelsService {
readonly _serviceBrand: undefined;
private readonly _providers: Map<string, ILanguageModelChat> = new Map();
private readonly _providers = new Map<string, ILanguageModelChat>();
private readonly _vendors = new Set<string>();
private readonly _onDidChangeProviders = new Emitter<{ added?: ILanguageModelChatMetadata[]; removed?: string[] }>();
readonly onDidChangeLanguageModels: Event<{ added?: ILanguageModelChatMetadata[]; removed?: string[] }> = this._onDidChangeProviders.event;
constructor(
@IExtensionService private readonly _extensionService: IExtensionService,
) { }
) {
languageModelExtensionPoint.setHandler((extensions) => {
this._vendors.clear();
for (const extension of extensions) {
if (!isProposedApiEnabled(extension.description, 'chatProvider')) {
extension.collector.error(localize('vscode.extension.contributes.languageModels.chatProviderRequired', "This contribution point requires the 'chatProvider' proposal."));
continue;
}
for (const item of Iterable.wrap(extension.value)) {
if (this._vendors.has(item.vendor)) {
extension.collector.error(localize('vscode.extension.contributes.languageModels.vendorAlreadyRegistered', "The vendor '{0}' is already registered and cannot be registered twice", item.vendor));
continue;
}
if (isFalsyOrWhitespace(item.vendor)) {
extension.collector.error(localize('vscode.extension.contributes.languageModels.emptyVendor', "The vendor field cannot be empty."));
continue;
}
if (item.vendor.trim() !== item.vendor) {
extension.collector.error(localize('vscode.extension.contributes.languageModels.whitespaceVendor', "The vendor field cannot start or end with whitespace."));
continue;
}
this._vendors.add(item.vendor);
}
}
const removed: string[] = [];
for (const [key, value] of this._providers) {
if (!this._vendors.has(value.metadata.vendor)) {
this._providers.delete(key);
removed.push(key);
}
}
if (removed.length > 0) {
this._onDidChangeProviders.fire({ removed });
}
});
}
dispose() {
this._onDidChangeProviders.dispose();
@ -107,11 +187,20 @@ export class LanguageModelsService implements ILanguageModelsService {
}
async selectLanguageModels(selector: ILanguageModelChatSelector): Promise<string[]> {
await this._extensionService.activateByEvent(`onLanguageModelChat:${selector.vendor ?? '*'}}`);
if (selector.vendor) {
// selective activation
await this._extensionService.activateByEvent(`onLanguageModelChat:${selector.vendor}}`);
} else {
// activate all extensions that do language models
const all = Array.from(this._vendors).map(vendor => this._extensionService.activateByEvent(`onLanguageModelChat:${vendor}`));
await Promise.all(all);
}
const result: string[] = [];
for (const model of this._providers.values()) {
if (selector.vendor !== undefined && model.metadata.vendor === selector.vendor
|| selector.family !== undefined && model.metadata.family === selector.family
|| selector.version !== undefined && model.metadata.version === selector.version
@ -121,7 +210,12 @@ export class LanguageModelsService implements ILanguageModelsService {
// true selection
result.push(model.metadata.identifier);
} else if (!selector || isEmptyObject(selector)) {
} else if (!selector || (
selector.vendor === undefined
&& selector.family === undefined
&& selector.version === undefined
&& selector.identifier === undefined)
) {
// no selection
result.push(model.metadata.identifier);
}
@ -131,6 +225,11 @@ export class LanguageModelsService implements ILanguageModelsService {
}
registerLanguageModelChat(identifier: string, provider: ILanguageModelChat): IDisposable {
if (!this._vendors.has(provider.metadata.vendor)) {
// throw new Error(`Chat response provider uses UNKNOWN vendor ${provider.metadata.vendor}.`);
console.warn('USING UNKNOWN vendor', provider.metadata.vendor);
this._vendors.add(provider.metadata.vendor);
}
if (this._providers.has(identifier)) {
throw new Error(`Chat response provider with identifier ${identifier} is already registered.`);
}

View file

@ -47,8 +47,6 @@ declare module 'vscode' {
role: LanguageModelChatMessageRole.Assistant;
content: AsyncIterable<string>;
};
}
/**
@ -95,12 +93,6 @@ declare module 'vscode' {
constructor(role: LanguageModelChatMessageRole, content: string, name?: string);
}
// ---------------------------
// Language Model Object (V2)
// (+) can pick by id or family
// (++) makes it harder to hardcode an identifier of a model in source code
// TODO@API name LanguageModelChatEndpoint
export interface LanguageModelChat {
/**
@ -130,6 +122,8 @@ declare module 'vscode' {
// TODO@API
// max_prompt_tokens vs output_tokens vs context_size
// readonly inputTokens: number;
// readonly outputTokens: number;
readonly contextSize: number;
/**
@ -174,9 +168,13 @@ declare module 'vscode' {
export interface LanguageModelChatSelector {
vendor?: string; // TODO@API make required?
// TODO@API make required?
vendor?: string;
family?: string;
version?: string;
id?: string;
// TODO@API tokens? min/max etc
}
@ -250,12 +248,10 @@ declare module 'vscode' {
* Select chat models by a {@link LanguageModelChatSelector selector}. This can yield in multiple or no chat models
* and extension must handle these cases, esp when no chat model exists.
*
* @param selector A chat model selector.
* @param selector A chat model selector. When omitted all chat models are returned.
* @returns An array of chat models or `undefined` when no chat model was selected.
*/
// (++) lazy activation
// (++) give specific LM to some extension
export function selectChatModels(selector: LanguageModelChatSelector): Thenable<LanguageModelChat[] | undefined>;
export function selectChatModels(selector?: LanguageModelChatSelector): Thenable<LanguageModelChat[] | undefined>;
}
/**