6
6
from common .utils import (
7
7
get_prompt_hash ,
8
8
get_qdrant_collection_name ,
9
+ get_query_with_prefix ,
9
10
get_relevant_document_count_hash ,
10
11
)
11
12
from repository .repository import Repository
@@ -27,12 +28,21 @@ def __init__(
27
28
vectors_config : VectorParams ,
28
29
vectorizer : Vectorizer ,
29
30
cache : Cache ,
31
+ passage_prefix : str = "" ,
32
+ query_prefix : str = "" ,
30
33
):
31
34
self .qdrant = client
32
35
self .collection_name = collection_name
33
36
self .model_name = model_name
34
37
self .vectorizer = vectorizer
35
38
self .cache = cache
39
+ self .passage_prefix = passage_prefix
40
+ self .query_prefix = query_prefix
41
+ self .distance = (
42
+ Distance .COSINE
43
+ if Distance .COSINE .lower () in collection_name .lower ()
44
+ else Distance .EUCLID
45
+ )
36
46
37
47
collections = self .qdrant .get_collections ()
38
48
if collection_name not in [
@@ -53,7 +63,9 @@ def insert_one(self, passage: Passage):
53
63
points = [
54
64
PointStruct (
55
65
id = str (uuid .uuid4 ()),
56
- vector = self .vectorizer .get_vector (passage .text ),
66
+ vector = self .vectorizer .get_vector (
67
+ get_query_with_prefix (passage .context , self .passage_prefix )
68
+ ),
57
69
payload = passage .dict (),
58
70
)
59
71
],
@@ -76,7 +88,9 @@ def insert_many(self, passages: List[Passage]):
76
88
points = [
77
89
PointStruct (
78
90
id = str (uuid .uuid4 ()),
79
- vector = self .vectorizer .get_vector (passage .text ),
91
+ vector = self .vectorizer .get_vector (
92
+ get_query_with_prefix (passage .context , self .passage_prefix )
93
+ ),
80
94
payload = passage .dict (),
81
95
)
82
96
for passage in passages
@@ -86,7 +100,7 @@ def insert_many(self, passages: List[Passage]):
86
100
collection_name = self .collection_name , wait = True , points = points
87
101
)
88
102
89
- def insert_many_with_vectors (self , passages : List [Tuple ]):
103
+ def insert_many_with_vectors (self , passages : List [Tuple [ Passage , List [ float ]] ]):
90
104
points = [
91
105
PointStruct (
92
106
id = str (uuid .uuid4 ()),
@@ -101,7 +115,10 @@ def insert_many_with_vectors(self, passages: List[Tuple]):
101
115
)
102
116
103
117
def find (self , query : str , dataset_key : str ) -> Result :
104
- hash_key = get_prompt_hash (self .model_name , query )
118
+ full_query = get_query_with_prefix (query , self .query_prefix )
119
+ hash_key = get_prompt_hash (
120
+ self .model_name , dataset_key , full_query , self .distance
121
+ )
105
122
106
123
cached_value = self .cache .get (hash_key )
107
124
@@ -110,7 +127,7 @@ def find(self, query: str, dataset_key: str) -> Result:
110
127
passages = [(Passage .from_dict (d ["passage" ]), d ["score" ]) for d in dicts ]
111
128
return Result (query , passages )
112
129
113
- vector = self .vectorizer .get_vector (query )
130
+ vector = self .vectorizer .get_vector (full_query )
114
131
115
132
data = self .qdrant .search (
116
133
collection_name = self .collection_name ,
@@ -164,6 +181,8 @@ def get_repository(
164
181
model_name : str ,
165
182
distance : Distance ,
166
183
cache : Cache ,
184
+ passage_prefix : str = "" ,
185
+ query_prefix : str = "" ,
167
186
):
168
187
collection_name = get_qdrant_collection_name (model_name , distance )
169
188
vectorizer = HFVectorizer (model_name , cache )
@@ -175,6 +194,8 @@ def get_repository(
175
194
VectorParams (size = MODEL_DIMENSIONS_MAP [model_name ], distance = distance ),
176
195
vectorizer ,
177
196
cache ,
197
+ passage_prefix ,
198
+ query_prefix ,
178
199
)
179
200
180
201
def count_relevant_documents (self , passage_id , dataset_key ) -> int :
0 commit comments