Skip to content

Commit 2e9db77

Browse files
liushzliushz
and
liushz
authoredSep 18, 2024··
[Feature] Add custom model postprocess function (#1519)
Co-authored-by: liushz <[email protected]>
1 parent c9a7026 commit 2e9db77

9 files changed

+653
-9
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess
5+
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
6+
from opencompass.utils.model_postprocessors import navie_model_postprocess
7+
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE
8+
9+
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
10+
11+
gsm8k_infer_cfg = dict(
12+
prompt_template=dict(
13+
type=PromptTemplate,
14+
template=dict(
15+
round=[
16+
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
17+
],
18+
),
19+
),
20+
retriever=dict(type=ZeroRetriever),
21+
inferencer=dict(type=GenInferencer, max_out_len=512),
22+
)
23+
24+
# # You can write your own postprocess prompt like:
25+
# GSM8K_NAVIE_PROMPT_TEMPLATE = """
26+
# There is a detailed explanation of the final answer you should extract:
27+
# 1. ...
28+
# 2. ...
29+
# ...
30+
# """
31+
32+
gsm8k_eval_cfg = dict(
33+
evaluator=dict(type=MATHEvaluator, version='v2'),
34+
pred_postprocessor=dict(type=math_postprocess_v2),
35+
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
36+
model_postprocessor=dict(
37+
type=navie_model_postprocess,
38+
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE,
39+
model_name='',
40+
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
41+
)
42+
43+
gsm8k_datasets = [
44+
dict(
45+
abbr='gsm8k',
46+
type=GSM8KDataset,
47+
path='opencompass/gsm8k',
48+
reader_cfg=gsm8k_reader_cfg,
49+
infer_cfg=gsm8k_infer_cfg,
50+
eval_cfg=gsm8k_eval_cfg,
51+
)
52+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import FixKRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
5+
from opencompass.datasets import MMLUDataset
6+
from opencompass.utils.text_postprocessors import first_option_postprocess
7+
from opencompass.utils.model_postprocessors import navie_model_postprocess
8+
from opencompass.utils.postprocessors.naive import OPTION_NAVIE_PROMPT_TEMPLATE
9+
10+
11+
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
12+
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
13+
14+
mmlu_reader_cfg = dict(
15+
input_columns=['input', 'A', 'B', 'C', 'D'],
16+
output_column='target',
17+
train_split='dev')
18+
19+
mmlu_all_sets = [
20+
'college_biology',
21+
'college_chemistry',
22+
'college_computer_science',
23+
'college_mathematics',
24+
'college_physics',
25+
'electrical_engineering',
26+
'astronomy',
27+
'anatomy',
28+
'abstract_algebra',
29+
'machine_learning',
30+
'clinical_knowledge',
31+
'global_facts',
32+
'management',
33+
'nutrition',
34+
'marketing',
35+
'professional_accounting',
36+
'high_school_geography',
37+
'international_law',
38+
'moral_scenarios',
39+
'computer_security',
40+
'high_school_microeconomics',
41+
'professional_law',
42+
'medical_genetics',
43+
'professional_psychology',
44+
'jurisprudence',
45+
'world_religions',
46+
'philosophy',
47+
'virology',
48+
'high_school_chemistry',
49+
'public_relations',
50+
'high_school_macroeconomics',
51+
'human_sexuality',
52+
'elementary_mathematics',
53+
'high_school_physics',
54+
'high_school_computer_science',
55+
'high_school_european_history',
56+
'business_ethics',
57+
'moral_disputes',
58+
'high_school_statistics',
59+
'miscellaneous',
60+
'formal_logic',
61+
'high_school_government_and_politics',
62+
'prehistory',
63+
'security_studies',
64+
'high_school_biology',
65+
'logical_fallacies',
66+
'high_school_world_history',
67+
'professional_medicine',
68+
'high_school_mathematics',
69+
'college_medicine',
70+
'high_school_us_history',
71+
'sociology',
72+
'econometrics',
73+
'high_school_psychology',
74+
'human_aging',
75+
'us_foreign_policy',
76+
'conceptual_physics',
77+
]
78+
79+
mmlu_datasets = []
80+
for _name in mmlu_all_sets:
81+
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
82+
mmlu_infer_cfg = dict(
83+
ice_template=dict(
84+
type=PromptTemplate,
85+
template=dict(round=[
86+
dict(
87+
role='HUMAN',
88+
prompt=
89+
f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
90+
),
91+
dict(role='BOT', prompt='{target}\n')
92+
]),
93+
),
94+
prompt_template=dict(
95+
type=PromptTemplate,
96+
template=dict(
97+
begin='</E>',
98+
round=[
99+
dict(
100+
role='HUMAN',
101+
prompt=f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
102+
),
103+
],
104+
),
105+
ice_token='</E>',
106+
),
107+
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
108+
inferencer=dict(type=GenInferencer),
109+
)
110+
111+
# # You can write your own postprocess prompt like:
112+
# MMLU_NAVIE_PROMPT_TEMPLATE = """
113+
# There is a detailed explanation of the final answer you should extract:
114+
# 1. ...
115+
# 2. ...
116+
# ...
117+
# """
118+
119+
mmlu_eval_cfg = dict(
120+
evaluator=dict(type=AccwithDetailsEvaluator),
121+
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
122+
model_postprocessor=dict(
123+
type=navie_model_postprocess,
124+
custom_instruction=OPTION_NAVIE_PROMPT_TEMPLATE,
125+
model_name='',
126+
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
127+
)
128+
129+
130+
mmlu_datasets.append(
131+
dict(
132+
abbr=f'lukaemon_mmlu_{_name}',
133+
type=MMLUDataset,
134+
path='opencompass/mmlu',
135+
name=_name,
136+
reader_cfg=mmlu_reader_cfg,
137+
infer_cfg=mmlu_infer_cfg,
138+
eval_cfg=mmlu_eval_cfg,
139+
))
140+
141+
del _name, _hint
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.datasets import GSM8KDataset, gsm8k_dataset_postprocess
5+
from opencompass.datasets import MATHEvaluator, math_postprocess_v2
6+
from opencompass.utils.model_postprocessors import navie_model_postprocess
7+
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE
8+
9+
gsm8k_reader_cfg = dict(input_columns=['question'], output_column='answer')
10+
11+
gsm8k_infer_cfg = dict(
12+
prompt_template=dict(
13+
type=PromptTemplate,
14+
template=dict(
15+
round=[
16+
dict(role='HUMAN', prompt='{question}\nPlease reason step by step, and put your final answer within \\boxed{}.'),
17+
],
18+
),
19+
),
20+
retriever=dict(type=ZeroRetriever),
21+
inferencer=dict(type=GenInferencer, max_out_len=512),
22+
)
23+
24+
# # You can write your own postprocess prompt like:
25+
# GSM8K_NAVIE_PROMPT_TEMPLATE = """
26+
# There is a detailed explanation of the final answer you should extract:
27+
# 1. ...
28+
# 2. ...
29+
# ...
30+
# """
31+
32+
gsm8k_eval_cfg = dict(
33+
evaluator=dict(type=MATHEvaluator, version='v2'),
34+
pred_postprocessor=dict(type=math_postprocess_v2),
35+
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
36+
model_postprocessor=dict(
37+
type=navie_model_postprocess,
38+
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE,
39+
model_name='',
40+
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
41+
)
42+
43+
gsm8k_datasets = [
44+
dict(
45+
abbr='gsm8k',
46+
type=GSM8KDataset,
47+
path='opencompass/gsm8k',
48+
reader_cfg=gsm8k_reader_cfg,
49+
infer_cfg=gsm8k_infer_cfg,
50+
eval_cfg=gsm8k_eval_cfg,
51+
)
52+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import FixKRetriever
3+
from opencompass.openicl.icl_inferencer import GenInferencer
4+
from opencompass.openicl.icl_evaluator import AccwithDetailsEvaluator
5+
from opencompass.datasets import MMLUDataset
6+
from opencompass.utils.text_postprocessors import first_option_postprocess
7+
from opencompass.utils.model_postprocessors import navie_model_postprocess
8+
from opencompass.utils.postprocessors.naive import OPTION_NAVIE_PROMPT_TEMPLATE
9+
10+
11+
# None of the mmlu dataset in huggingface is correctly parsed, so we use our own dataset reader
12+
# Please download the dataset from https://people.eecs.berkeley.edu/~hendrycks/data.tar
13+
14+
mmlu_reader_cfg = dict(
15+
input_columns=['input', 'A', 'B', 'C', 'D'],
16+
output_column='target',
17+
train_split='dev')
18+
19+
mmlu_all_sets = [
20+
'college_biology',
21+
'college_chemistry',
22+
'college_computer_science',
23+
'college_mathematics',
24+
'college_physics',
25+
'electrical_engineering',
26+
'astronomy',
27+
'anatomy',
28+
'abstract_algebra',
29+
'machine_learning',
30+
'clinical_knowledge',
31+
'global_facts',
32+
'management',
33+
'nutrition',
34+
'marketing',
35+
'professional_accounting',
36+
'high_school_geography',
37+
'international_law',
38+
'moral_scenarios',
39+
'computer_security',
40+
'high_school_microeconomics',
41+
'professional_law',
42+
'medical_genetics',
43+
'professional_psychology',
44+
'jurisprudence',
45+
'world_religions',
46+
'philosophy',
47+
'virology',
48+
'high_school_chemistry',
49+
'public_relations',
50+
'high_school_macroeconomics',
51+
'human_sexuality',
52+
'elementary_mathematics',
53+
'high_school_physics',
54+
'high_school_computer_science',
55+
'high_school_european_history',
56+
'business_ethics',
57+
'moral_disputes',
58+
'high_school_statistics',
59+
'miscellaneous',
60+
'formal_logic',
61+
'high_school_government_and_politics',
62+
'prehistory',
63+
'security_studies',
64+
'high_school_biology',
65+
'logical_fallacies',
66+
'high_school_world_history',
67+
'professional_medicine',
68+
'high_school_mathematics',
69+
'college_medicine',
70+
'high_school_us_history',
71+
'sociology',
72+
'econometrics',
73+
'high_school_psychology',
74+
'human_aging',
75+
'us_foreign_policy',
76+
'conceptual_physics',
77+
]
78+
79+
mmlu_datasets = []
80+
for _name in mmlu_all_sets:
81+
_hint = f'There is a single choice question about {_name.replace("_", " ")}. Answer the question by replying A, B, C or D.'
82+
mmlu_infer_cfg = dict(
83+
ice_template=dict(
84+
type=PromptTemplate,
85+
template=dict(round=[
86+
dict(
87+
role='HUMAN',
88+
prompt=
89+
f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
90+
),
91+
dict(role='BOT', prompt='{target}\n')
92+
]),
93+
),
94+
prompt_template=dict(
95+
type=PromptTemplate,
96+
template=dict(
97+
begin='</E>',
98+
round=[
99+
dict(
100+
role='HUMAN',
101+
prompt=f'{_hint}\nQuestion: {{input}}\nA. {{A}}\nB. {{B}}\nC. {{C}}\nD. {{D}}\nAnswer: '
102+
),
103+
],
104+
),
105+
ice_token='</E>',
106+
),
107+
retriever=dict(type=FixKRetriever, fix_id_list=[0, 1, 2, 3, 4]),
108+
inferencer=dict(type=GenInferencer),
109+
)
110+
111+
# # You can write your own postprocess prompt like:
112+
# MMLU_NAVIE_PROMPT_TEMPLATE = """
113+
# There is a detailed explanation of the final answer you should extract:
114+
# 1. ...
115+
# 2. ...
116+
# ...
117+
# """
118+
119+
mmlu_eval_cfg = dict(
120+
evaluator=dict(type=AccwithDetailsEvaluator),
121+
pred_postprocessor=dict(type=first_option_postprocess, options='ABCD'),
122+
model_postprocessor=dict(
123+
type=navie_model_postprocess,
124+
custom_instruction=OPTION_NAVIE_PROMPT_TEMPLATE,
125+
model_name='',
126+
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
127+
)
128+
129+
130+
mmlu_datasets.append(
131+
dict(
132+
abbr=f'lukaemon_mmlu_{_name}',
133+
type=MMLUDataset,
134+
path='opencompass/mmlu',
135+
name=_name,
136+
reader_cfg=mmlu_reader_cfg,
137+
infer_cfg=mmlu_infer_cfg,
138+
eval_cfg=mmlu_eval_cfg,
139+
))
140+
141+
del _name, _hint

‎opencompass/utils/model_postprocessors.py

+62-9
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,66 @@
66

77
from opencompass.registry import TEXT_POSTPROCESSORS
88

9+
from .postprocessors.naive import NaiveExtractor, format_input_naive
910
from .postprocessors.xfinder.extractor import Extractor
1011
from .postprocessors.xfinder.xfinder_utils import (DataProcessor,
1112
convert_to_xfinder_format)
1213

1314

14-
def gen_output(ori_data, extractor):
15+
def gen_output_naive(ori_data, extractor):
16+
extracted_answers = []
17+
for item in tqdm(ori_data):
18+
user_input = extractor.prepare_input(item)
19+
extracted_answer = extractor.gen_output(user_input)
20+
item['extracted_answer'] = extracted_answer
21+
extracted_answers.append(extracted_answer)
22+
23+
return extracted_answers
24+
25+
26+
@TEXT_POSTPROCESSORS.register_module('naive')
27+
def navie_model_postprocess(preds: list, model_name: str,
28+
custom_instruction: str, api_url: Union[str, list],
29+
**kwargs) -> list:
30+
"""Postprocess the text extracted by custom model.
31+
Args:
32+
preds (list): The question, reference answer and model prediction.
33+
model_name (str): The name of the model.
34+
custom_instruction (str): Custom instruction for the dataset.
35+
url (Union[str, list]): The api url of the model.
36+
37+
Returns:
38+
list: The postprocessed answers.
39+
"""
40+
41+
def _eval_pred(texts, extractor, num_processes=8):
42+
ori_data = texts
43+
extracted_answers = []
44+
batched_ori_data = []
45+
# Split data into batches
46+
num_processes = min(num_processes, len(ori_data))
47+
batch_size = len(ori_data) // num_processes
48+
for i in range(0, len(ori_data), batch_size):
49+
batched_ori_data.append(ori_data[i:i + batch_size])
50+
with Pool(num_processes) as p:
51+
results = p.map(partial(gen_output_naive, extractor=extractor),
52+
batched_ori_data)
53+
for result in results:
54+
extracted_answers.extend(result)
55+
return extracted_answers
56+
57+
format_data = format_input_naive(preds)
58+
assert api_url is not None, 'Please provide the api url.'
59+
extractor = NaiveExtractor(
60+
model_name=model_name,
61+
custom_instruction=custom_instruction,
62+
url=api_url.split(',') if ',' in api_url else api_url)
63+
calc_acc_func = partial(_eval_pred, extractor=extractor)
64+
extracted_answers = calc_acc_func(format_data)
65+
return extracted_answers
66+
67+
68+
def gen_output_xfinder(ori_data, extractor):
1569
ext_cor_pairs = []
1670
extracted_data = []
1771
extracted_answers = []
@@ -30,9 +84,8 @@ def gen_output(ori_data, extractor):
3084

3185

3286
@TEXT_POSTPROCESSORS.register_module('xfinder')
33-
def xfinder_postprocess(preds: list, question_type: str,
34-
xfinder_model_name: str,
35-
xfiner_api_url: Union[str, list], **kwargs) -> list:
87+
def xfinder_postprocess(preds: list, question_type: str, model_name: str,
88+
api_url: Union[str, list], **kwargs) -> list:
3689
"""Postprocess the text extracted by xFinder model.
3790
Args:
3891
preds (list): The question, reference answer and model prediction.
@@ -56,7 +109,7 @@ def _eval_pred(texts, data_processor, extractor, num_processes=8):
56109
for i in range(0, len(ori_data), batch_size):
57110
batched_ori_data.append(ori_data[i:i + batch_size])
58111
with Pool(num_processes) as p:
59-
results = p.map(partial(gen_output, extractor=extractor),
112+
results = p.map(partial(gen_output_xfinder, extractor=extractor),
60113
batched_ori_data)
61114
for result in results:
62115
extracted_answers += result[0]
@@ -65,11 +118,11 @@ def _eval_pred(texts, data_processor, extractor, num_processes=8):
65118
return extracted_answers
66119

67120
format_data = convert_to_xfinder_format(question_type, preds)
68-
assert xfiner_api_url is not None, 'Please provide the api url.'
121+
assert api_url is not None, 'Please provide the api url.'
69122
data_processor = DataProcessor()
70-
extractor = Extractor(model_name=xfinder_model_name,
71-
url=xfiner_api_url.split(',')
72-
if ',' in xfiner_api_url else xfiner_api_url)
123+
extractor = Extractor(
124+
model_name=model_name,
125+
url=api_url.split(',') if ',' in api_url else api_url)
73126
calc_acc_func = partial(_eval_pred,
74127
data_processor=data_processor,
75128
extractor=extractor)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
OPTION_NAVIE_PROMPT_TEMPLATE = """
2+
There is a detailed explanation of the final answer you should extract:
3+
1. You should extract the final answer option like 'A', 'B', 'C', 'D' ... from the given output sentences.
4+
2. The question is a single choice question, so the final answer option should be one of the options, not a combination of options.
5+
""" # noqa
6+
7+
MATH_NAVIE_PROMPT_TEMPLATE = """
8+
This is a detailed explanation of the final answer you should extract:
9+
1. The question type is a math question, so the final answer should be a number, set, vector, matrix, interval, expression, function, equation, or inequality and any combination of them.
10+
2. If the final answer includes additional symbols, such as units, you should exclude them and only extract the pure final answer.
11+
""" # noqa
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
## Short Usage Introduction for Naive Model Postprocessor with Custom Model
2+
3+
<!-- Now OC can use -->
4+
5+
### Step 1: Deploy an API server using vLLM or LMDeploy
6+
7+
```bash
8+
lmdeploy serve api_server meta-llama/Meta-Llama-3-8B-Instruct --model-name llama3-8b-instruct --server-port 23333 --backend turbomind --tp 1
9+
```
10+
11+
### Step 2: Add Naive Model Postprocessor to the configuration file
12+
13+
Take GSM8K as an example, you can add the following lines to the configuration file and replace the `api_url` with the correct address of the API server.
14+
15+
```python
16+
...
17+
from opencompass.utils.model_postprocessors import navie_model_postprocess
18+
from opencompass.utils.postprocessors.naive import MATH_NAVIE_PROMPT_TEMPLATE
19+
20+
...
21+
22+
gsm8k_eval_cfg = dict(
23+
evaluator=dict(type=MATHEvaluator, version='v2'),
24+
pred_postprocessor=dict(type=math_postprocess_v2),
25+
dataset_postprocessor=dict(type=gsm8k_dataset_postprocess),
26+
# Add the following line to use the naive model postprocessor
27+
model_postprocessor=dict(
28+
type=navie_model_postprocess,
29+
custom_instruction=MATH_NAVIE_PROMPT_TEMPLATE,
30+
model_name='llama3-8b-instruct',
31+
api_url='http://0.0.0.0:23333/v1,http://0.0.0.0:23334/v1')
32+
)
33+
...
34+
35+
```
36+
37+
The prompt for extraction can also be customized by changing the `custom_instruction` parameter. Now support two default templates: `MATH_NAVIE_PROMPT_TEMPLATE` for math problems extraction like GSM8K and MATH, and `OPTION_NAVIE_PROMPT_TEMPLATE` for option problems extraction like MMLU. You can also write your own prompt template, like:
38+
39+
```python
40+
OPTION_NAVIE_PROMPT_TEMPLATE = """
41+
There is a detailed explanation of the final answer you should extract:
42+
1. You should extract the final answer option like 'A', 'B', 'C', 'D' ... from the given output sentences.
43+
2. The question is a single choice question, so the final answer option should be one of the options, not a combination of options.
44+
"""
45+
```
46+
47+
Your prompt should start with `There is a detailed explanation of the final answer you should extract:` and following with your customized instructions.
48+
49+
### Step 3: Run the Evaluation as Usual
50+
51+
Now you can run the evaluation as usual with the configuration file you modified. The evaluation will use the custom model as the post-process model to get the final result. The final result will be the `model_postprocess_accuracy` in the evaluation result, like:
52+
53+
```Markdown
54+
dataset version metric mode llama-3-8b-instruct-turbomind
55+
------------------------------------------------- --------- -------------------------- ------ -------------------------------
56+
gsm8k a58960 accuracy gen 73.46
57+
gsm8k a58960 model_postprocess_accuracy gen 78.77
58+
```
59+
60+
## Experiment Results
61+
62+
We have tested the model postprocess method with different models (Qwen2-72B-Chat, Llama3-8b-Chat) as post-process model on the GSM8K, MMLU datasets for `Meta-Llama-3-8B-Instruct` with above settings, and the results are as follows:
63+
64+
```Markdown
65+
| Dataset | Type | Config ID | Regex Postprocess Score | Model Postprocess Score (Llama3-8b-Instruct) | Model Postprocess Score (Qwen2-72B-Chat) |
66+
| ------- | --------------- | ------------------------ | ----------------------- | ----------------------- |----------------------- |
67+
| gsm8k | math | a58960 | 73.46 | 79.08 | 78.77 |
68+
| mmlu | option | 4d595a | 67.89 | 65.26 | 67.94 |
69+
```
70+
71+
The `metric` column with `model_postprocess_accuracy` is the final result after the `Naive Model Postprocessor` is applied.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .extractor import * # noqa
2+
from .PROMPT_TEMPLATE import * # noqa
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Naive model extractor for OpenCompass, modified from xFinder: https://github.com/IAAR-Shanghai/xFinder # noqa
2+
import json
3+
import time
4+
from logging import getLogger
5+
6+
from openai import OpenAI
7+
8+
Meta_Instruction = """I will provide you with a question, output sentences along with an answer range. The output sentences are the response of the question provided. The answer range could either describe the type of answer expected or list all possible valid answers. Using the information provided, you must accurately and precisely determine and extract the intended key answer from the output sentences. Please don't have your subjective thoughts about the question.
9+
First, you need to determine whether the content of the output sentences is relevant to the given question. If the entire output sentences are unrelated to the question (meaning the output sentences are not addressing the question), then output [No valid answer].
10+
Otherwise, ignore the parts of the output sentences that have no relevance to the question and then extract the key answer that matches the answer range.
11+
Below are some special cases you need to be aware of:
12+
(1) If the output sentences present multiple different answers, carefully determine if the later provided answer is a correction or modification of a previous one. If so, extract this corrected or modified answer as the final response. Conversely, if the output sentences fluctuate between multiple answers without a clear final answer, you should output [No valid answer].
13+
(2) If the answer range is a list and the key answer in the output sentences is not explicitly listed among the candidate options in the answer range, also output [No valid answer].
14+
(3) You should only return the precise answer you extract, without processing the answer. Please return only the answer and do not add any additional content.
15+
16+
""" # noqa
17+
18+
19+
def format_input_naive(data):
20+
format_data = []
21+
for item in data:
22+
template = {}
23+
question = item['origin_prompt'][-1]['prompt']
24+
llm_output = item['prediction']
25+
correct_answer = item['reference'] if item['reference'] else item[
26+
'gold']
27+
template['correct_answer'] = correct_answer
28+
template['question'] = question
29+
template['llm_output'] = llm_output
30+
31+
format_data.append(template)
32+
return format_data
33+
34+
35+
class NaiveExtractor:
36+
37+
def __init__(
38+
self,
39+
model_name,
40+
model_path=None,
41+
url=None,
42+
temperature=0,
43+
max_tokens=3000,
44+
api_key='EMPTY',
45+
SYSTEM='You are a help assistant tasked with extracting the precise key answer from given output sentences. You must only provide the extracted key answer without including any additional text.', # noqa
46+
custom_instruction=''):
47+
self.model_name = model_name
48+
self.SYSTEM = SYSTEM
49+
self.model_path = model_path
50+
self.url = url
51+
self.api_key = api_key
52+
self.temperature = temperature
53+
self.max_tokens = max_tokens
54+
self.custom_instruction = custom_instruction
55+
self.logger = getLogger(__name__)
56+
57+
def prepare_input(self, item):
58+
user_input = Meta_Instruction + self.custom_instruction + \
59+
"Question: \"\"\"" + item['question'] + "\"\"\"\n\n" + \
60+
"Output sentences: \"\"\"" + item['llm_output'] + "\"\"\"\n\n" + \
61+
'Key extracted answer: '
62+
63+
return user_input
64+
65+
def gen_output(self, query):
66+
return self.openai_infer(query)
67+
68+
def openai_infer(self, query: str, retry=9) -> str:
69+
"""Perform inference on the OpenAI model.
70+
71+
Args:
72+
query (str): The input query.
73+
74+
Returns:
75+
str: The extracted answer (xFinder's output).
76+
"""
77+
if isinstance(self.url, list):
78+
# Randomly api for better load balancing
79+
import random
80+
self.url = random.choice(self.url)
81+
self.client = OpenAI(
82+
api_key=self.api_key,
83+
base_url=self.url,
84+
)
85+
self.retry = retry
86+
87+
t = time.time()
88+
retry = self.retry
89+
response = ''
90+
while retry > 0:
91+
try:
92+
chat_response = self.client.chat.completions.create(
93+
model=self.client.models.list().data[0].id
94+
if self.model_name == '' else self.model_name,
95+
messages=[
96+
{
97+
'role': 'system',
98+
'content': self.SYSTEM
99+
},
100+
{
101+
'role': 'user',
102+
'content': query
103+
},
104+
],
105+
temperature=self.temperature,
106+
max_tokens=self.max_tokens,
107+
)
108+
js_response = json.loads(chat_response.model_dump_json())
109+
response = js_response['choices'][0]['message']['content']
110+
break
111+
except Exception as e:
112+
self.logger.info(f'Error: {e}')
113+
self.logger.info(f'{self.url} is down. Retrying...')
114+
self.logger.info(f'Time elapsed: {time.time() - t} seconds')
115+
time.sleep(6)
116+
retry -= 1
117+
if retry == 0:
118+
response = 'Error: Failed to get response.'
119+
self.logger.info(f'{response} after {self.retry} tries.')
120+
raise ValueError('The api is down')
121+
return response.strip()

0 commit comments

Comments
 (0)
Please sign in to comment.