Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Together AI in OpenAIEmbeddings wrapper #304

Merged
merged 1 commit into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/_sidebar.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
- [GCP Vertex AI](/modules/retrieval/text_embedding/integrations/gcp_vertex_ai.md)
- [Ollama](/modules/retrieval/text_embedding/integrations/ollama.md)
- [Mistral AI](/modules/retrieval/text_embedding/integrations/mistralai.md)
- [Together AI](/modules/retrieval/text_embedding/integrations/together_ai.md)
- [Prem App](/modules/retrieval/text_embedding/integrations/prem.md)
- [Vector stores](/modules/retrieval/vector_stores/vector_stores.md)
- Integrations
Expand Down
25 changes: 18 additions & 7 deletions docs/modules/retrieval/text_embedding/integrations/openai.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# OpenAI
# OpenAIEmbeddings

Let's load the OpenAI Embedding class.
You can use the `OpenAIEmbeddings` wrapper to consume OpenAI embedding models.

```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openaiApiKey);
const text = 'This is a test document.';
final res = await embeddings.embedQuery(text);
final res = await embeddings.embedDocuments([text]);
final openAiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openAiApiKey);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.003105443, 0.011136302, -0.0040295827, -0.011749065, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.005047946, 0.0050882488, -0.0051957234, -0.019143905, ...]

embeddings.close();
```
30 changes: 30 additions & 0 deletions docs/modules/retrieval/text_embedding/integrations/together_ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Together AI Embeddings

[Together AI](https://www.together.ai/) offers several leading [embedding models](https://docs.together.ai/docs/embedding-models#embedding-models) through its OpenAI compatible API.

You can consume Together AI API using the `OpenAIEmbeddings` wrapper in the same way you would use the OpenAI API.

The only difference is that you need to change the base URL to `https://api.together.xyz/v1`:

```dart
final togetherAiApiKey = Platform.environment['TOGETHER_AI_API_KEY'];
final embeddings = OpenAIEmbeddings(
apiKey: togetherAiApiKey,
baseUrl: 'https://api.together.xyz/v1',
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.038838703, 0.0580902, 0.022614542, 0.0078403875, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.019722218, 0.04656633, -0.0074559706, 0.005712764, ...]

embeddings.close();
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// ignore_for_file: avoid_print
import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_openai/langchain_openai.dart';

void main(final List<String> arguments) async {
final openAiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openAiApiKey);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.003105443, 0.011136302, -0.0040295827, -0.011749065, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.005047946, 0.0050882488, -0.0051957234, -0.019143905, ...]

embeddings.close();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// ignore_for_file: avoid_print
import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_openai/langchain_openai.dart';

void main(final List<String> arguments) async {
final togetherAiApiKey = Platform.environment['TOGETHER_AI_API_KEY'];
final embeddings = OpenAIEmbeddings(
apiKey: togetherAiApiKey,
baseUrl: 'https://api.together.xyz/v1',
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.038838703, 0.0580902, 0.022614542, 0.0078403875, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.019722218, 0.04656633, -0.0074559706, 0.005712764, ...]

embeddings.close();
}
9 changes: 8 additions & 1 deletion packages/langchain_openai/lib/src/embeddings/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import 'package:openai_dart/openai_dart.dart';
/// - [Embeddings guide](https://platform.openai.com/docs/guides/embeddings/limitations-risks)
/// - [Embeddings API docs](https://platform.openai.com/docs/api-reference/embeddings)
///
/// You can also use this wrapper to consume OpenAI-compatible APIs like [Together AI](https://www.together.ai).
///
/// ### Authentication
///
/// The OpenAI API uses API keys for authentication. Visit your
Expand Down Expand Up @@ -122,7 +124,7 @@ class OpenAIEmbeddings implements Embeddings {
OpenAIEmbeddings({
final String? apiKey,
final String? organization,
final String? baseUrl,
final String baseUrl = 'https://api.openai.com/v1',
final Map<String, String>? headers,
final Map<String, dynamic>? queryParams,
final http.Client? client,
Expand Down Expand Up @@ -199,4 +201,9 @@ class OpenAIEmbeddings implements Embeddings {
);
return data.data.first.embeddingVector;
}

/// Closes the client and cleans up any resources associated with it.
void close() {
_client.endSession();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:io';

import 'package:langchain_openai/langchain_openai.dart';
import 'package:test/test.dart';

void main() {
group('Together AI Embeddings tests', () {
late OpenAIEmbeddings embeddings;

setUp(() async {
embeddings = OpenAIEmbeddings(
apiKey: Platform.environment['TOGETHER_AI_API_KEY'],
baseUrl: 'https://api.together.xyz/v1',
);
});

tearDown(() {
embeddings.close();
});

test('Test AI Embeddings models', () async {
final models = [
'togethercomputer/m2-bert-80M-2k-retrieval',
'togethercomputer/m2-bert-80M-8k-retrieval',
'togethercomputer/m2-bert-80M-32k-retrieval',
'WhereIsAI/UAE-Large-V1',
'BAAI/bge-large-en-v1.5',
'BAAI/bge-base-en-v1.5',
'sentence-transformers/msmarco-bert-base-dot-v5',
'bert-base-uncased',
];
for (final model in models) {
embeddings.model = model;
final res = await embeddings.embedQuery('Hello world');
expect(res.length, greaterThan(0));
await Future<void>.delayed(const Duration(seconds: 1)); // Rate limit
}
});
});
}