Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update medbench #678

Merged
merged 4 commits into from
Dec 9, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@ exclude: |
opencompass/utils/internal/|
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/
opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench
)
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -5,7 +5,8 @@ exclude: |
opencompass/utils/internal/|
opencompass/openicl/icl_evaluator/hf_metrics/|
opencompass/datasets/lawbench/utils|
opencompass/datasets/lawbench/evaluation_functions/
opencompass/datasets/lawbench/evaluation_functions/|
opencompass/datasets/medbench/
)
repos:
- repo: https://github.com/PyCQA/flake8
4 changes: 4 additions & 0 deletions configs/datasets/MedBench/medbench_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmengine.config import read_base

with read_base():
from .medbench_gen_d44f24 import medbench_datasets # noqa: F401, F403
160 changes: 160 additions & 0 deletions configs/datasets/MedBench/medbench_gen_d44f24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import AccEvaluator
from opencompass.datasets import (
MedBenchDataset,
MedBenchEvaluator,
MedBenchEvaluator_Cloze,
MedBenchEvaluator_IE,
MedBenchEvaluator_mcq,
MedBenchEvaluator_CMeEE,
MedBenchEvaluator_CMeIE,
MedBenchEvaluator_CHIP_CDEE,
MedBenchEvaluator_CHIP_CDN,
MedBenchEvaluator_CHIP_CTC,
MedBenchEvaluator_NLG,
MedBenchEvaluator_TF,
MedBenchEvaluator_EMR,
)
from opencompass.utils.text_postprocessors import first_capital_postprocess

medbench_reader_cfg = dict(
input_columns=['problem_input'], output_column='label')

medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断

medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答

medbench_cloze_sets = ['Triage'] # 限定域QA,有标答

medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答

medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价

#, 'CMeIE', 'CHIP_CDEE', 'CHIP_CDN', 'CHIP_CTC', 'Doc_parsing', 'MRG'

medbench_datasets = []


for name in medbench_single_choice_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))

medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_TF), pred_role="BOT")

medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))

for name in medbench_multiple_choices_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))

medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator), pred_role="BOT")

medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))

for name in medbench_qa_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))

medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_NLG), pred_role="BOT")

medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))

for name in medbench_cloze_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))

medbench_eval_cfg = dict(
evaluator=dict(type=MedBenchEvaluator_Cloze), pred_role="BOT")

medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))

for name in medbench_ie_sets:
medbench_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[dict(role="HUMAN", prompt='{problem_input}')])),
retriever=dict(type=ZeroRetriever
), # retriver 不起作用,以输入参数为准 (zero-shot / few-shot)
inferencer=dict(type=GenInferencer))

medbench_eval_cfg = dict(
evaluator=dict(type=eval('MedBenchEvaluator_'+name)), pred_role="BOT")

medbench_datasets.append(
dict(
type=MedBenchDataset,
path='./data/MedBench/' + name,
name=name,
abbr='medbench-' + name,
setting_name='zero-shot',
reader_cfg=medbench_reader_cfg,
infer_cfg=medbench_infer_cfg.copy(),
eval_cfg=medbench_eval_cfg.copy()))

del name, medbench_infer_cfg, medbench_eval_cfg
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,7 @@
from .math import * # noqa: F401, F403
from .mathbench import * # noqa: F401, F403
from .mbpp import * # noqa: F401, F403
from .medbench import * # noqa: F401, F403
from .mmlu import * # noqa: F401, F403
from .multirc import * # noqa: F401, F403
from .narrativeqa import * # noqa: F401, F403
3 changes: 3 additions & 0 deletions opencompass/datasets/medbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .medbench import * # noqa: F401, F403
104 changes: 104 additions & 0 deletions opencompass/datasets/medbench/constructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# flake8: noqa
import pandas as pd


class TaskSchema(object):

def __init__(self,
passage=None,
question=None,
options=None,
label=None,
answer=None,
other=None):
self.passage = passage
self.question = question
self.options = options
self.label = label
self.answer = answer
self.other = other

def to_dict(self):
return {
'passage': self.passage,
'question': self.question,
'options': self.options,
'label': self.label,
'answer': self.answer,
'other': self.other
}


# define README.json
class MedBenchInstance(object):

