Skip to content

Commit ba55733

Browse files
committedAug 7, 2024
Update multiple retriever related stuff
1 parent a1d7211 commit ba55733

14 files changed

+173
-64
lines changed
 

‎docker-compose.yml

+10-5
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ services:
1212
environment:
1313
- discovery.type=single-node
1414
- xpack.security.enabled=false
15+
- ES_JAVA_OPTS=-Xms2g -Xmx2g
1516
ports:
1617
- 9200:9200
1718
volumes:
@@ -31,14 +32,18 @@ services:
3132
- kibana_data:/usr/share/kibana/data
3233
depends_on:
3334
- elasticsearch
34-
redis:
35-
image: redis:latest
35+
mongo:
36+
image: mongo:latest
3637
ports:
37-
- 6379:6379
38+
- 27017:27017
3839
volumes:
39-
- redis_data:/data
40+
- mongo_data:/data/db
41+
deploy:
42+
resources:
43+
limits:
44+
memory: 2g
4045
volumes:
4146
qdrant_data:
4247
elasticsearch_data:
4348
kibana_data:
44-
redis_data:
49+
mongo_data:

‎src/cache/cache.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
1-
import redis
1+
from pymongo import MongoClient
22

33
SIX_MONTHS = 60 * 60 * 24 * 30 * 6
44

55

66
class Cache:
77
def __init__(self):
8-
self.redis = redis.Redis(host="localhost", port=6379, db=0)
8+
mongo_client = MongoClient("mongodb://localhost:27017/")
9+
db = mongo_client["polish-nl-qa"]
10+
self.key_value_collection = db["key_value"]
911

1012
def get(self, key):
11-
return self.redis.get(key)
13+
maybe_cached_value = self.key_value_collection.find_one({"key": key})
14+
15+
if maybe_cached_value is None:
16+
return None
17+
18+
return maybe_cached_value["value"]
1219

1320
def set(self, key, value):
14-
return self.redis.set(key, value, SIX_MONTHS)
21+
self.key_value_collection.delete_many({"key": key})
22+
item = {"key": key, "value": value}
23+
self.key_value_collection.insert_one(item)
24+
25+
def unset(self, key):
26+
return self.key_value_collection.delete_one({"key": key})

‎src/clear_cache.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@
33
r = redis.StrictRedis(host="localhost", port=6379, db=0)
44

55
# comment out the keys you want to keep
6-
prefixes = [
7-
"count:*",
8-
"vectorizer:*",
9-
"prompt:*",
10-
"reranker:*",
11-
"query:*",
12-
]
6+
prefixes = ["count:*", "vectorizer:*", "prompt:*", "reranker:*", "query:*", "score:*"]
137

148
for prefix in prefixes:
159
print(f"Clearing keys with prefix: {prefix}")

‎src/common/names.py

+22
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,22 @@
1010
"BAAI/bge-m3",
1111
]
1212

13+
QUERY_PREFIX_MAP = {
14+
"sdadas/mmlw-retrieval-roberta-large": "zapytanie: ",
15+
"ipipan/silver-retriever-base-v1": "Pytanie: ",
16+
"intfloat/multilingual-e5-large": "query: ",
17+
"sdadas/mmlw-roberta-large": "zapytanie: ",
18+
"BAAI/bge-m3": "",
19+
}
20+
21+
PASSAGE_PREFIX_MAP = {
22+
"sdadas/mmlw-retrieval-roberta-large": "",
23+
"ipipan/silver-retriever-base-v1": "",
24+
"intfloat/multilingual-e5-large": "passage: ",
25+
"sdadas/mmlw-roberta-large": "",
26+
"BAAI/bge-m3": "",
27+
}
28+
1329
DISTANCES = [Distance.COSINE, Distance.EUCLID]
1430

1531
INDEX_NAMES = [
@@ -22,4 +38,10 @@
2238
"morfologik_stopwords_index",
2339
]
2440

