Skip to content

Commit bcb5330

Browse files
committedAug 3, 2024
Return count of relevant docs
1 parent c1830d0 commit bcb5330

File tree

7 files changed

+113
-67
lines changed

7 files changed

+113
-67
lines changed
 

‎src/common/names.py

-9
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,4 @@
2222
"morfologik_stopwords_index",
2323
]
2424

25-
SEMANTIC_TYPES = ["interquartile", "standard_deviation", "percentile"]
26-
2725
CHUNK_SIZES = [(500, 100), (1000, 200), (2000, 500), (100000, 0)]
28-
29-
CHARACTER_SPLITTING_FUNCTION = [
30-
"character-500",
31-
"character-1000",
32-
"character-2000",
33-
"character-100000",
34-
]

‎src/common/utils.py

+36-25
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
import uuid
2-
from common.names import (
3-
CHUNK_SIZES,
4-
DATASET_NAMES,
5-
DISTANCES,
6-
MODEL_NAMES,
7-
SEMANTIC_TYPES,
8-
)
2+
from common.names import CHUNK_SIZES, DATASET_NAMES, DISTANCES, INDEX_NAMES, MODEL_NAMES
93
from common.passage import Passage
104
import hashlib
115

@@ -50,25 +44,42 @@ def get_reranker_hash(model: str, query: str, passage_ids: list, count: int):
5044
return "reranker:" + hashed
5145

5246

53-
def get_all_qdrant_collection_names():
54-
names = []
55-
for dataset_name in DATASET_NAMES:
56-
for model_name in MODEL_NAMES:
57-
for distance in DISTANCES:
58-
for chunk_size, _ in CHUNK_SIZES:
59-
name = get_qdrant_collection_name(
60-
dataset_name, model_name, "character", chunk_size, distance
61-
)
62-
names.append(name)
63-
64-
for semantic_type in SEMANTIC_TYPES:
65-
name = get_qdrant_collection_name(
66-
dataset_name, model_name, semantic_type, 1.5, distance
67-
)
68-
names.append(name)
69-
70-
return names
47+
def get_relevant_document_count_hash(id: str, dataset_key: str):
48+
hashed = hashlib.sha256((id + dataset_key).encode()).hexdigest()
49+
return "count:" + hashed
7150

7251

7352
def get_dataset_key(dataset_name: str, split: str):
7453
return replace_slash_with_dash(f"{dataset_name}-{split}")
54+
55+
56+
def get_all_es_index_combinations():
57+
dataset_keys = [
58+
get_dataset_key(dataset_name, split)
59+
for dataset_name in DATASET_NAMES
60+
for split, _ in CHUNK_SIZES
61+
]
62+
63+
return [
64+
(index, dataset_key) for index in INDEX_NAMES for dataset_key in dataset_keys
65+
]
66+
67+
68+
def get_all_qdrant_model_combinations():
69+
dataset_keys = [
70+
get_dataset_key(dataset_name, split)
71+
for dataset_name in DATASET_NAMES
72+
for split, _ in CHUNK_SIZES
73+
]
74+
75+
qdrant_collection_names = [
76+
get_qdrant_collection_name(model, distance)
77+
for model in MODEL_NAMES
78+
for distance in DISTANCES
79+
]
80+
81+
return [
82+
(collection_name, dataset_key)
83+
for collection_name in qdrant_collection_names
84+
for dataset_key in dataset_keys
85+
]

‎src/evaluation/retriever_evaluator.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
class RetrieverEvaluator:
99
# Calculate NDCG for top 10 results
10-
def calculate_ndcg(self, result: Result, correct_passage_title: str) -> float:
10+
def calculate_ndcg(self, result: Result, correct_passage_id: str) -> float:
1111
relevances = [
12-
1 if passage.title == correct_passage_title else 0
13-
for passage in result.passages
12+
1 if passage.id == correct_passage_id else 0 for passage in result.passages
1413
]
1514

1615
sorted_relevances = sorted(relevances, reverse=True)
@@ -23,22 +22,22 @@ def calculate_ndcg(self, result: Result, correct_passage_title: str) -> float:
2322
return dcg / idcg if idcg != 0 else 0
2423

2524
# Calculate MRR for top 10 results
26-
def calculate_mrr(self, result: Result, correct_passage_title: str) -> float:
25+
def calculate_mrr(self, result: Result, correct_passage_id: str) -> float:
2726
for i, passage in enumerate(result.passages):
28-
if passage.title == correct_passage_title:
27+
if passage.id == correct_passage_id:
2928
return 1 / (i + 1)
3029
return 0
3130

3231
# Calculate recall for top 10 results
3332
def calculate_recall(
34-
self, result: Result, correct_passage_title: str, relevant_documents_count: int
33+
self, result: Result, correct_passage_id: str, relevant_documents_count: int
3534
) -> float:
3635
relevant_documents = sum(
37-
1 for passage in result.passages if passage.title == correct_passage_title
36+
1 for passage in result.passages if passage.id == correct_passage_id
3837
)
3938

4039
return relevant_documents / relevant_documents_count
4140