def __init__(self, task_description, data_source, task_schema, output,
evaluation_metric, task_example):
self.task_description = task_description
self.data_source = data_source
self.task_schema = task_schema
self.output = output
self.evaluation_metric = evaluation_metric
self.task_example = task_example

def to_dict(self):
return {
'task description': self.task_description,
'data source': self.data_source,
'task schema': self.task_schema.to_dict(),
'output': self.output,
'evaluation metric': self.evaluation_metric,
'task example': self.task_example
}


class ChatGPTSchema(object):

def __init__(self, context=None, metadata=''):
self.context = context
self.metadata = metadata

def to_dict(self):
return {'context': self.context, 'metadata': self.metadata}


class ResultsForHumanSchema(object):

def __init__(self,
index,
problem_input,
label,
model_input='',
model_output='',
parse_result='',
first_stage_output='',
second_stage_input='',
is_correct=False):
self.index = index
self.problem_input = problem_input
self.model_input = model_input
self.model_output = model_output
self.parse_result = parse_result
self.label = label
self.first_stage_output = first_stage_output
self.second_stage_input = second_stage_input
self.is_correct = is_correct

def to_dict(self):
return {
'index': self.index,
'problem_input': self.problem_input,
'model_input': self.model_input,
'model_output': self.model_output,
'parse_result': self.parse_result,
'label': self.label,
'is_correct': self.is_correct,
'first_stage_output': self.first_stage_output,
'second_stage_input': self.second_stage_input,
}

@staticmethod
def to_tsv(result_list, path):
result_json = [item.to_dict() for item in result_list]
table = pd.json_normalize(result_json)
table.to_excel(path, index=False)
338 changes: 338 additions & 0 deletions opencompass/datasets/medbench/dataset_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,338 @@
# flake8: noqa
import ast
import json
import os

import pandas as pd
import tiktoken
from tqdm import tqdm

from .constructions import ChatGPTSchema, ResultsForHumanSchema
from .utils import extract_answer, read_jsonl, save_jsonl

# define the datasets
medbench_multiple_choices_sets = ['Health_exam', 'DDx-basic', 'DDx-advanced_pre', 'DDx-advanced_final', 'SafetyBench'] # 选择题,用acc判断

medbench_qa_sets = ['Health_Counseling', 'Medicine_Counseling', 'MedDG', 'MedSpeQA', 'MedTreat', 'CMB-Clin'] # 开放式QA,有标答

medbench_cloze_sets = ['Triage'] # 限定域QA,有标答

medbench_single_choice_sets = ['Medicine_attack'] # 正确与否判断,有标答

medbench_ie_sets = ['EMR', 'CMeEE'] # 判断识别的实体是否一致,用F1评价

def convert_zero_shot(line, dataset_name):
# passage = line['passage'] if line['passage'] is not None else ''
if dataset_name in medbench_qa_sets:
return line['question']
elif dataset_name in medbench_cloze_sets:
return '问题:' + line['question'] + '\n答案:'
elif dataset_name in medbench_multiple_choices_sets:
return '问题:' + line['question'] + ' ' \
+ '选项:' + ' '.join(line['options']) + '\n从A到G,我们应该选择'
else:
return line['question']

prefix = '该问题为单选题,所有选项中必有一个正确答案,且只有一个正确答案。\n'


# def convert_zero_shot_CoT_stage1(line, dataset_name):
# try:
# passage = line['passage'] if line['passage'] is not None else ''
# if dataset_name in english_qa_datasets:
# return passage + 'Q: ' + line['question'] + ' ' \
# + 'Answer Choices: ' + ' '.join(line['options']) + '\n' + \
# "Let's think step by step."

# elif dataset_name in chinese_qa_datasets:
# option_string = 'ABCDEFG'
# count = len(line['options'])
# if count == 1:
# count = 4
# return passage + '问题:' + line['question'] + ' ' \
# + '选项:' + ' '.join(line['options']) + '\n' + \
# '从A到{}, 我们应选择什么?让我们逐步思考:'.format(option_string[count - 1])

# elif dataset_name in english_cloze_datasets:
# return passage + 'Q: ' + line['question'] + '\n' \
# "A: Let's think step by step."