41+
RERANKER_MODEL_NAMES = [
42+
"sdadas/polish-reranker-large-ranknet",
43+
"BAAI/bge-reranker-v2-gemma",
44+
"unicamp-dl/mt5-13b-mmarco-100k",
45+
]
46+
2547
CHUNK_SIZES = [(500, 100), (1000, 200), (2000, 500), (100000, 0)]

‎src/common/qdrant_data_importer.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import List
21
from common.passage_factory import PassageFactory
2+
from common.utils import get_query_with_prefix
33
from repository.qdrant_repository import QdrantRepository
44
from vectorizer.vectorizer import Vectorizer
55

@@ -10,17 +10,24 @@ def __init__(
1010
repository: QdrantRepository,
1111
passage_factory: PassageFactory,
1212
vectorizer: Vectorizer,
13+
prefix: str = "",
1314
):
1415
self.repository = repository
1516
self.passage_factory = passage_factory
1617
self.vectorizer = vectorizer
18+
self.prefix = prefix
1719

1820
def import_data(self):
1921
passages = self.passage_factory.get_passages()
2022

2123
for i in range(0, len(passages), 10):
2224
passages_and_vectors = [
23-
(passage, self.vectorizer.get_vector(passage.context))
25+
(
26+
passage,
27+
self.vectorizer.get_vector(
28+
get_query_with_prefix(passage.context, self.prefix)
29+
),
30+
)
2431
for passage in passages[i : i + 10]
2532
]
2633

‎src/common/utils.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@ def get_vectorizer_hash(model: str, prompt: str):
3232
return "vectorizer:" + hashed
3333

3434

35-
def get_prompt_hash(model: str, prompt: str):
36-
hashed = hashlib.sha256((model + prompt).encode()).hexdigest()
35+
def get_prompt_hash(model: str, dataset_key: str, prompt: str, distance: str):
36+
hashed = hashlib.sha256(
37+
(model + dataset_key + prompt + distance).encode()
38+
).hexdigest()
3739
return "prompt:" + hashed
3840

3941

@@ -42,9 +44,11 @@ def get_es_query_hash(index_name: str, dataset_key: str, query: str):
4244
return "query:" + hashed
4345

4446

45-
def get_reranker_hash(model: str, query: str, passage_ids: list, count: int):
47+
def get_reranker_hash(
48+
model: str, query: str, passage_ids: list, dataset_key: str, count: int
49+
):
4650
hashed = hashlib.sha256(
47-
(model + query + str(passage_ids) + str(count)).encode()
51+
(model + query + str(passage_ids) + dataset_key + str(count)).encode()
4852
).hexdigest()
4953
return "reranker:" + hashed
5054

@@ -77,14 +81,13 @@ def get_all_qdrant_model_combinations():
7781
for split, _ in CHUNK_SIZES
7882
]
7983

80-
qdrant_collection_names = [
81-
get_qdrant_collection_name(model, distance)
84+
return [
85+
(model, distance, dataset_key)
86+
for dataset_key in dataset_keys
8287
for model in MODEL_NAMES
8388
for distance in DISTANCES
8489
]
8590

86-
return [
87-
(collection_name, dataset_key)
88-
for collection_name in qdrant_collection_names
89-
for dataset_key in dataset_keys
90-
]
91+
92+
def get_query_with_prefix(query: str, prefix: str):
93+
return f"{prefix}{query}"

‎src/load_qdrant_data.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
DATASET_NAMES,
77
DISTANCES,
88
MODEL_NAMES,
9+
PASSAGE_PREFIX_MAP,
10+
QUERY_PREFIX_MAP,
911
)
1012
from common.passage_factory import PassageFactory
1113
from common.qdrant_data_importer import QdrantDataImporter
@@ -22,7 +24,7 @@ def main():
2224
client = QdrantClient(host="localhost", port=6333)
2325
cache = Cache()
2426

25-
for model_name in MODEL_NAMES:
27+
for model_name in ["intfloat/multilingual-e5-large"]:
2628
vectorizer = HFVectorizer(model_name, cache)
2729
for distance in DISTANCES:
2830
insert_passage_data(client, model_name, distance, cache, vectorizer)
@@ -56,16 +58,23 @@ def insert_passage_data(
5658
chunk_size, chunk_overlap, dataset_name
5759
)
5860

