Skip to content

Commit a331c9a

Browse files
liushzLeymore
and
Leymore
authoredDec 1, 2023
[Feature] Add wikibench dataset (#655)
* Add WikiBench * Add WikiBench * format --------- Co-authored-by: Leymore <[email protected]>
1 parent e019c83 commit a331c9a

File tree

4 files changed

+123
-0
lines changed

4 files changed

+123
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from mmengine.config import read_base
2+
3+
with read_base():
4+
from .wikibench_gen_f96ece import wikibench_datasets # noqa: F401, F403
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.openicl.icl_evaluator import CircularEvaluator, AccEvaluator
5+
from opencompass.datasets import WikiBenchDataset
6+
from opencompass.utils.text_postprocessors import first_option_postprocess
7+
8+
9+
single_choice_prompts = {
10+
"single_choice_cn": "以下是一道单项选择题,请你根据你了解的知识给出正确的答案选项。\n下面是你要回答的题目:\n{question}\n答案选项:",
11+
}
12+
13+
wikibench_sets = {
14+
"wiki": ["single_choice_cn"],
15+
}
16+
17+
do_circular = True
18+
19+
wikibench_datasets = []
20+
21+
for _split in list(wikibench_sets.keys()):
22+
for _name in wikibench_sets[_split]:
23+
wikibench_infer_cfg = dict(
24+
ice_template=dict(
25+
type=PromptTemplate,
26+
template=dict(
27+
begin="</E>",
28+
round=[
29+
dict(role="HUMAN", prompt=single_choice_prompts[_name]),
30+
dict(role="BOT", prompt="{answer}"),
31+
],
32+
),
33+
ice_token="</E>",
34+
),
35+
retriever=dict(type=ZeroRetriever),
36+
inferencer=dict(type=GenInferencer),
37+
)
38+
wikibench_eval_cfg = dict(
39+
evaluator=dict(type=CircularEvaluator if do_circular else AccEvaluator),
40+
pred_postprocessor=dict(type=first_option_postprocess, options="ABCD"),
41+
)
42+
43+
wikibench_datasets.append(
44+
dict(
45+
type=WikiBenchDataset,
46+
path=f"./data/WikiBench/{_name}.jsonl",
47+
name="circular_" + _name if do_circular else _name,
48+
abbr="wikibench-" + _split + "-" + _name + "circular" if do_circular else "",
49+
reader_cfg=dict(
50+
input_columns=["question"],
51+
output_column="answer",
52+
),
53+
infer_cfg=wikibench_infer_cfg,
54+
eval_cfg=wikibench_eval_cfg,
55+
)
56+
)

‎opencompass/datasets/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
from .truthfulqa import * # noqa: F401, F403
8787
from .tydiqa import * # noqa: F401, F403
8888
from .wic import * # noqa: F401, F403
89+
from .wikibench import * # noqa: F401, F403
8990
from .winograd import * # noqa: F401, F403
9091
from .winogrande import * # noqa: F401, F403
9192
from .wnli import wnliDataset # noqa: F401, F403

‎opencompass/datasets/wikibench.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import copy
2+
import json
3+
4+
from datasets import Dataset
5+
6+
from opencompass.registry import LOAD_DATASET
7+
8+
from .base import BaseDataset
9+
10+
11+
def get_number(options):
12+
13+
result_string = ''
14+
for i, option in enumerate(options, start=65):
15+
result_string += f'{chr(i)}. {option}\n'
16+
return result_string
17+
18+
19+
@LOAD_DATASET.register_module()
20+
class WikiBenchDataset(BaseDataset):
21+
22+
@staticmethod
23+
def load(path: str, name: str):
24+
25+
circular_patterns = ['ABCD', 'BCDA', 'CDAB', 'DABC']
26+
27+
data = []
28+
with open(path, 'r') as infile:
29+
for id, line in enumerate(infile):
30+
entry = json.loads(line)
31+
if 'cloze' in name:
32+
data.append({
33+
'question': entry['question'].strip(),
34+
'answer': entry['answer'].strip()
35+
})
36+
elif 'circular' in name:
37+
for c in circular_patterns:
38+
line = copy.deepcopy(entry)
39+
options = []
40+
for i in range(4):
41+
options.append(line['options'][ord(c[i]) -
42+
ord('A')])
43+
line['options'] = options
44+
line['answer'] = {
45+
c[0]: 'A',
46+
c[1]: 'B',
47+
c[2]: 'C',
48+
c[3]: 'D'
49+
}[line['answer']]
50+
line['answer'] = str(
51+
id) + '--' + line['answer'] + '--' + c
52+
line['question'] = line['question'].strip(
53+
) + '\n' + get_number(line['options'])
54+
data.append(line)
55+
else:
56+
# treat as normal single choice question
57+
entry['question'] = entry['question'].strip(
58+
) + '\n' + get_number(entry['options'])
59+
data.append(entry)
60+
61+
dataset = Dataset.from_list(data)
62+
return dataset

0 commit comments

Comments
 (0)
Please sign in to comment.