# elif dataset_name in chinese_cloze_datasets:
# return passage + '问题:' + line['question'] + '\n' \
# '答案:让我们逐步思考:'
# except NameError:
# print('Dataset not defined.')


# process few-shot raw_prompts
def combine_prompt(prompt_path,
dataset_name,
load_explanation=True,
chat_mode=False):
skip_passage = False
if dataset_name == 'sat-en-without-passage':
skip_passage = True
dataset_name = 'sat-en'
demostrations = []
# read the prompts by context and explanation
context_row = [0, 1, 3, 5, 7, 9]
explanation_row = [0, 2, 4, 6, 8, 10]
raw_prompts_context = pd.read_csv(prompt_path,
header=0,
skiprows=lambda x: x not in context_row,
keep_default_na=False)
raw_prompts_explanation = pd.read_csv(
prompt_path,
header=0,
skiprows=lambda x: x not in explanation_row,
keep_default_na=False).replace(r'\n\n', '\n', regex=True)
contexts = []
for line in list(raw_prompts_context[dataset_name]):
if line:
# print(line)
contexts.append(ast.literal_eval(line))
explanations = [
exp for exp in raw_prompts_explanation[dataset_name] if exp
]

for idx, (con, exp) in enumerate(zip(contexts, explanations)):
passage = con['passage'] if con[
'passage'] is not None and not skip_passage else ''
question = con['question']
options = con['options'] if con['options'] is not None else ''
label = con['label'] if con['label'] is not None else ''
answer = con[
'answer'] if 'answer' in con and con['answer'] is not None else ''

if dataset_name in qa_datasets:
question_input = '问题 {}. '.format(idx + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(label)

elif dataset_name in cloze_datasets:
question_input = '问题 {}. '.format(idx + 1) + question + '\n'
question_output = (('问题 {}的解析: '.format(idx + 1) + exp + '\n') if load_explanation else '') \
+ '答案是 {}'.format(answer)
else:
raise ValueError(
f'During loading few-sot examples, found unknown dataset: {dataset_name}'
)
if chat_mode:
demostrations.append((question_input, question_output))
else:
demostrations.append(question_input + question_output + '\n')

return demostrations


enc = None


def _lazy_load_enc():
global enc
if enc is None:
enc = tiktoken.encoding_for_model('gpt-4')


# cut prompt if reach max token length
def concat_prompt(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
demostration_en = 'Here are the answers for the problems in the exam.\n'
demostration_zh = '以下是考试中各个问题的答案。\n'

for i in range(len(demos)):
# print(len(enc.encode(demostration_en)), len(enc.encode(demostration_zh)))
if dataset_name in english_qa_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_qa_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
elif dataset_name in english_cloze_datasets:
demostration_en = demostration_en + demos[i] + end_of_example
elif dataset_name in chinese_cloze_datasets:
demostration_zh = demostration_zh + demos[i] + end_of_example
# break if reach max token limit
if len(enc.encode(demostration_en)) < max_tokens and len(
enc.encode(demostration_zh)) < max_tokens:
output = demostration_en if len(demostration_en) > len(
demostration_zh) else demostration_zh
prompt_num = i + 1
else:
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(output)), 'num_shot is', prompt_num)
return output, prompt_num


def concat_prompt_chat_mode(demos,
dataset_name,
max_tokens,
end_of_example='\n',
verbose=False):
_lazy_load_enc()
answers = []
sentences = ''
for i in range(len(demos)):
answers += [
{
'role': 'user',
'content': demos[i][0]
},
{
'role': 'assistant',
'content': demos[i][1]
},
]
sentences += json.dumps(answers[-1])
# break if reach max token limit
if len(enc.encode(sentences)) > max_tokens:
answers.pop()
answers.pop()
break
if verbose:
print('max_tokens set as ', max_tokens, 'actual_tokens is',
len(enc.encode(sentences)), 'num_shot is',
len(answers) // 2)
return answers, len(answers) // 2


def convert_few_shot(line, dataset_name, demo, n_shot, chat_mode=False):
passage = line['passage'] if line['passage'] is not None else ''
question = line['question']
options = line['options'] if line['options'] is not None else ''

if dataset_name in qa_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + passage + ' ' + question + '\n' \
+ '从以下选项中选择: ' + ' '.join(options) + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)

