Skip to content

Commit 491f220

Browse files
committedAug 12, 2024
Update retrievers to work and save evaluations correctly
1 parent 3d276cc commit 491f220

14 files changed

+706
-131
lines changed
 

‎elasticsearch/basic_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
},
1818
"dataset_key": {
1919
"type": "keyword"
20+
},
21+
"metadata": {
22+
"properties": {
23+
"passage_id": {
24+
"type": "keyword"
25+
}
26+
}
2027
}
2128
}
2229
}

‎elasticsearch/morfologik_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
},
1818
"dataset_key": {
1919
"type": "keyword"
20+
},
21+
"metadata": {
22+
"properties": {
23+
"passage_id": {
24+
"type": "keyword"
25+
}
26+
}
2027
}
2128
}
2229
}

‎elasticsearch/morfologik_stopwords_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
},
3939
"dataset_key": {
4040
"type": "keyword"
41+
},
42+
"metadata": {
43+
"properties": {
44+
"passage_id": {
45+
"type": "keyword"
46+
}
47+
}
4148
}
4249
}
4350
}

‎elasticsearch/morfologik_whitespace_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
},
3232
"dataset_key": {
3333
"type": "keyword"
34+
},
35+
"metadata": {
36+
"properties": {
37+
"passage_id": {
38+
"type": "keyword"
39+
}
40+
}
3441
}
3542
}
3643
}

‎elasticsearch/polish_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
},
1818
"dataset_key": {
1919
"type": "keyword"
20+
},
21+
"metadata": {
22+
"properties": {
23+
"passage_id": {
24+
"type": "keyword"
25+
}
26+
}
2027
}
2128
}
2229
}

‎elasticsearch/polish_stopwords_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@
3737
},
3838
"dataset_key": {
3939
"type": "keyword"
40+
},
41+
"metadata": {
42+
"properties": {
43+
"passage_id": {
44+
"type": "keyword"
45+
}
46+
}
4047
}
4148
}
4249
}

‎elasticsearch/polish_whitespace_index.json

+7
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@
3131
},
3232
"dataset_key": {
3333
"type": "keyword"
34+
},
35+
"metadata": {
36+
"properties": {
37+
"passage_id": {
38+
"type": "keyword"
39+
}
40+
}
3441
}
3542
}
3643
}

‎src/common/utils.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,19 @@ def get_vectorizer_hash(model: str, prompt: str):
3939
return "vectorizer:" + hashed
4040

4141

42-
def get_prompt_hash(model: str, dataset_key: str, prompt: str, distance: str):
42+
def get_prompt_hash(
43+
model: str, dataset_key: str, prompt: str, distance: str, size: int
44+
):
4345
hashed = hashlib.sha256(
44-
(model + dataset_key + prompt + distance).encode()
46+
(model + dataset_key + prompt + distance + str(size)).encode()
4547
).hexdigest()
4648
return "prompt:" + hashed
4749

4850

49-
def get_es_query_hash(index_name: str, dataset_key: str, query: str):
50-
hashed = hashlib.sha256((index_name + dataset_key + query).encode()).hexdigest()
51+
def get_es_query_hash(index_name: str, dataset_key: str, query: str, size: int):
52+
hashed = hashlib.sha256(
53+
(index_name + dataset_key + query + str(size)).encode()
54+
).hexdigest()
5155
return "query:" + hashed
5256

5357

‎src/notebooks/01_retrievers_evaluation.ipynb

+614-102
Large diffs are not rendered by default.

‎src/repository/es_repository.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def insert_many(self, data: list[Passage]):
2424
documents = [d.dict() for d in data]
2525
return helpers.bulk(self.client, documents, index=self.index_name)
2626

27-
def find(self, query: str, dataset_key: str) -> Result:
28-
hash_key = get_es_query_hash(self.index_name, dataset_key, query)
27+
def find(self, query: str, dataset_key: str, size: int = 10) -> Result:
28+
hash_key = get_es_query_hash(self.index_name, dataset_key, query, size)
2929
cached_value = self.cache.get(hash_key)
3030

3131
if cached_value:
@@ -34,7 +34,7 @@ def find(self, query: str, dataset_key: str) -> Result:
3434
return Result(query, passages)
3535

