|
| 1 | +from sentence_transformers import CrossEncoder |
| 2 | +from cache.cache import Cache |
| 3 | +from common.models_dimensions import RERANKER_MODEL_DIMENSIONS_MAP |
| 4 | +from common.result import Result |
| 5 | +import nltk |
| 6 | +from nltk.tokenize import sent_tokenize |
| 7 | +import ssl |
| 8 | +from mlx_lm import load, generate |
| 9 | +from common.utils import ( |
| 10 | + get_answer_relevance_hash, |
| 11 | + get_faithfulness_hash, |
| 12 | + get_query_to_context_relevance_hash, |
| 13 | +) |
| 14 | +from vectorizer.vectorizer import Vectorizer |
| 15 | + |
| 16 | + |
| 17 | +class RAGASEvaluator: |
| 18 | + def __init__( |
| 19 | + self, |
| 20 | + reranker_model_name: str, |
| 21 | + cache: Cache, |
| 22 | + generator_model_name: str, |
| 23 | + vectorizer: Vectorizer, |
| 24 | + ): |
| 25 | + try: |
| 26 | + _create_unverified_https_context = ssl._create_unverified_context |
| 27 | + except AttributeError: |
| 28 | + pass |
| 29 | + else: |
| 30 | + ssl._create_default_https_context = _create_unverified_https_context |
| 31 | + |
| 32 | + nltk.download("punkt") |
| 33 | + |
| 34 | + self.raranker_model_name = reranker_model_name |
| 35 | + self.raranker_model = CrossEncoder( |
| 36 | + reranker_model_name, |
| 37 | + max_length=RERANKER_MODEL_DIMENSIONS_MAP[reranker_model_name], |
| 38 | + ) |
| 39 | + model, tokenizer = load(generator_model_name) |
| 40 | + self.generator_model = model |
| 41 | + self.generator_tokenizer = tokenizer |
| 42 | + self.cache = cache |
| 43 | + self.vectorizer = vectorizer |
| 44 | + pass |
| 45 | + |
| 46 | + def context_precision(self, result: Result, correct_passage_id: str) -> float: |
| 47 | + |
| 48 | + precisions_sum = 0 |
| 49 | + |
| 50 | + for i in range(1, len(result.passages) + 1): |
| 51 | + sub_precision = 0 |
| 52 | + |
| 53 | + for j in range(0, i): |
| 54 | + if result.passages[j][0].id == correct_passage_id: |
| 55 | + sub_precision += 1 |
| 56 | + |
| 57 | + relevant_at_i = ( |
| 58 | + 1 if result.passages[i - 1][0].id == correct_passage_id else 0 |
| 59 | + ) |
| 60 | + |
| 61 | + precisions_sum += (sub_precision / i) * relevant_at_i |
| 62 | + |
| 63 | + sum_of_relevant_passages = sum( |
| 64 | + [1 for passage in result.passages if passage[0].id == correct_passage_id] |
| 65 | + ) |
| 66 | + |
| 67 | + value = ( |
| 68 | + precisions_sum / sum_of_relevant_passages |
| 69 | + if sum_of_relevant_passages > 0 |
| 70 | + else 0 |
| 71 | + ) |
| 72 | + |
| 73 | + return value |
| 74 | + |
| 75 | + def context_recall(self, result: Result, correct_passage_id: str) -> float: |
| 76 | + is_any_relevant = False |
| 77 | + |
| 78 | + for passage in result.passages: |
| 79 | + if passage[0].id == correct_passage_id: |
| 80 | + is_any_relevant = True |
| 81 | + break |
| 82 | + |
| 83 | + return 1 if is_any_relevant else 0 |
| 84 | + |
| 85 | + def _get_relevancy_between_answer_and_context( |
| 86 | + self, answer: str, contexts: list[str] |
| 87 | + ): |
| 88 | + pairs = [[answer, context] for context in contexts] |
| 89 | + results = self.raranker_model.predict(pairs) |
| 90 | + return 1 if max(results) > 0.5 else 0 |
| 91 | + |
| 92 | + def faithfulness(self, result: Result, answer: str) -> float: |
| 93 | + sentences = sent_tokenize(answer) |
| 94 | + context = " ".join([passage[0].context for passage in result.passages]) |
| 95 | + context_sentences = sent_tokenize(context) |
| 96 | + |
| 97 | + hash_key = get_faithfulness_hash(answer, context) |
| 98 | + |
| 99 | + maybe_faithfulness = self.cache.get(hash_key) |
| 100 | + |
| 101 | + if maybe_faithfulness: |
| 102 | + return float(maybe_faithfulness) |
| 103 | + |
| 104 | + faihtfulness = sum( |
| 105 | + [ |
| 106 | + self._get_relevancy_between_answer_and_context( |
| 107 | + sentence, context_sentences |
| 108 | + ) |
| 109 | + for sentence in sentences |
| 110 | + ] |
| 111 | + ) / len(sentences) |
| 112 | + |
| 113 | + self.cache.set(hash_key, str(faihtfulness)) |
| 114 | + |
| 115 | + return faihtfulness |
| 116 | + |
| 117 | + def answer_relevance(self, original_question: str, answer: str) -> float: |
| 118 | + hash_key = get_answer_relevance_hash(original_question, answer) |
| 119 | + |
| 120 | + maybe_answer_relevance = self.cache.get(hash_key) |
| 121 | + |
| 122 | + if maybe_answer_relevance: |
| 123 | + return float(maybe_answer_relevance) |
| 124 | + |
| 125 | + prompt = f""" |
| 126 | + [INST] |
| 127 | + Na podstawie podanego kontekstu wygeneruj trzy pytania, które mogłyby zostać zadane w kontekście tego tekstu. |
| 128 | + Zwróć pytania w formacie |
| 129 | + 1. Pierwsze pytanie |
| 130 | + 2. Drugie pytanie |
| 131 | + 3. Trzecie pytanie |
| 132 | + |
| 133 | + Każde pytanie musi być zakończone kropką i znajdować się w osobnej linii. Zwróć tylko pytania z numerami 1, 2 i 3. |
| 134 | + |
| 135 | + Kontekst: {answer.replace("\n", " ")} |
| 136 | + [/INST] |
| 137 | + """ |
| 138 | + |
| 139 | + generated_questions = generate( |
| 140 | + self.generator_model, |
| 141 | + self.generator_tokenizer, |
| 142 | + prompt=prompt, |
| 143 | + max_tokens=300, |
| 144 | + ) |
| 145 | + |
| 146 | + questions = generated_questions.split("\n") |
| 147 | + filtered_questions = [ |
| 148 | + q.strip() |
| 149 | + for q in questions |
| 150 | + if q.strip() and len(q.strip()) > 0 and q.strip()[0] in {"1", "2", "3"} |
| 151 | + ][:3] |
| 152 | + |
| 153 | + original_question_vector = self.vectorizer.get_vector( |
| 154 | + f"zapytanie: {original_question}" |
| 155 | + ) |
| 156 | + sentence_vectiors = [ |
| 157 | + self.vectorizer.get_vector(f"zapytanie: {q}") for q in filtered_questions |
| 158 | + ] |
| 159 | + results = [ |
| 160 | + self.vectorizer.get_similarity(original_question_vector, v).item() |
| 161 | + for v in sentence_vectiors |
| 162 | + ] |
| 163 | + |
| 164 | + score = ( |
| 165 | + (1 / len(filtered_questions)) * sum(results) |
| 166 | + if len(filtered_questions) > 0 |
| 167 | + else 0 |
| 168 | + ) |
| 169 | + |
| 170 | + self.cache.set(hash_key, str(score)) |
| 171 | + |
| 172 | + return score |
| 173 | + |
| 174 | + def query_to_context_relevance(self, result: Result) -> float: |
| 175 | + hash_key = get_query_to_context_relevance_hash(result) |
| 176 | + |
| 177 | + maybe_query_to_context_relevance = self.cache.get(hash_key) |
| 178 | + |
| 179 | + if maybe_query_to_context_relevance: |
| 180 | + return float(maybe_query_to_context_relevance) |
| 181 | + |
| 182 | + query_vector = self.vectorizer.get_vector(result.query) |
| 183 | + context_vectors = [ |
| 184 | + self.vectorizer.get_vector(passage[0].context) |
| 185 | + for passage in result.passages |
| 186 | + ] |
| 187 | + |
| 188 | + result = [ |
| 189 | + self.vectorizer.get_similarity(query_vector, v).item() |
| 190 | + for v in context_vectors |
| 191 | + ] |
| 192 | + |
| 193 | + score = (1 / len(context_vectors)) * sum(result) |
| 194 | + |
| 195 | + self.cache.set(hash_key, str(score)) |
| 196 | + |
| 197 | + return score |
| 198 | + |
| 199 | + def ragas(self, result: Result, correct_passage_id: str, answer: str) -> float: |
| 200 | + context_precision = self.context_precision(result, correct_passage_id) |
| 201 | + context_recall = self.context_recall(result, correct_passage_id) |
| 202 | + faithfulness = self.faithfulness(result, answer) |
| 203 | + answer_relevance = self.answer_relevance(result.query, answer) |
| 204 | + |
| 205 | + return ( |
| 206 | + context_precision + context_recall + faithfulness + answer_relevance |
| 207 | + ) / 4 |
| 208 | + |
| 209 | + def hallucination(self, result: Result, answer: str) -> float: |
| 210 | + faithfulness = self.faithfulness(result, answer) |
| 211 | + answer_relevance = self.answer_relevance(result.query, answer) |
| 212 | + context_relevance = self.query_to_context_relevance(result) |
| 213 | + |
| 214 | + return (faithfulness + answer_relevance + context_relevance) / 3 |
0 commit comments