Skip to content

Commit 7b58c67

Browse files
committedMar 21, 2024
add fine-tuning example
1 parent 198134a commit 7b58c67

5 files changed

+394
-73
lines changed
 

‎.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
env
22
.env
3-
.DS_STORE
3+
.DS_STORE
4+
output

‎01-SentenceSimilarity.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"source": [
1616
"## Installing a sample model\n",
1717
"\n",
18-
"We'll use **sentence-transformers/all-MiniLM-L6-v2**. It's a small and lightweight model that will be good for this particular showcase."
18+
"We'll use [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2). It's a small and lightweight model that will be good for this particular showcase."
1919
]
2020
},
2121
{

‎03-QuestionAnswering.ipynb

+10-69
Large diffs are not rendered by default.

‎04-FineTunningModel.ipynb

+378
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,378 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Fine-tunning a model\n",
8+
"\n",
9+
"In this example notebook we'll fine tune a [deepset/roberta-base-squad2](https://huggingface.co/deepset/roberta-base-squad2) with a [poquad dataset](https://huggingface.co/datasets/clarin-pl/poquad)."
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Downloading poquad dataset"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": 3,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"from datasets import load_dataset\n",
26+
"\n",
27+
"poquad = load_dataset(\"clarin-pl/poquad\")"
28+
]
29+
},
30+
{
31+
"cell_type": "markdown",
32+
"metadata": {},
33+
"source": [
34+
"## Downloading model"
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": 5,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"from transformers import AutoTokenizer\n",
44+
"\n",
45+
"model_name = 'deepset/roberta-base-squad2'\n",
46+
"\n",
47+
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
48+
]
49+
},
50+
{
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"## Adding preprocessing function\n",
55+
"It's from an hugging-face example. More information can be found here - https://huggingface.co/docs/transformers/tasks/question_answering"
56+
]
57+
},
58+
{
59+
"cell_type": "code",
60+
"execution_count": 6,
61+
"metadata": {},
62+
"outputs": [],
63+
"source": [
64+
"def preprocess_function(examples):\n",
65+
" questions = [q.strip() for q in examples[\"question\"]]\n",
66+
" inputs = tokenizer(\n",
67+
" questions,\n",
68+
" examples[\"context\"],\n",
69+
" max_length=384,\n",
70+
" truncation=\"only_second\",\n",
71+
" return_offsets_mapping=True,\n",
72+
" padding=\"max_length\",\n",
73+
" )\n",
74+
"\n",
75+
" offset_mapping = inputs.pop(\"offset_mapping\")\n",
76+
" answers = examples[\"answers\"]\n",
77+
" start_positions = []\n",
78+
" end_positions = []\n",
79+
"\n",
80+
" for i, offset in enumerate(offset_mapping):\n",
81+
" answer = answers[i]\n",
82+
" start_char = answer[\"answer_start\"][0]\n",
83+
" end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
84+
" sequence_ids = inputs.sequence_ids(i)\n",
85+
"\n",
86+
" # Find the start and end of the context\n",
87+
" idx = 0\n",
88+
" while sequence_ids[idx] != 1:\n",
89+
" idx += 1\n",
90+
" context_start = idx\n",
91+
" while sequence_ids[idx] == 1:\n",
92+
" idx += 1\n",
93+
" context_end = idx - 1\n",
94+
"\n",
95+
" # If the answer is not fully inside the context, label it (0, 0)\n",
96+
" if offset[context_start][0] > end_char or offset[context_end][1] < start_char:\n",
97+
" start_positions.append(0)\n",
98+
" end_positions.append(0)\n",
99+
" else:\n",
100+
" # Otherwise it's the start and end token positions\n",
101+
" idx = context_start\n",
102+
" while idx <= context_end and offset[idx][0] <= start_char:\n",
103+
" idx += 1\n",
104+
" start_positions.append(idx - 1)\n",
105+
"\n",
106+
" idx = context_end\n",
107+
" while idx >= context_start and offset[idx][1] >= end_char:\n",
108+
" idx -= 1\n",
109+
" end_positions.append(idx + 1)\n",
110+
"\n",
111+
" inputs[\"start_positions\"] = start_positions\n",
112+
" inputs[\"end_positions\"] = end_positions\n",
113+
" return inputs"
114+
]
115+
},
116+
{
117+
"cell_type": "markdown",
118+
"metadata": {},
119+
"source": [
120+
"## Tokenizing the dateset"
121+
]
122+
},
123+
{
124+
"cell_type": "code",
125+
"execution_count": 10,
126+
"metadata": {},
127+
"outputs": [
128+
{
129+
"data": {
130+
"application/vnd.jupyter.widget-view+json": {
131+
"model_id": "c39dae923e42422a93f19c38880e79c3",
132+
"version_major": 2,
133+
"version_minor": 0
134+
},
135+
"text/plain": [
136+
"Map: 0%| | 0/5764 [00:00<?, ? examples/s]"
137+
]
138+
},
139+
"metadata": {},
140+
"output_type": "display_data"
141+
}
142+
],
143+
"source": [
144+
"tokenized_poquad = poquad.map(preprocess_function, batched=True, remove_columns=poquad[\"train\"].column_names)"
145+
]
146+
},
147+
{
148+
"cell_type": "markdown",
149+
"metadata": {},
150+
"source": [
151+
"## Fine-tuning model"
152+
]
153+
},
154+
{
155+
"cell_type": "code",
156+
"execution_count": 9,
157+
"metadata": {},
158+
"outputs": [],
159+
"source": [
160+
"from transformers import DefaultDataCollator\n",
161+
"\n",
162+
"data_collator = DefaultDataCollator()"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": 14,
168+
"metadata": {},
169+
"outputs": [
170+
{
171+
"data": {
172+
"application/vnd.jupyter.widget-view+json": {
173+
"model_id": "5a3ff2dafc0a4438ae1df02497a71bbd",
174+
"version_major": 2,
175+
"version_minor": 0
176+
},
177+
"text/plain": [
178+
" 0%| | 0/8661 [00:00<?, ?it/s]"
179+
]
180+
},
181+
"metadata": {},
182+
"output_type": "display_data"
183+
},
184+
{
185+
"name": "stdout",
186+
"output_type": "stream",
187+
"text": [
188+
"{'loss': 2.366, 'grad_norm': 25.46889305114746, 'learning_rate': 1.884539891467498e-05, 'epoch': 0.17}\n",
189+
"{'loss': 2.019, 'grad_norm': 31.06574249267578, 'learning_rate': 1.769079782934996e-05, 'epoch': 0.35}\n",
190+
"{'loss': 1.8536, 'grad_norm': 25.971675872802734, 'learning_rate': 1.653619674402494e-05, 'epoch': 0.52}\n",
191+
"{'loss': 1.7774, 'grad_norm': 28.223947525024414, 'learning_rate': 1.538159565869992e-05, 'epoch': 0.69}\n",
192+
"{'loss': 1.7428, 'grad_norm': 25.615482330322266, 'learning_rate': 1.42269945733749e-05, 'epoch': 0.87}\n"
193+
]
194+
},
195+
{
196+
"data": {
197+
"application/vnd.jupyter.widget-view+json": {
198+
"model_id": "37460c5225d5443c84baba38778c5250",
199+
"version_major": 2,
200+
"version_minor": 0
201+
},
202+
"text/plain": [
203+
" 0%| | 0/361 [00:00<?, ?it/s]"
204+
]
205+
},
206+
"metadata": {},
207+
"output_type": "display_data"
208+
},
209+
{
210+
"name": "stdout",
211+
"output_type": "stream",
212+
"text": [
213+
"{'eval_loss': 1.5653996467590332, 'eval_runtime': 91.9805, 'eval_samples_per_second': 62.665, 'eval_steps_per_second': 3.925, 'epoch': 1.0}\n",
214+
"{'loss': 1.6227, 'grad_norm': 44.27663803100586, 'learning_rate': 1.3072393488049879e-05, 'epoch': 1.04}\n",
215+
"{'loss': 1.4481, 'grad_norm': 27.834014892578125, 'learning_rate': 1.1917792402724858e-05, 'epoch': 1.21}\n",
216+
"{'loss': 1.4219, 'grad_norm': 34.01536560058594, 'learning_rate': 1.076319131739984e-05, 'epoch': 1.39}\n",
217+
"{'loss': 1.431, 'grad_norm': 24.65727996826172, 'learning_rate': 9.60859023207482e-06, 'epoch': 1.56}\n",
218+
"{'loss': 1.4034, 'grad_norm': 43.74335861206055, 'learning_rate': 8.453989146749799e-06, 'epoch': 1.73}\n",
219+
"{'loss': 1.3625, 'grad_norm': 39.83943176269531, 'learning_rate': 7.299388061424778e-06, 'epoch': 1.91}\n"
220+
]
221+
},
222+
{
223+
"data": {
224+
"application/vnd.jupyter.widget-view+json": {
225+
"model_id": "bc208431a52148f0bcaa0d9a405baccc",
226+
"version_major": 2,
227+
"version_minor": 0
228+
},
229+
"text/plain": [
230+
" 0%| | 0/361 [00:00<?, ?it/s]"
231+
]
232+
},
233+
"metadata": {},
234+
"output_type": "display_data"
235+
},
236+
{
237+
"name": "stdout",
238+
"output_type": "stream",
239+
"text": [
240+
"{'eval_loss': 1.4093455076217651, 'eval_runtime': 89.9364, 'eval_samples_per_second': 64.09, 'eval_steps_per_second': 4.014, 'epoch': 2.0}\n",
241+
"{'loss': 1.282, 'grad_norm': 24.49555015563965, 'learning_rate': 6.144786976099758e-06, 'epoch': 2.08}\n",
242+
"{'loss': 1.187, 'grad_norm': 29.855417251586914, 'learning_rate': 4.990185890774737e-06, 'epoch': 2.25}\n",
243+
"{'loss': 1.1464, 'grad_norm': 21.781904220581055, 'learning_rate': 3.835584805449718e-06, 'epoch': 2.42}\n",
244+
"{'loss': 1.1625, 'grad_norm': 25.76011848449707, 'learning_rate': 2.680983720124697e-06, 'epoch': 2.6}\n",
245+
"{'loss': 1.1572, 'grad_norm': 40.99302673339844, 'learning_rate': 1.5263826347996768e-06, 'epoch': 2.77}\n",
246+
"{'loss': 1.147, 'grad_norm': 25.29827880859375, 'learning_rate': 3.7178154947465653e-07, 'epoch': 2.94}\n"
247+
]
248+
},
249+
{
250+
"data": {
251+
"application/vnd.jupyter.widget-view+json": {
252+
"model_id": "d68f21bbbe2348f8bc41b520f0f81e62",
253+
"version_major": 2,
254+
"version_minor": 0
255+
},
256+
"text/plain": [
257+
" 0%| | 0/361 [00:00<?, ?it/s]"
258+
]
259+
},
260+
"metadata": {},
261+
"output_type": "display_data"
262+
},
263+
{
264+
"name": "stdout",
265+
"output_type": "stream",
266+
"text": [
267+
"{'eval_loss': 1.4216176271438599, 'eval_runtime': 90.2408, 'eval_samples_per_second': 63.874, 'eval_steps_per_second': 4.0, 'epoch': 3.0}\n",
268+
"{'train_runtime': 8207.5606, 'train_samples_per_second': 16.882, 'train_steps_per_second': 1.055, 'train_loss': 1.4954243963770333, 'epoch': 3.0}\n"
269+
]
270+
},
271+
{
272+
"data": {
273+
"text/plain": [
274+
"TrainOutput(global_step=8661, training_loss=1.4954243963770333, metrics={'train_runtime': 8207.5606, 'train_samples_per_second': 16.882, 'train_steps_per_second': 1.055, 'train_loss': 1.4954243963770333, 'epoch': 3.0})"
275+
]
276+
},
277+
"execution_count": 14,
278+
"metadata": {},
279+
"output_type": "execute_result"
280+
}
281+
],
282+
"source": [
283+
"from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer\n",
284+
"\n",
285+
"model = AutoModelForQuestionAnswering.from_pretrained(model_name)\n",
286+
"\n",
287+
"training_args = TrainingArguments(\n",
288+
" output_dir=\"output/roberta-base-squad2-pl\",\n",
289+
" evaluation_strategy=\"epoch\",\n",
290+
" learning_rate=2e-5,\n",
291+
" per_device_train_batch_size=16,\n",
292+
" per_device_eval_batch_size=16,\n",
293+
" num_train_epochs=3,\n",
294+
" weight_decay=0.01,\n",
295+
" push_to_hub=False,\n",
296+
")\n",
297+
"\n",
298+
"trainer = Trainer(\n",
299+
" model=model,\n",
300+
" args=training_args,\n",
301+
" train_dataset=tokenized_poquad[\"train\"],\n",
302+
" eval_dataset=tokenized_poquad[\"validation\"],\n",
303+
" tokenizer=tokenizer,\n",
304+
" data_collator=data_collator,\n",
305+
")\n",
306+
"\n",
307+
"trainer.train()"
308+
]
309+
},
310+
{
311+
"cell_type": "markdown",
312+
"metadata": {},
313+
"source": [
314+
"## Testing new model"
315+
]
316+
},
317+
{
318+
"cell_type": "code",
319+
"execution_count": 17,
320+
"metadata": {},
321+
"outputs": [
322+
{
323+
"name": "stdout",
324+
"output_type": "stream",
325+
"text": [
326+
"{'score': 0.5960702300071716, 'start': 125, 'end': 145, 'answer': 'promieniotwórczością'}\n"
327+
]
328+
}
329+
],
330+
"source": [
331+
"from transformers import pipeline\n",
332+
"\n",
333+
"model = AutoModelForQuestionAnswering.from_pretrained(\"output/roberta-base-squad2-pl/checkpoint-8500\")\n",
334+
"nlp = pipeline('question-answering', model=model, tokenizer=tokenizer)\n",
335+
"\n",
336+
"context = 'Maria Skłodowska-Curie była polską i naturalizowaną francuską fizyczką i chemiczką, która prowadziła pionierskie badania nad promieniotwórczością. Była pierwszą kobietą, która zdobyła Nagrodę Nobla, pierwszą osobą i jedyną, która zdobyła Nagrody Nobla w dwóch różnych dziedzinach nauki, i była częścią rodziny Curie, która zdobyła pięć Nagród Nobla.'\n",
337+
"question = 'W jakiej dziedzinie Maria Curie prowadziła pionierskie badania?'\n",
338+
"\n",
339+
"result = nlp(question=question, context=context)\n",
340+
"\n",
341+
"print(result)"
342+
]
343+
},
344+
{
345+
"cell_type": "markdown",
346+
"metadata": {},
347+
"source": [
348+
"We can see a huge improvement in score. From 0.02 on base model to 0.59 on this fine-tuned model. However the answer it not correct, as it would need to be correctly conjugated to `promieniotwórczości` and not `promieniotwórczością`."
349+
]
350+
},
351+
{
352+
"cell_type": "markdown",
353+
"metadata": {},
354+
"source": []
355+
}
356+
],
357+
"metadata": {
358+
"kernelspec": {
359+
"display_name": "env",
360+
"language": "python",
361+
"name": "python3"
362+
},
363+
"language_info": {
364+
"codemirror_mode": {
365+
"name": "ipython",
366+
"version": 3
367+
},
368+
"file_extension": ".py",
369+
"mimetype": "text/x-python",
370+
"name": "python",
371+
"nbconvert_exporter": "python",
372+
"pygments_lexer": "ipython3",
373+
"version": "3.12.2"
374+
}
375+
},
376+
"nbformat": 4,
377+
"nbformat_minor": 2
378+
}

‎requirements.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ torchaudio
55
mlx # used for macos
66
jupyterlab
77
ipywidgets
8-
transformers
8+
transformers==4.38.1
99
datasets
1010
accelerate
1111
qdrant-client
1212
elasticsearch
13-
python-dotenv
13+
python-dotenv
14+
evaluate

0 commit comments

Comments
 (0)
Please sign in to comment.