if dataset_name in cloze_datasets:
question_input = '问题 {}. '.format(n_shot + 1) + question + '\n'
# + "问题 {}的解析: ".format(n_shot + 1)
if chat_mode:
return demo + [
{
'role': 'user',
'content': question_input
},
]
else:
return demo + question_input


def load_dataset(dataset_name,
setting_name,
parent_path,
prompt_path=None,
max_tokens=None,
end_of_example='\n',
chat_mode=False,
verbose=False):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)
processed = []
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
# process demo once if it is few-shot-CoT
processed_demos = combine_prompt(
prompt_path,
dataset_name,
load_explanation=setting_name == 'few-shot-CoT',
chat_mode=chat_mode)
if chat_mode:
chosen_prompt, n_shot = concat_prompt_chat_mode(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)
else:
chosen_prompt, n_shot = concat_prompt(processed_demos,
dataset_name,
max_tokens,
end_of_example,
verbose=verbose)

if verbose:
loaded_jsonl = tqdm(loaded_jsonl)
for meta_idx, line in enumerate(loaded_jsonl):
# 正确
if setting_name == 'zero-shot':
ctxt = convert_zero_shot(line, dataset_name)
elif setting_name == 'zero-shot-CoT':
ctxt = convert_zero_shot_CoT_stage1(line, dataset_name)
elif setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
ctxt = convert_few_shot(line, dataset_name, chosen_prompt, n_shot,
chat_mode)
try:
new_instance = ChatGPTSchema(context=ctxt, metadata=meta_idx)
processed.append(new_instance.to_dict())
except NameError:
print('Dataset not defined.')
return processed


def generate_second_stage_input(dataset_name,
input_list,
output_list,
with_format_prompt=False):
try:
chinese_format_prompt = '根据以上内容,你的任务是把最终的答案提取出来并填在【】中,例如【0】或者【A】。'
if dataset_name in qa_datasets:
prompt_suffix = '因此,从A到D, 我们应选择'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
elif dataset_name in cloze_datasets:
prompt_suffix = '因此,答案是'
if with_format_prompt:
prompt_suffix = chinese_format_prompt + prompt_suffix
except NameError:
print('Dataset not defined.')
processed = []
for i in range(len(input_list)):
ctxt = '{0}\n{1}\n{2}'.format(input_list[i]['context'],
extract_answer(output_list[i]),
prompt_suffix)
new_instance = ChatGPTSchema(context=ctxt,
metadata=input_list[i]['metadata'])
processed.append(new_instance.to_dict())
return processed


def load_dataset_as_result_schema(dataset_name, parent_path):
test_path = os.path.join(parent_path, dataset_name + '.jsonl')
loaded_jsonl = read_jsonl(test_path)

processed = []
for i, line in enumerate(loaded_jsonl):
problem_input = convert_zero_shot(line, dataset_name)
processed.append(
ResultsForHumanSchema(
index=i,
problem_input=problem_input,
# label=line['label'] if line['label'] else line['answer']
label = line['answer']
))
return processed


if __name__ == '__main__':
# set variables
parent_dir = '../../data/exam_guidance'

