Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 059d6d9

Browse files
committedSep 4, 2024
Add hallucination score
1 parent 04b7a55 commit 059d6d9

File tree

7 files changed

+1384
-349
lines changed

7 files changed

+1384
-349
lines changed
 

‎src/common/result.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33

44

55
class Result:
6-
def __init__(self, query: str, passages: List[Tuple[Passage, int]]) -> None:
6+
def __init__(self, query: str, passages: List[Tuple[Passage, float]]) -> None:
77
self.query = query
88
self.passages = passages

‎src/common/utils.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from common.passage import Passage
1212
import hashlib
1313

14+
from common.result import Result
15+
1416

1517
def get_passages_for_embedding(dataset):
1618
unique_contexts = set((row["title"], row["context"]) for row in dataset)
@@ -75,19 +77,21 @@ def get_generator_hash(query: str, context: str, type: str, model: str):
7577
return "generator:" + hashed
7678

7779

78-
def get_ner_hash(answer: str, context: str):
80+
def get_faithfulness_hash(answer: str, context: str):
7981
hashed = hashlib.sha256((answer + context).encode()).hexdigest()
80-
return "ner:" + hashed
82+
return "faithfulness:" + hashed
8183

8284

83-
def get_halucination_hash(answer: str, context: str):
84-
hashed = hashlib.sha256((answer + context).encode()).hexdigest()
85-
return "halucination:" + hashed
85+
def get_answer_relevance_hash(original_question: str, answer: str):
86+
hashed = hashlib.sha256((original_question + answer).encode()).hexdigest()
87+
return "answer_relevance:" + hashed
8688

8789

88-
def get_answer_reranker_hash(answer: str, passages: list[Passage]):
89-
hashed = hashlib.sha256((answer + str(passages)).encode()).hexdigest()
90-
return "answer_reranker:" + hashed
90+
def get_query_to_context_relevance_hash(result: Result):
91+
hashed = hashlib.sha256(
92+
(result.query + str([passage for (passage, _) in result.passages])).encode()
93+
).hexdigest()
94+
return "query_to_context_relevance:" + hashed
9195

9296

9397
def get_query_reranker_hash(query: str, answer: str):

‎src/evaluation/hallucination_evaluator.py

-148
This file was deleted.

‎src/evaluation/ragas_evaulator.py

