Skip to content

Commit 444d8d9

Browse files
Connor-ShenLeymore
and
Leymore
authoredFeb 6, 2024··
[feat] support multipl-e (#846)
* [feat] support humaneval_multipl-e * format --------- Co-authored-by: Leymore <[email protected]>
1 parent a6c49f1 commit 444d8d9

File tree

3 files changed

+268
-0
lines changed

3 files changed

+268
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from mmengine.config import read_base
2+
3+
with read_base():
4+
from .humaneval_multi_gen_82cf85 import humaneval_multi_datasets # noqa: F401, F403
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.datasets import HumanevalMultiDataset, HumanevalMultiEvaluator
5+
6+
humaneval_multi_reader_cfg = dict(input_columns=['prompt'], output_column='tests')
7+
8+
humaneval_multi_infer_cfg = dict(
9+
prompt_template=dict(type=PromptTemplate, template='{prompt}'),
10+
retriever=dict(type=ZeroRetriever),
11+
inferencer=dict(type=GenInferencer, max_out_len=1024),
12+
)
13+
14+
humaneval_multi_eval_cfg = {
15+
lang: dict(
16+
evaluator=dict(
17+
type=HumanevalMultiEvaluator,
18+
language=lang,
19+
ip_address='localhost', # replace to your code_eval_server ip_address, port
20+
port=5000,
21+
), # refer to https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html to launch a server
22+
pred_role='BOT',
23+
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
24+
}
25+
26+
'''there are four versions of humaneval-{LANG}-{version}.jsonl:
27+
['keep', 'transform', 'reworded', 'remove']
28+
SRCDATA-LANG-keep is the same as SRCDATA-LANG, but the text of the prompt is totally unchanged. If the original prompt had Python doctests, they remain as Python instead of being translated to LANG. If the original prompt had Python-specific terminology, e.g., 'list', it remains 'list', instead of being translated, e.g., to 'vector' for C++.
29+
SRCDATA-LANG-transform transforms the doctests to LANG but leaves the natural language text of the prompt unchanged.
30+
SRCDATA-LANG-reworded transforms both the doctests and the natural language text of the prompt to LANG.
31+
SRCDATA-LANG-remove removes the doctests from the prompt.
32+
'''
33+
34+
humaneval_multi_datasets = [
35+
dict(
36+
type=HumanevalMultiDataset,
37+
abbr=f'humaneval_multiple-{lang}',
38+
language=lang,
39+
version='reworded', # choose from ['keep', 'transform', 'reworded', 'remove']
40+
num_repeats=1,
41+
path='./data/multi-data/humaneval_multipl-e/',
42+
reader_cfg=humaneval_multi_reader_cfg,
43+
infer_cfg=humaneval_multi_infer_cfg,
44+
eval_cfg=humaneval_multi_eval_cfg[lang],
45+
) for lang in ['cpp', 'cs', 'd', 'go', 'java', 'jl', 'js', 'lua', 'php', 'pl', 'py', 'r', 'rb', 'rkt', 'rs', 'scala', 'sh', 'swift', 'ts']
46+
]
+218
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import gzip
2+
import json
3+
import os
4+
import os.path as osp
5+
import re
6+
import shutil
7+
import subprocess
8+
import tempfile
9+
import time
10+
11+
import numpy as np
12+
from datasets import Dataset
13+
14+
from opencompass.openicl.icl_evaluator import BaseEvaluator
15+
from opencompass.registry import LOAD_DATASET
16+
17+
from .base import BaseDataset
18+
19+
# currently supporting 19 languages
20+
_LANGUAGE_NAME_DICT = {
21+
'cpp': 'CPP',
22+
'cs': 'C#',
23+
'd': 'D',
24+
'go': 'Go',
25+
'java': 'Java',
26+
'jl': 'Julia',
27+
'js': 'JavaScript',
28+
'lua': 'Lua',
29+
'php': 'PHP',
30+
'pl': 'Perl',
31+
'py': 'Python',
32+
'r': 'R',
33+
'rb': 'Ruby',
34+
'rkt': 'Racket',
35+
'rs': 'Rust',
36+
'scala': 'Scala',
37+
'sh': 'Shell',
38+
'swift': 'Swift',
39+
'ts': 'TypeScript',
40+
}
41+
42+
43+
@LOAD_DATASET.register_module()
44+
class HumanevalMultiDataset(BaseDataset):
45+
46+
@staticmethod
47+
def load(path, language, version, num_repeats: int = 1, **kwargs):
48+
"""Load humaneval dataset for pass k mode.
49+
50+
Note that you can use num_repeats > 1 when your model does not support
51+
`num_return_sequence` in generation, otherwise use the raw
52+
humaneval dataset and set `num_return_sequence` in model config to
53+
generate multiple responses for testing pass@k>1.
54+
55+
It better to change your dataset abbr correspondingly if you want to
56+
change num_repeats>1, otherwise the number in
57+
`.cache/dataset_size.json` might be inconsistent.
58+
59+
Args:
60+
num_repeats(int): Number of repetition for this dataset to get
61+
multiple responses in special cases.
62+
"""
63+
assert language in _LANGUAGE_NAME_DICT.keys(), (
64+
f'language must be in {list(_LANGUAGE_NAME_DICT.keys())}')
65+
assert version in [
66+
'keep', 'transform', 'reworded', 'remove'
67+
], ('version must be in ["keep", "transform", "reworded", "remove"]')
68+
file_path = osp.join(path, f'humaneval-{language}-{version}.jsonl')
69+
dataset = []
70+
with open(file_path, 'r', encoding='utf-8') as f:
71+
for line in f:
72+
dataset.extend(
73+
[json.loads(line.strip()) for _ in range(num_repeats)])
74+
return Dataset.from_list(dataset)
75+
76+
77+
class HumanevalMultiEvaluator(BaseEvaluator):
78+
79+
def __init__(self,
80+
language,
81+
ip_address='localhost',
82+
port=5000,
83+
retry=2,
84+
timeout=600) -> None:
85+
self.language = language
86+
self.ip_address = ip_address
87+
self.port = port
88+
self.retry = retry
89+
self.timeout = timeout
90+
super().__init__()
91+
92+
def stop_at_stop_token(self, decoded_string, stop_tokens):
93+
"""Produces the prefix of decoded_string that ends at the first
94+
occurrence of a stop_token.
95+
96+
WARNING: the decoded_string *must not* include the prompt,
97+
which may have stop tokens itself.
98+
"""
99+
min_stop_index = len(decoded_string)
100+
for stop_token in stop_tokens:
101+
stop_index = decoded_string.find(stop_token)
102+
if stop_index != -1 and stop_index < min_stop_index:
103+
min_stop_index = stop_index
104+
return decoded_string[:min_stop_index]
105+
106+
def _code_eval_service(self, file_path):
107+
exec_result = subprocess.run([
108+
'curl', '-X', 'POST', '-F', f'file=@{file_path}', '-F',
109+
f'dataset=multipl-e/{self.language}',
110+
f'{self.ip_address}:{self.port}/evaluate'
111+
],
112+
timeout=self.timeout,
113+
capture_output=True)
114+
115+
if exec_result.returncode == 0 and re.match(
116+
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')):
117+
return True, json.loads(exec_result.stdout.decode('utf-8'))
118+
else:
119+
if exec_result.stderr:
120+
try:
121+
err = exec_result.stderr.decode()
122+
except Exception:
123+
err = exec_result.stderr
124+
else:
125+
try:
126+
err = exec_result.stdout.decode()
127+
except Exception:
128+
err = exec_result.stdout
129+
return False, err
130+
131+
def estimator(self, n: int, c: int, k: int) -> float:
132+
"""
133+
Calculates 1 - comb(n - c, k) / comb(n, k).
134+
"""
135+
if n - c < k:
136+
return 1.0
137+
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
138+
139+
def for_file(self, path):
140+
141+
try:
142+
with gzip.open(path, 'rt') as f:
143+
data = json.load(f)
144+
except Exception:
145+
return None
146+
147+
n = len(data['results'])
148+
c = len([
149+
True for r in data['results']
150+
if r['status'] == 'OK' and r['exit_code'] == 0
151+
])
152+
return {
153+
'pass@1': self.estimator(n, c, 1),
154+
'pass@10': self.estimator(n, c, 10),
155+
'pass@100': self.estimator(n, c, 100),
156+
'n': n,
157+
'c': c,
158+
}
159+
160+
def score(self, predictions, references, test_set):
161+
162+
stop_tokens = test_set['stop_tokens'][0]
163+
print(stop_tokens)
164+
165+
# convert to original version
166+
test_set = test_set.to_pandas()
167+
test_set_origin = test_set.drop_duplicates(subset='name')
168+
num_repeats = int(len(test_set) / len(test_set_origin))
169+
print(num_repeats)
170+
171+
# Create a temporary directory using the tempfile module
172+
with tempfile.TemporaryDirectory() as tmpdir:
173+
for i in range(len(test_set_origin)):
174+
completions = predictions[i * num_repeats:(i + 1) *
175+
num_repeats]
176+
processed_completions = []
177+
for comp in completions:
178+
comp = self.stop_at_stop_token(comp, stop_tokens)
179+
processed_completions.append(comp)
180+
181+
result_dict = {
182+
'name': test_set_origin.iloc[i]['name'],
183+
'language': test_set_origin.iloc[i]['language'],
184+
'prompt': test_set_origin.iloc[i]['prompt'],
185+
'tests': test_set_origin.iloc[i]['tests'],
186+
'completions': processed_completions
187+
}
188+
189+
json_str = json.dumps(result_dict)
190+
json_bytes = json_str.encode('utf-8')
191+
192+
with gzip.GzipFile(
193+
os.path.join(tmpdir, f'{result_dict["name"]}.json.gz'),
194+
'w') as f:
195+
f.write(json_bytes)
196+
197+
# create a zip file containing all the generated .json.gz files
198+
zipname = os.path.join(tmpdir, 'archive')
199+
shutil.make_archive(zipname, 'zip', tmpdir)
200+
zipfile_path = f'{zipname}.zip'
201+
202+
num_retry = 0
203+
while num_retry < self.retry:
204+
succeed, output = self._code_eval_service(
205+
file_path=zipfile_path)
206+
if not succeed and '(56) Recv failure' in output:
207+
# only retry when connection failed
208+
num_retry += 1
209+
# wait a min in case the service load is too high
210+
time.sleep(60)
211+
else:
212+
break
213+
214+
if succeed:
215+
if isinstance(output, str):
216+
return json.loads(output)
217+
elif isinstance(output, dict):
218+
return output

0 commit comments

Comments
 (0)
Please sign in to comment.