# set dataset name to process
setting_name = 'zero-shot' # setting_name can be chosen from ["zero-shot", "zero-shot-CoT", "few-shot-CoT"]
data_name = 'health_exam'
save_dir = '../../experiment_input/{}/'.format(setting_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
processed_data = load_dataset(data_name,
setting_name,
parent_dir,
prompt_path=raw_prompt_path,
max_tokens=2048)
save_jsonl(processed_data,
os.path.join(save_dir, '{}.jsonl'.format(data_name)))
43 changes: 43 additions & 0 deletions opencompass/datasets/medbench/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# flake8: noqa
from . import dataset_loader, utils
from .math_equivalence import is_equiv


def convert_to_set(item):
if isinstance(item, list):
return set(item)
if isinstance(item, str):
return {item}
if item is None:
return {}
raise ValueError("Input can't parse:", item)


def evaluate_single_sample(dataset_name, prediction, label):
if dataset_name in dataset_loader.multi_choice_datasets:
p = convert_to_set(prediction)
l = convert_to_set(label)
return p == l
elif dataset_name in dataset_loader.math_output_datasets:
return is_equiv(prediction, label)
else:
return prediction == label


# def evaluate(dataset_name, prediction_list, label_list):
# correct = 0
# if dataset_name in multi_choice_datasets:
# for prediction, label in zip(prediction_list, label_list):
# p = convert_to_set(prediction)
# l = convert_to_set(label)
# if p == l:
# correct += 1
# elif dataset_name in math_output_datasets:
# for prediction, label in zip(prediction_list, label_list):
# if is_equiv(prediction, label):
# correct += 1
# else:
# for prediction, label in zip(prediction_list, label_list):
# if prediction == label:
# correct += 1
# return "{0:.2%}".format(correct / len(label_list))
161 changes: 161 additions & 0 deletions opencompass/datasets/medbench/math_equivalence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# flake8: noqa


# code from https://github.com/hendrycks/math/blob/main/modeling/math_equivalence.py
def _fix_fracs(string):
substrs = string.split('\\frac')
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += '\\frac'
if substr[0] == '{':
new_str += substr
else:
try:
assert len(substr) >= 2
except:
return string
a = substr[0]
b = substr[1]
if b != '{':
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}{' + b + '}' + post_substr
else:
new_str += '{' + a + '}{' + b + '}'
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += '{' + a + '}' + b + post_substr
else:
new_str += '{' + a + '}' + b
string = new_str
return string


def _fix_a_slash_b(string):
if len(string.split('/')) != 2:
return string
a = string.split('/')[0]
b = string.split('/')[1]
try:
a = int(a)
b = int(b)
assert string == '{}/{}'.format(a, b)
new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
return new_string
except:
return string


def _remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if '\\text{ ' in string:
splits = string.split('\\text{ ')
assert len(splits) == 2
return splits[0]
else:
return string


def _fix_sqrt(string):
if '\\sqrt' not in string:
return string
splits = string.split('\\sqrt')
new_string = splits[0]
for split in splits[1:]:
if split[0] != '{':
a = split[0]
new_substr = '\\sqrt{' + a + '}' + split[1:]
else:
new_substr = '\\sqrt' + split
new_string += new_substr
return new_string


def _strip_string(string):
# linebreaks
string = string.replace('\n', '')
# print(string)

# remove inverse spaces
string = string.replace('\\!', '')
# print(string)

# replace \\ with \
string = string.replace('\\\\', '\\')
# print(string)

# replace tfrac and dfrac with frac
string = string.replace('tfrac', 'frac')
string = string.replace('dfrac', 'frac')
# print(string)

# remove \left and \right
string = string.replace('\\left', '')
string = string.replace('\\right', '')
# print(string)

# Remove circ (degrees)
string = string.replace('^{\\circ}', '')
string = string.replace('^\\circ', '')

# remove dollar signs
string = string.replace('\\$', '')

# remove units (on the right)
string = _remove_right_units(string)

# remove percentage
string = string.replace('\\%', '')
string = string.replace('\%', '')

# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(' .', ' 0.')
string = string.replace('{.', '{0.')
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == '.':
string = '0' + string

# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split('=')) == 2:
if len(string.split('=')[0]) <= 2:
string = string.split('=')[1]

# fix sqrt3 --> sqrt{3}
string = _fix_sqrt(string)

# remove spaces
string = string.replace(' ', '')

# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = _fix_fracs(string)

# manually change 0.5 --> \frac{1}{2}
if string == '0.5':
string = '\\frac{1}{2}'

# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = _fix_a_slash_b(string)

return string


def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print('WARNING: Both None')
return True
if str1 is None or str2 is None:
return False

try:
ss1 = _strip_string(str1)
ss2 = _strip_string(str2)
if verbose:
print(ss1, ss2)
return ss1 == ss2
except:
return str1 == str2
646 changes: 646 additions & 0 deletions opencompass/datasets/medbench/medbench.py

Large diffs are not rendered by default.

198 changes: 198 additions & 0 deletions opencompass/datasets/medbench/post_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# flake8: noqa
import json
import re

from . import dataset_loader