61+
passage_prefix = PASSAGE_PREFIX_MAP[model_name]
62+
query_prefix = QUERY_PREFIX_MAP[model_name]
63+
5964
repository = QdrantRepository(
6065
client,
6166
collection_name,
6267
model_name,
6368
VectorParams(size=MODEL_DIMENSIONS_MAP[model_name], distance=distance),
6469
vectorizer,
6570
cache,
71+
passage_prefix,
72+
query_prefix,
6673
)
6774

68-
data_importer = QdrantDataImporter(repository, passage_factory, vectorizer)
75+
data_importer = QdrantDataImporter(
76+
repository, passage_factory, vectorizer, passage_prefix
77+
)
6978

7079
data_importer.import_data()
7180

‎src/repository/es_repository.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def count_relevant_documents(self, passage_id: str, dataset_key: str) -> int:
8686

8787
cached_value = self.cache.get(hash_key)
8888

89-
# if cached_value:
90-
# return int(cached_value)
89+
if cached_value:
90+
return int(cached_value)
9191

9292
body = {
9393
"query": {

‎src/repository/qdrant_repository.py

+26-5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from common.utils import (
77
get_prompt_hash,
88
get_qdrant_collection_name,
9+
get_query_with_prefix,
910
get_relevant_document_count_hash,
1011
)
1112
from repository.repository import Repository
@@ -27,12 +28,21 @@ def __init__(
2728
vectors_config: VectorParams,
2829
vectorizer: Vectorizer,
2930
cache: Cache,
31+
passage_prefix: str = "",
32+
query_prefix: str = "",
3033
):
3134
self.qdrant = client
3235
self.collection_name = collection_name
3336
self.model_name = model_name
3437
self.vectorizer = vectorizer
3538
self.cache = cache
39+
self.passage_prefix = passage_prefix
40+
self.query_prefix = query_prefix
41+
self.distance = (
42+
Distance.COSINE
43+
if Distance.COSINE.lower() in collection_name.lower()
44+
else Distance.EUCLID
45+
)
3646

3747
collections = self.qdrant.get_collections()
3848
if collection_name not in [
@@ -53,7 +63,9 @@ def insert_one(self, passage: Passage):
5363
points=[
5464
PointStruct(
5565
id=str(uuid.uuid4()),
56-
vector=self.vectorizer.get_vector(passage.text),
66+
vector=self.vectorizer.get_vector(
67+
get_query_with_prefix(passage.context, self.passage_prefix)
68+
),
5769
payload=passage.dict(),
5870
)
5971
],
@@ -76,7 +88,9 @@ def insert_many(self, passages: List[Passage]):
7688
points = [
7789
PointStruct(
7890
id=str(uuid.uuid4()),
79-
vector=self.vectorizer.get_vector(passage.text),
91+
vector=self.vectorizer.get_vector(
92+
get_query_with_prefix(passage.context, self.passage_prefix)
93+
),
8094
payload=passage.dict(),
8195
)
8296
for passage in passages
@@ -86,7 +100,7 @@ def insert_many(self, passages: List[Passage]):
86100
collection_name=self.collection_name, wait=True, points=points
87101
)
88102

89-
def insert_many_with_vectors(self, passages: List[Tuple]):
103+
def insert_many_with_vectors(self, passages: List[Tuple[Passage, List[float]]]):
90104
points = [
91105
PointStruct(
92106
id=str(uuid.uuid4()),
@@ -101,7 +115,10 @@ def insert_many_with_vectors(self, passages: List[Tuple]):
101115
)
102116

103117
def find(self, query: str, dataset_key: str) -> Result:
104-
hash_key = get_prompt_hash(self.model_name, query)
118+
full_query = get_query_with_prefix(query, self.query_prefix)
119+
hash_key = get_prompt_hash(
120+
self.model_name, dataset_key, full_query, self.distance
121+
)
105122

106123
cached_value = self.cache.get(hash_key)
107124

@@ -110,7 +127,7 @@ def find(self, query: str, dataset_key: str) -> Result:
110127
passages = [(Passage.from_dict(d["passage"]), d["score"]) for d in dicts]
111128
return Result(query, passages)
112129

113-
vector = self.vectorizer.get_vector(query)
130+
vector = self.vectorizer.get_vector(full_query)
114131

115132
data = self.qdrant.search(
116133
collection_name=self.collection_name,
@@ -164,6 +181,8 @@ def get_repository(
164181
model_name: str,
165182
distance: Distance,
166183
cache: Cache,
184+
passage_prefix: str = "",
185+
query_prefix: str = "",
167186
):
168187
collection_name = get_qdrant_collection_name(model_name, distance)
169188
vectorizer = HFVectorizer(model_name, cache)
@@ -175,6 +194,8 @@ def get_repository(
175194
VectorParams(size=MODEL_DIMENSIONS_MAP[model_name], distance=distance),
176195
vectorizer,
177196
cache,
197+
passage_prefix,
198+
query_prefix,
178199
)
179200

180201
def count_relevant_documents(self, passage_id, dataset_key) -> int:

‎src/rerankers/hf_reranker.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import json
22
from typing import List
3+
34
from cache.cache import Cache
45
from common.passage import Passage
6+
from common.result import Result
57
from common.utils import get_reranker_hash
68
from rerankers.reranker import Reranker
79
from sentence_transformers import CrossEncoder
@@ -15,37 +17,49 @@ def __init__(self, model_name: str, cache: Cache):
1517

1618
print(f"Vectorizer with model {model_name} initialized")
1719

18-
def get_relevant_passages(
19-
self, query: str, passages: List[Passage], count: int
20-
) -> List[Passage]:
21-
passages_ids = list(map(lambda passage: passage.id, passages))
20+
def rerank(self, result: Result, count: int, dataset_key: str) -> Result:
21+
if (len(result.passages)) == 0:
22+
return result
23+
24+
passages_ids = [passage[0].id for passage in result.passages]
2225
sorted_passages_ids = sorted(passages_ids)
2326

24-
reranker_hash = get_reranker_hash(
25-
self.model_name, query, sorted_passages_ids, count
27+
hash_key = get_reranker_hash(
28+
self.model_name, result.query, sorted_passages_ids, dataset_key, count
2629
)
2730

28-
maybe_cached_result = self.cache.get(reranker_hash)
31+
maybe_cached_result = self.cache.get(hash_key)
2932

3033
if maybe_cached_result:
31-
json_result = json.loads(maybe_cached_result)
32-
return list(map(lambda passage: Passage.from_dict(passage), json_result))
34+
dicts = json.loads(maybe_cached_result)
35+
passages = [(Passage.from_dict(d["passage"]), d["score"]) for d in dicts]
36+
return Result(result.query, passages)
3337

34-
pairs = [
35-
[query, passage]
36-
for passage in list(map(lambda passage: passage.context, passages))
37-
]
38+
pairs = [[result.query, passage[0].context] for passage in result.passages]
3839

3940
results = self.model.predict(pairs)
4041
results_list = results.tolist()
4142

42-
scored_passages = list(zip(results_list, passages))
43+
scored_passages = list(zip(results_list, result.passages))
4344
sorted_passages = sorted(scored_passages, key=lambda x: x[0], reverse=True)
4445
top_n_passages = sorted_passages[:count]
45-
top_n_passages = [passage for _, passage in top_n_passages]
46+
top_n_passages = [(passage, score) for score, (passage, _) in top_n_passages]
47+
48+
max_score = top_n_passages[0][1]
49+
min_score = top_n_passages[-1][1]
4650

47-
top_n_passages_dict = list(map(lambda passage: passage.dict(), top_n_passages))
51+
score_diff = max_score - min_score
4852

49-
self.cache.set(reranker_hash, json.dumps(top_n_passages_dict))
53+
normalized_passages = [
54+
(p, 1 if score_diff == 0 else (s - min_score) / score_diff)
55+
for (p, s) in top_n_passages
56+
]
57+
58+
reranked_result = Result(result.query, normalized_passages)
59+
60+
result_json = json.dumps(
61+
[{"passage": p.dict(), "score": s} for (p, s) in normalized_passages]
62+
)
63+
self.cache.set(hash_key, result_json)
5064

51-
return top_n_passages
65+
return reranked_result

‎src/rerankers/reranker.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
from abc import ABC, abstractmethod
2-
from typing import List
3-
from common.passage import Passage
2+
from common.result import Result
43

54

65
class Reranker(ABC):
76
@abstractmethod
8-
def get_relevant_passages(
9-
self, query: str, passages: List[Passage], count: int
10-
) -> List[Passage]:
7+
def rerank(self, result: Result, count: int) -> Result:
118
pass

‎src/retrievers/es_retriever.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from typing import List
22
from common.result import Result
33
from repository.es_repository import ESRepository
4+
from rerankers.hf_reranker import HFReranker
45
from retrievers.retriever import Retriever
56

67

78
class ESRetriever(Retriever):
8-
def __init__(self, repository: ESRepository, dataset_key: str):
9+
def __init__(
10+
self, repository: ESRepository, dataset_key: str, reranker: HFReranker = None
11+
):
912
self.repository = repository
1013
self.dataset_key = dataset_key
14+
self.reranker = reranker
1115

1216
def get_relevant_passages(self, query: str) -> Result:
1317
result = self.repository.find(query, self.dataset_key)
1418

19+
if self.reranker:
20+
result = self.reranker.rerank(result, 10, self.dataset_key)
21+
1522
return result

‎src/retrievers/hybrid_retriever.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from common.result import Result
44
from repository.es_repository import ESRepository
55
from repository.qdrant_repository import QdrantRepository
6+
from rerankers.hf_reranker import HFReranker
67
from retrievers.retriever import Retriever
78

89

@@ -13,11 +14,13 @@ def __init__(
1314
qdrant_repository: QdrantRepository,
1415
dataset_key: str,
1516
alpha: float = 0.5, # weight for ES
17+
reranker: HFReranker = None,
1618
):
1719
self.es_repository = es_repository
1820
self.qdrant_repository = qdrant_repository
1921
self.dataset_key = dataset_key
2022
self.alpha = alpha
23+
self.reranker = reranker
2124

2225
def get_relevant_passages(self, query: str) -> List[str]:
2326
es_result = self.es_repository.find(query, self.dataset_key)
@@ -39,4 +42,9 @@ def get_relevant_passages(self, query: str) -> List[str]:
3942
final_results = list(combined_scores.items())
4043
final_results.sort(key=lambda x: x[1], reverse=True)
4144

42-
return Result(query, final_results[:10])
45+
result = Result(query, final_results[:10])
46+
47+
if self.reranker:
48+
result = self.reranker.rerank(result, 10, self.dataset_key)
49+
50+
return result

‎src/retrievers/qdrant_retriever.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
from typing import List
22
from common.result import Result
33
from repository.qdrant_repository import QdrantRepository
4+
from rerankers.hf_reranker import HFReranker
45
from retrievers.retriever import Retriever
56

67

78
class QdrantRetriever(Retriever):
8-
def __init__(self, repository: QdrantRepository, dataset_key: str):
9+
def __init__(
10+
self,
11+
repository: QdrantRepository,
12+
dataset_key: str,
13+
reranker: HFReranker = None,
14+
):
915
self.repository = repository
1016
self.dataset_key = dataset_key
17+
self.reranker = reranker
1118

1219
def get_relevant_passages(self, query: str) -> Result:
1320
result = self.repository.find(query, self.dataset_key)
1421

22+
if self.reranker:
23+
result = self.reranker.rerank(result, 10, self.dataset_key)
24+
1525
return result

0 commit comments

Comments
 (0)
Please sign in to comment.