4241
# Calculate accuracy for top 1 result
43-
def calculate_accuracy(self, result: Result, correct_passage_title: str) -> float:
44-
return 1 if result.passages[0].title == correct_passage_title else 0
42+
def calculate_accuracy(self, result: Result, correct_passage_id: str) -> float:
43+
return 1 if result.passages[0].id == correct_passage_id else 0

‎src/main.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,8 @@
1-
from ast import Dict
2-
from tkinter import Place
3-
from xml.etree.ElementInclude import include
4-
from elasticsearch import Elasticsearch
51
from langchain_text_splitters import RecursiveCharacterTextSplitter
6-
from common.names import DATASET_NAMES, INDEX_NAMES
7-
from common.passage import Passage
2+
from common.names import DATASET_NAMES
83
from common.passage_factory import PassageFactory
9-
from common.utils import replace_slash_with_dash
104
from dataset.poquad_dataset_getter import PoquadDatasetGetter
11-
from repository.es_repository import ESRepository
12-
from repository.qdrant_repository import QdrantRepository
13-
from qdrant_client import QdrantClient
14-
from qdrant_client.models import Distance, VectorParams
15-
from vectorizer.hf_vectorizer import HFVectorizer
16-
from langchain_experimental.text_splitter import SemanticChunker
17-
from langchain_community.embeddings import HuggingFaceEmbeddings
185
from dataset.polqa_dataset_getter import PolqaDatasetGetter
19-
from elasticsearch import Elasticsearch
20-
from qdrant_client import QdrantClient
21-
from cache.cache import Cache
22-
from common.models_dimensions import MODEL_DIMENSIONS_MAP
23-
from common.names import DISTANCES, MODEL_NAMES
24-
from common.utils import get_all_qdrant_collection_names
25-
from repository.qdrant_repository import QdrantRepository
26-
from vectorizer.hf_vectorizer import HFVectorizer
276

287

298
def get_passage_factory(

‎src/repository/es_repository.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from cache.cache import Cache
44
from common.passage import Passage
55
from common.result import Result
6-
from common.utils import get_es_query_hash
6+
from common.utils import (
7+
get_es_query_hash,
8+
get_relevant_document_count_hash,
9+
)
710
from repository.repository import Repository
811
import json
912

@@ -74,3 +77,29 @@ def find(self, query: str, dataset_key: str) -> Result:
7477
def delete(self, query: str):
7578
body = {"query": {"match": {"text": query}}}
7679
return self.client.delete_by_query(index=self.index_name, body=body)
80+
81+
def count_relevant_documents(self, id, dataset_key) -> int:
82+
hash_key = get_relevant_document_count_hash(id, dataset_key)
83+
84+
cached_value = self.cache.get(hash_key)
85+
86+
if cached_value:
87+
return int(cached_value)
88+
89+
if cached_value:
90+
return int(cached_value)
91+
92+
body = {
93+
"query": {
94+
"bool": {
95+
"must": [
96+
{"match": {"id": id}},
97+
{"match": {"dataset_key": dataset_key}},
98+
]
99+
}
100+
},
101+
}
102+
103+
response = self.client.count(index=self.index_name, body=body)
104+
105+
return response["count"]

‎src/repository/qdrant_repository.py

+34-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from common.models_dimensions import MODEL_DIMENSIONS_MAP
44
from common.passage import Passage
55
from common.result import Result
6-
from common.utils import get_prompt_hash, get_qdrant_collection_name
6+
from common.utils import (
7+
get_prompt_hash,
8+
get_qdrant_collection_name,
9+
get_relevant_document_count_hash,
10+
)
711
from repository.repository import Repository
812
from qdrant_client import QdrantClient, models
913
from qdrant_client.models import VectorParams, PointStruct, Distance
@@ -169,3 +173,32 @@ def get_repository(
169173
vectorizer,
170174
cache,
171175
)
176+
177+
def count_relevant_documents(self, id, dataset_key) -> int:
178+
hash_key = get_relevant_document_count_hash(id, dataset_key)
179+
180+
cached_value = self.cache.get(hash_key)
181+
182+
if cached_value:
183+
return int(cached_value)
184+
185+
result = self.qdrant.count(
186+
collection_name=self.collection_name,
187+
count_filter=models.Filter(
188+
must=[
189+
models.FieldCondition(
190+
key="id",
191+
match=models.MatchValue(value=int(id)),
192+
),
193+
models.FieldCondition(
194+
key="dataset_key",
195+
match=models.MatchValue(value=dataset_key),
196+
),
197+
]
198+
),
199+
exact=True,
200+
)
201+
202+
self.cache.set(hash_key, str(result.count))
203+
204+
return result.count

‎src/repository/repository.py

+4
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,7 @@ def find(self, query, dataset_key) -> Result:
2020
@abstractmethod
2121
def delete(self, query):
2222
pass
23+
24+
@abstractmethod
25+
def count_relevant_documents(self, id, dataset_key) -> int:
26+
pass

0 commit comments

Comments
 (0)
Please sign in to comment.