@@ -24,8 +24,8 @@ def insert_many(self, data: list[Passage]):
24
24
documents = [d .dict () for d in data ]
25
25
return helpers .bulk (self .client , documents , index = self .index_name )
26
26
27
- def find (self , query : str , dataset_key : str ) -> Result :
28
- hash_key = get_es_query_hash (self .index_name , dataset_key , query )
27
+ def find (self , query : str , dataset_key : str , size : int = 10 ) -> Result :
28
+ hash_key = get_es_query_hash (self .index_name , dataset_key , query , size )
29
29
cached_value = self .cache .get (hash_key )
30
30
31
31
if cached_value :
@@ -34,7 +34,7 @@ def find(self, query: str, dataset_key: str) -> Result:
34
34
return Result (query , passages )
35
35
36
36
body = {
37
- "size" : 10 ,
37
+ "size" : size ,
38
38
"query" : {
39
39
"bool" : {
40
40
"must" : [
@@ -81,23 +81,29 @@ def delete(self, query: str):
81
81
body = {"query" : {"match" : {"text" : query }}}
82
82
return self .client .delete_by_query (index = self .index_name , body = body )
83
83
84
- def count_relevant_documents (self , passage_id : str , dataset_key : str ) -> int :
85
- hash_key = get_relevant_document_count_hash (passage_id , dataset_key )
84
+ def count_relevant_documents (self , passage_ids : list [str ], dataset_key : str ) -> int :
85
+ sorted_passage_ids = sorted (passage_ids )
86
+ joined_passage_ids = "," .join (map (str , sorted_passage_ids ))
87
+ hash_key = get_relevant_document_count_hash (joined_passage_ids , dataset_key )
86
88
87
89
cached_value = self .cache .get (hash_key )
88
90
89
91
if cached_value :
90
92
return int (cached_value )
91
93
94
+ is_poquad = True if "poquad" in dataset_key else False
95
+
96
+ must = [
97
+ {"match" : {"dataset_key" : dataset_key }},
98
+ ]
99
+
100
+ if is_poquad :
101
+ must .append ({"match" : {"id" : passage_ids [0 ]}})
102
+ else :
103
+ must .append ({"terms" : {"metadata.passage_id" : passage_ids }})
104
+
92
105
body = {
93
- "query" : {
94
- "bool" : {
95
- "must" : [
96
- {"match" : {"id" : passage_id }},
97
- {"match" : {"dataset_key" : dataset_key }},
98
- ]
99
- }
100
- },
106
+ "query" : {"bool" : {"must" : must }},
101
107
}
102
108
103
109
response = self .client .count (index = self .index_name , body = body )
0 commit comments