+214
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
from sentence_transformers import CrossEncoder
2+
from cache.cache import Cache
3+
from common.models_dimensions import RERANKER_MODEL_DIMENSIONS_MAP
4+
from common.result import Result
5+
import nltk
6+
from nltk.tokenize import sent_tokenize
7+
import ssl
8+
from mlx_lm import load, generate
9+
from common.utils import (
10+
get_answer_relevance_hash,
11+
get_faithfulness_hash,
12+
get_query_to_context_relevance_hash,
13+
)
14+
from vectorizer.vectorizer import Vectorizer
15+
16+
17+
class RAGASEvaluator:
18+
def __init__(
19+
self,
20+
reranker_model_name: str,
21+
cache: Cache,
22+
generator_model_name: str,
23+
vectorizer: Vectorizer,
24+
):
25+
try:
26+
_create_unverified_https_context = ssl._create_unverified_context
27+
except AttributeError:
28+
pass
29+
else:
30+
ssl._create_default_https_context = _create_unverified_https_context
31+
32+
nltk.download("punkt")
33+
34+
self.raranker_model_name = reranker_model_name
35+
self.raranker_model = CrossEncoder(
36+
reranker_model_name,
37+
max_length=RERANKER_MODEL_DIMENSIONS_MAP[reranker_model_name],
38+
)
39+
model, tokenizer = load(generator_model_name)
40+
self.generator_model = model
41+
self.generator_tokenizer = tokenizer
42+
self.cache = cache
43+
self.vectorizer = vectorizer
44+
pass
45+
46+
def context_precision(self, result: Result, correct_passage_id: str) -> float:
47+
48+
precisions_sum = 0
49+
50+
for i in range(1, len(result.passages) + 1):
51+
sub_precision = 0
52+
53+
for j in range(0, i):
54+
if result.passages[j][0].id == correct_passage_id:
55+
sub_precision += 1
56+
57+
relevant_at_i = (
58+
1 if result.passages[i - 1][0].id == correct_passage_id else 0
59+
)
60+
61+
precisions_sum += (sub_precision / i) * relevant_at_i
62+
63+
sum_of_relevant_passages = sum(
64+
[1 for passage in result.passages if passage[0].id == correct_passage_id]
65+
)
66+
67+
value = (
68+
precisions_sum / sum_of_relevant_passages
69+
if sum_of_relevant_passages > 0
70+
else 0
71+
)
72+
73+
return value
74+
75+
def context_recall(self, result: Result, correct_passage_id: str) -> float:
76+
is_any_relevant = False
77+
78+
for passage in result.passages:
79+
if passage[0].id == correct_passage_id:
80+
is_any_relevant = True
81+
break
82+
83+
return 1 if is_any_relevant else 0
84+
85+
def _get_relevancy_between_answer_and_context(
86+
self, answer: str, contexts: list[str]
87+
):
88+
pairs = [[answer, context] for context in contexts]
89+
results = self.raranker_model.predict(pairs)
90+
return 1 if max(results) > 0.5 else 0
91+
92+
def faithfulness(self, result: Result, answer: str) -> float:
93+
sentences = sent_tokenize(answer)
94+
context = " ".join([passage[0].context for passage in result.passages])
95+
context_sentences = sent_tokenize(context)
96+
97+
hash_key = get_faithfulness_hash(answer, context)
98+
99+
maybe_faithfulness = self.cache.get(hash_key)
100+
101+
if maybe_faithfulness:
102+
return float(maybe_faithfulness)
103+
104+
faihtfulness = sum(
105+
[
106+
self._get_relevancy_between_answer_and_context(
107+
sentence, context_sentences
108+
)
109+
for sentence in sentences
110+
]
111+
) / len(sentences)
112+
113+
self.cache.set(hash_key, str(faihtfulness))
114+
115+
return faihtfulness
116+
117+
def answer_relevance(self, original_question: str, answer: str) -> float:
118+
hash_key = get_answer_relevance_hash(original_question, answer)
119+
120+
maybe_answer_relevance = self.cache.get(hash_key)
121+
122+
if maybe_answer_relevance:
123+
return float(maybe_answer_relevance)
124+
125+
prompt = f"""
126+
[INST]
127+
Na podstawie podanego kontekstu wygeneruj trzy pytania, które mogłyby zostać zadane w kontekście tego tekstu.
128+
Zwróć pytania w formacie
129+
1. Pierwsze pytanie
130+
2. Drugie pytanie
131+
3. Trzecie pytanie
132+
133+
Każde pytanie musi być zakończone kropką i znajdować się w osobnej linii. Zwróć tylko pytania z numerami 1, 2 i 3.
134+
135+
Kontekst: {answer.replace("\n", " ")}
136+
[/INST]
137+
"""
138+
139+
generated_questions = generate(
140+
self.generator_model,
141+
self.generator_tokenizer,
142+
prompt=prompt,
143+
max_tokens=300,
144+
)
145+
146+
questions = generated_questions.split("\n")
147+
filtered_questions = [
148+
q.strip()
149+
for q in questions
150+
if q.strip() and len(q.strip()) > 0 and q.strip()[0] in {"1", "2", "3"}
151+
][:3]
152+
153+
original_question_vector = self.vectorizer.get_vector(
154+
f"zapytanie: {original_question}"
155+
)
156+
sentence_vectiors = [
157+
self.vectorizer.get_vector(f"zapytanie: {q}") for q in filtered_questions
158+
]
159+
results = [
160+
self.vectorizer.get_similarity(original_question_vector, v).item()
161+
for v in sentence_vectiors
162+
]
163+
164+
score = (
165+
(1 / len(filtered_questions)) * sum(results)
166+
if len(filtered_questions) > 0
167+
else 0
168+
)
169+
170+
self.cache.set(hash_key, str(score))
171+
172+
return score
173+
174+
def query_to_context_relevance(self, result: Result) -> float:
175+
hash_key = get_query_to_context_relevance_hash(result)
176+
177+
maybe_query_to_context_relevance = self.cache.get(hash_key)
178+
179+
if maybe_query_to_context_relevance:
180+
return float(maybe_query_to_context_relevance)
181+
182+
query_vector = self.vectorizer.get_vector(result.query)
183+
context_vectors = [
184+
self.vectorizer.get_vector(passage[0].context)
185+
for passage in result.passages
186+
]
187+
188+
result = [
189+
self.vectorizer.get_similarity(query_vector, v).item()
190+
for v in context_vectors
191+
]
192+
193+
score = (1 / len(context_vectors)) * sum(result)
194+
195+
self.cache.set(hash_key, str(score))
196+
197+
return score
198+
199+
def ragas(self, result: Result, correct_passage_id: str, answer: str) -> float:
200+
context_precision = self.context_precision(result, correct_passage_id)
201+
context_recall = self.context_recall(result, correct_passage_id)
202+
faithfulness = self.faithfulness(result, answer)
203+
answer_relevance = self.answer_relevance(result.query, answer)
204+
205+
return (
206+
context_precision + context_recall + faithfulness + answer_relevance
207+
) / 4
208+
209+
def hallucination(self, result: Result, answer: str) -> float:
210+
faithfulness = self.faithfulness(result, answer)
211+
answer_relevance = self.answer_relevance(result.query, answer)
212+
context_relevance = self.query_to_context_relevance(result)
213+
214+
return (faithfulness + answer_relevance + context_relevance) / 3

‎src/notebooks/00_test.ipynb

+99-104
Large diffs are not rendered by default.

‎src/notebooks/02_generators_evaluation.ipynb

+1,054-85
Large diffs are not rendered by default.

‎src/vectorizer/hf_vectorizer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
class HFVectorizer(Vectorizer):
1212
def __init__(self, model_name: str, cache: Cache):
13+
self.device = "mps"
1314
self.model_name = model_name
14-
self.model = SentenceTransformer(model_name)
15+
self.model = SentenceTransformer(model_name, device=self.device)
1516
self.max_seq_length = self.model.max_seq_length
1617
self.cache = cache
1718

@@ -32,7 +33,7 @@ def get_vector(self, query: str) -> Any:
3233
vector_json = json.dumps(vector_list)
3334
self.cache.set(hash_key, vector_json)
3435

35-
return hashed_vector
36+
return hashed_vector.to(self.device)
3637

3738
def get_similarity(self, vector1: Any, vector2: Any) -> float:
38-
return self.model.similarity(vector1, vector2)
39+
return self.model.similarity(vector1.to(self.device), vector2.to(self.device))

0 commit comments

Comments
 (0)
Please sign in to comment.