def extract_last_line(string):
lines = string.split('\n')
for item in lines[::-1]:
if item.strip() != '':
string = item
break
return string


def remove_few_shot_prefix(string: str):
prefix_list = ['The answer is therefore', '答案是']
for prefix in prefix_list:
if string.startswith(prefix):
string = string[len(prefix):].strip()
elif prefix in string:
index = string.rfind(prefix)
if index >= 0:
string = string[index + len(prefix):].strip()
return string


def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if language == 'en':
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
elif language == 'zh':
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
else:
raise ValueError('Unknown language {0}'.format(language))
if match:
return match.group(1)
else:
return None


def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if dataset_name in dataset_loader.chinese_cloze_datasets:
return string.startswith('答案是')
elif dataset_name in dataset_loader.english_cloze_datasets:
return string.startswith('The answer is therefore')
elif dataset_name in dataset_loader.chinese_qa_datasets:
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
return match is not None
elif dataset_name in dataset_loader.english_qa_datasets:
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
return match is not None
return False


def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
language)
if answer is None:
return find_first_capital_letter(string)
else:
return answer


def find_first_capital_letter(answer):
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
for c in answer:
if c in letter_set:
return c
# print("Can't find capital letter in:", answer)
return ''


def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
if prefix not in answer and suffix not in answer:
# print("doesn't found special tokens in:", answer)
return ''
s = answer.index(prefix) + len(prefix)
t = answer.index(suffix)
ret = answer[s:t]
return ret


def parse_math_answer(setting_name, raw_string):
if setting_name == 'few-shot-CoT':
raw_string = extract_last_line(raw_string)
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
raw_string = remove_few_shot_prefix(raw_string)
return raw_string

def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
answer = s[len(left):-1]
if '=' in answer:
answer = answer.split('=')[-1].lstrip(' ')
return answer
except:
return None

def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1

if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]

return retval

def get_answer_with_dollar_sign(s):
first_pattern = '\$(.*)\$'
last_match = None
matches = re.findall(first_pattern, s)
if matches:
last_match = matches[-1]
if '=' in last_match:
last_match = last_match.split('=')[-1].lstrip(' ')
return last_match

def get_answer_without_dollar_sign(s):
last_match = None
if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\\n' in last_match:
last_match = last_match.split('\\n')[0]
else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s)
if matches:
last_match = matches[-1]
return last_match

raw_string = remove_few_shot_prefix(raw_string)
if '\\boxed' in raw_string:
answer = remove_boxed(last_boxed_only_string(raw_string))
else:
answer = get_answer_with_dollar_sign(raw_string)
if not answer:
answer = get_answer_without_dollar_sign(raw_string)
return answer


def parse_qa_multiple_answer(string):
# if setting_name == 'few-shot-CoT':
# string = extract_last_line(string)
pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string)
if match:
return match
return []


def post_process(dataset_name, setting_name, prediction):
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
return parse_math_answer(setting_name, prediction)

if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
return parse_qa_multiple_answer(prediction, setting_name)

# all other datasets are QA problems with single answer
if 'zero-shot' in setting_name:
answer = find_first_capital_letter(prediction)
return answer

# all other datasets are QA problems with single answer and setting_name are few-shot
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
return parse_few_shot_qa_single_answer(prediction, setting_name,
language)
else:
raise ValueError(f'Unsupported dataset name {dataset_name}')
43 changes: 43 additions & 0 deletions opencompass/datasets/medbench/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# flake8: noqa
import json


def read_jsonl(path):
with open(path, encoding='utf8') as fh:
results = []
for line in fh:
if line is None:
continue
try:
results.append(json.loads(line) if line != 'null' else line)
except Exception as e:
print(e)
print(path)
print(line)
raise e
return results


def save_jsonl(lines, directory):
with open(directory, 'w', encoding='utf8') as f:
for line in lines:
f.write(json.dumps(line, ensure_ascii=False) + '\n')


def extract_answer(js):
try:
if js is None or js == 'null':
return ''
answer = ''
if isinstance(js, str):
answer = js
elif 'text' in js['choices'][0]:
answer = js['choices'][0]['text']
else:
answer = js['choices'][0]['message']['content']
# answer = js['']
return answer
except Exception as e:
# print(e)
# print(js)
return ''