-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathopenai-language-models-manager-impl.ts
117 lines (105 loc) · 4.69 KB
/
openai-language-models-manager-impl.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// *****************************************************************************
// Copyright (C) 2024 EclipseSource GmbH.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// http://www.eclipse.org/legal/epl-2.0.
//
// This Source Code may also be made available under the following Secondary
// Licenses when the conditions for such availability set forth in the Eclipse
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
// with the GNU Classpath Exception which is available at
// https://www.gnu.org/software/classpath/license.html.
//
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
// *****************************************************************************
import { LanguageModelRegistry } from '@theia/ai-core';
import { inject, injectable } from '@theia/core/shared/inversify';
import { OpenAiModel, OpenAiModelUtils } from './openai-language-model';
import { OpenAiLanguageModelsManager, OpenAiModelDescription } from '../common';
@injectable()
export class OpenAiLanguageModelsManagerImpl implements OpenAiLanguageModelsManager {
@inject(OpenAiModelUtils)
protected readonly openAiModelUtils: OpenAiModelUtils;
protected _apiKey: string | undefined;
protected _apiVersion: string | undefined;
@inject(LanguageModelRegistry)
protected readonly languageModelRegistry: LanguageModelRegistry;
get apiKey(): string | undefined {
return this._apiKey ?? process.env.OPENAI_API_KEY;
}
get apiVersion(): string | undefined {
return this._apiVersion ?? process.env.OPENAI_API_VERSION;
}
// Triggered from frontend. In case you want to use the models on the backend
// without a frontend then call this yourself
async createOrUpdateLanguageModels(...modelDescriptions: OpenAiModelDescription[]): Promise<void> {
for (const modelDescription of modelDescriptions) {
const model = await this.languageModelRegistry.getLanguageModel(modelDescription.id);
const apiKeyProvider = () => {
if (modelDescription.apiKey === true) {
return this.apiKey;
}
if (modelDescription.apiKey) {
return modelDescription.apiKey;
}
return undefined;
};
const apiVersionProvider = () => {
if (modelDescription.apiVersion === true) {
return this.apiVersion;
}
if (modelDescription.apiVersion) {
return modelDescription.apiVersion;
}
return undefined;
};
if (model) {
if (!(model instanceof OpenAiModel)) {
console.warn(`OpenAI: model ${modelDescription.id} is not an OpenAI model`);
continue;
}
model.model = modelDescription.model;
model.enableStreaming = modelDescription.enableStreaming;
model.url = modelDescription.url;
model.apiKey = apiKeyProvider;
model.apiVersion = apiVersionProvider;
model.developerMessageSettings = modelDescription.developerMessageSettings || 'developer';
model.supportsStructuredOutput = modelDescription.supportsStructuredOutput;
model.defaultRequestSettings = modelDescription.defaultRequestSettings;
} else {
this.languageModelRegistry.addLanguageModels([
new OpenAiModel(
modelDescription.id,
modelDescription.model,
modelDescription.enableStreaming,
apiKeyProvider,
apiVersionProvider,
modelDescription.supportsStructuredOutput,
modelDescription.url,
this.openAiModelUtils,
modelDescription.developerMessageSettings,
modelDescription.defaultRequestSettings
)
]);
}
}
}
removeLanguageModels(...modelIds: string[]): void {
this.languageModelRegistry.removeLanguageModels(modelIds);
}
setApiKey(apiKey: string | undefined): void {
if (apiKey) {
this._apiKey = apiKey;
} else {
this._apiKey = undefined;
}
}
setApiVersion(apiVersion: string | undefined): void {
if (apiVersion) {
this._apiVersion = apiVersion;
} else {
this._apiVersion = undefined;
}
}
}