3636
body = {
37-
"size": 10,
37+
"size": size,
3838
"query": {
3939
"bool": {
4040
"must": [
@@ -81,23 +81,29 @@ def delete(self, query: str):
8181
body = {"query": {"match": {"text": query}}}
8282
return self.client.delete_by_query(index=self.index_name, body=body)
8383

84-
def count_relevant_documents(self, passage_id: str, dataset_key: str) -> int:
85-
hash_key = get_relevant_document_count_hash(passage_id, dataset_key)
84+
def count_relevant_documents(self, passage_ids: list[str], dataset_key: str) -> int:
85+
sorted_passage_ids = sorted(passage_ids)
86+
joined_passage_ids = ",".join(map(str, sorted_passage_ids))
87+
hash_key = get_relevant_document_count_hash(joined_passage_ids, dataset_key)
8688

8789
cached_value = self.cache.get(hash_key)
8890

8991
if cached_value:
9092
return int(cached_value)
9193

94+
is_poquad = True if "poquad" in dataset_key else False
95+
96+
must = [
97+
{"match": {"dataset_key": dataset_key}},
98+
]
99+
100+
if is_poquad:
101+
must.append({"match": {"id": passage_ids[0]}})
102+
else:
103+
must.append({"terms": {"metadata.passage_id": passage_ids}})
104+
92105
body = {
93-
"query": {
94-
"bool": {
95-
"must": [
96-
{"match": {"id": passage_id}},
97-
{"match": {"dataset_key": dataset_key}},
98-
]
99-
}
100-
},
106+
"query": {"bool": {"must": must}},
101107
}
102108

103109
response = self.client.count(index=self.index_name, body=body)

‎src/repository/qdrant_openai_repository.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def insert_many_with_vectors(self, passages: List[Tuple[Passage, List[float]]]):
105105
collection_name=self.collection_name, wait=True, points=points
106106
)
107107

108-
def find(self, query: str, dataset_key: str) -> Result:
108+
def find(self, query: str, dataset_key: str, size: int = 10) -> Result:
109109
full_query = query
110110
hash_key = get_prompt_hash(
111-
self.model_name, dataset_key, full_query, self.distance
111+
self.model_name, dataset_key, full_query, self.distance, size
112112
)
113113

114114
cached_value = self.cache.get(hash_key)
@@ -123,7 +123,7 @@ def find(self, query: str, dataset_key: str) -> Result:
123123
data = self.qdrant.search(
124124
collection_name=self.collection_name,
125125
query_vector=vector,
126-
limit=10,
126+
limit=size,
127127
query_filter=models.Filter(
128128
must=[
129129
models.FieldCondition(

‎src/repository/qdrant_repository.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,10 @@ def insert_many_with_vectors(self, passages: List[Tuple[Passage, List[float]]]):
114114
collection_name=self.collection_name, wait=True, points=points
115115
)
116116

117-
def find(self, query: str, dataset_key: str) -> Result:
117+
def find(self, query: str, dataset_key: str, size: int = 10) -> Result:
118118
full_query = get_query_with_prefix(query, self.query_prefix)
119119
hash_key = get_prompt_hash(
120-
self.model_name, dataset_key, full_query, self.distance
120+
self.model_name, dataset_key, full_query, self.distance, size
121121
)
122122

123123
cached_value = self.cache.get(hash_key)
@@ -132,7 +132,7 @@ def find(self, query: str, dataset_key: str) -> Result:
132132
data = self.qdrant.search(
133133
collection_name=self.collection_name,
134134
query_vector=vector,
135-
limit=10,
135+
limit=size,
136136
query_filter=models.Filter(
137137
must=[
138138
models.FieldCondition(

‎src/retrievers/es_retriever.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ def __init__(
1313
self.dataset_key = dataset_key
1414
self.reranker = reranker
1515

16-
def get_relevant_passages(self, query: str) -> Result:
17-
result = self.repository.find(query, self.dataset_key)
16+
def get_relevant_passages(self, query: str, size: int = 10) -> Result:
17+
docs_size = size * 2 if self.reranker else size
18+
19+
result = self.repository.find(query, self.dataset_key, docs_size)
1820

1921
if self.reranker:
20-
result = self.reranker.rerank(result, 10, self.dataset_key)
22+
result = self.reranker.rerank(result, size, self.dataset_key)
2123

2224
return result

‎src/retrievers/qdrant_retriever.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ def __init__(
1616
self.dataset_key = dataset_key
1717
self.reranker = reranker
1818

19-
def get_relevant_passages(self, query: str) -> Result:
20-
result = self.repository.find(query, self.dataset_key)
19+
def get_relevant_passages(self, query: str, size: int = 10) -> Result:
20+
docs_size = size * 2 if self.reranker else size
21+
22+
result = self.repository.find(query, self.dataset_key, docs_size)
2123

2224
if self.reranker:
23-
result = self.reranker.rerank(result, 10, self.dataset_key)
25+
result = self.reranker.rerank(result, size, self.dataset_key)
2426

2527
return result

0 commit comments

Comments
 (0)
Please sign in to comment.