Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9eb5cad

Browse files
authoredDec 1, 2023
[Feat] update gsm8k and math agent config (#652)
* [Feat] update gsm8k and math agent config * minor fix
1 parent a331c9a commit 9eb5cad

File tree

3 files changed

+113
-6
lines changed

3 files changed

+113
-6
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from opencompass.openicl.icl_prompt_template import PromptTemplate
2+
from opencompass.openicl.icl_retriever import ZeroRetriever
3+
from opencompass.openicl.icl_inferencer import AgentInferencer
4+
from opencompass.datasets import (
5+
MATHDataset, MATHAgentEvaluator, math_postprocess
6+
)
7+
# use pal format but not perform well
8+
math_reader_cfg = dict(input_columns=['problem'], output_column='solution')
9+
10+
math_infer_cfg = dict(
11+
prompt_template=dict(
12+
type=PromptTemplate,
13+
template=dict(
14+
round=[
15+
# # ################################### NEW SHOT ###################################
16+
dict(role='HUMAN', prompt='Find the coefficient of $x^3$ when $3(x^2 - x^3+x) +3(x +2x^3- 3x^2 + 3x^5+x^3) -5(1+x-4x^3 - x^2)$ is simplifie.'),
17+
dict(role='BOT', prompt="""Tool:PythonInterpreter\nTool Input:from sympy import symbols, simplify
18+
19+
def solution():
20+
x = symbols('x')
21+
expr = 3*(x**2 - x**3 + x) + 3*(x + 2*x**3 - 3*x**2 + 3*x**5 + x**3) - 5*(1 + x - 4*x**3 - x**2)
22+
simplified_expr = simplify(expr)
23+
24+
x3_coefficient = simplified_expr.as_coefficients_dict()[x**3]
25+
result = x3_coefficient
26+
return result"""),
27+
dict(role='SYSTEM', prompt='Response:26'),
28+
dict(role='BOT', prompt='FinalAnswer: The final answer is $26$. I hope it is correct.'),
29+
dict(role='HUMAN', prompt='The surface area of a sphere with radius $r$ is $4\pi r^2$. Including the area of its circular base, what is the total surface area of a hemisphere with radius 6 cm? Express your answer in terms of $\pi$.'),
30+
dict(role='BOT', prompt="""Tool:PythonInterpreter\nTool Input:import math
31+
32+
def solution():
33+
radius = 6
34+
35+
# Surface area of the hemisphere
36+
hemisphere_area = 2 * math.pi * radius**2
37+
38+
# Area of the circular base
39+
base_area = math.pi * radius**2
40+
41+
# Total surface area
42+
total_surface_area = hemisphere_area + base_area
43+
44+
# Formatting the result in LaTeX
45+
result = r'{}\pi'.format(total_surface_area / math.pi)
46+
return result"""),
47+
dict(role='SYSTEM', prompt='Response:108.0\\pi'),
48+
dict(role='BOT', prompt='FinalAnswer: The final answer is $108.0\pi$. I hope it is correct.'),
49+
dict(role='HUMAN', prompt='Monica tosses a fair 6-sided die. If the roll is a prime number, then she wins that amount of dollars (so that, for example, if she rolls 3, then she wins 3 dollars). If the roll is composite, she wins nothing. Otherwise, she loses 3 dollars. What is the expected value of her winnings on one die toss? Express your answer as a dollar value to the nearest cent.'),
50+
dict(role='BOT', prompt="""Tool:PythonInterpreter\nTool Input:def solution():
51+
# Probabilities of each outcome
52+
prime_prob = 1 / 6
53+
composite_prob = 1 / 3
54+
otherwise_prob = 1 / 6
55+
56+
# Expected value of each outcome
57+
prime_expected_value = (2 * prime_prob) + (3 * prime_prob) + (5 * prime_prob)
58+
composite_expected_value = 0 * composite_prob
59+
otherwise_expected_value = -3 * otherwise_prob
60+
61+
# Total expected value
62+
total_expected_value = prime_expected_value + composite_expected_value + otherwise_expected_value
63+
64+
# Dollar value to the nearest cent
65+
result = "{:.2f}".format(total_expected_value)
66+
return result"""),
67+
dict(role='SYSTEM', prompt='Response:1.17'),
68+
dict(role='BOT', prompt='FinalAnswer: The final answer is $1.17$. I hope it is correct.'),
69+
dict(role='HUMAN', prompt='{problem}'),
70+
])),
71+
retriever=dict(type=ZeroRetriever),
72+
inferencer=dict(type=AgentInferencer),
73+
)
74+
75+
math_eval_cfg = dict(
76+
evaluator=dict(type=MATHAgentEvaluator),
77+
pred_postprocessor=dict(type=math_postprocess),
78+
)
79+
80+
math_datasets = [
81+
dict(
82+
abbr='math',
83+
type=MATHDataset,
84+
path='./data/math/math.json',
85+
reader_cfg=math_reader_cfg,
86+
infer_cfg=math_infer_cfg,
87+
eval_cfg=math_eval_cfg,
88+
)
89+
]

‎opencompass/datasets/gsm8k.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,10 @@ def gsm8k_postprocess(text: str) -> str:
4949
break
5050
ret1 = ''
5151
for i in range(len(ret)):
52-
if ret[i].isdigit():
52+
# deal with potential float number
53+
if ret[i].isdigit() or ret[i] == '.':
5354
ret1 += ret[i]
54-
return ret1
55+
return ret1.strip('.')
5556

5657

5758
class Gsm8kEvaluator(BaseEvaluator):
@@ -87,15 +88,23 @@ class Gsm8kAgentEvaluator(BaseEvaluator):
8788
def __init__(self, action: str = 'PythonInterpreter'):
8889
self.action = action
8990

91+
def is_equal(self, pred, refer):
92+
try:
93+
if pred == refer or abs(float(pred) - int(refer)) < 1e-6:
94+
return True
95+
except Exception:
96+
pass
97+
return False
98+
9099
def soft_equal(self, pred, refer, step):
91100
try:
92101
soft_pred = step['result']['text']
93-
if str(int(float(soft_pred))) == refer:
102+
if abs(float(soft_pred) - int(refer)) < 1e-6:
94103
return True
95104
except Exception:
96105
# result might not exists
97106
# text cannot convert to float
98-
print(pred, soft_pred, refer)
107+
pass
99108
return False
100109

101110
def get_action(self, step):
@@ -114,7 +123,7 @@ def score(self, predictions, references, steps):
114123
total = len(references)
115124
for pred, refer, step in zip(predictions, references, steps):
116125
# if final answer right
117-
if pred == refer:
126+
if self.is_equal(pred, refer):
118127
if self.get_action(step):
119128
final_scope += 1
120129
else:

‎opencompass/models/huggingface.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ class HuggingFace(BaseModel):
4949
use_fastchat_template (str, optional): Whether to use fastchat to get
5050
the conversation template. If True, fastchat needs to be
5151
implemented first. Defaults to False.
52+
end_str (str, optional): Whether to trim generated strings with end_str
53+
if the model has special ending strings that are not handled well.
54+
Defaults to None.
5255
5356
Note:
5457
About ``extract_pred_after_decode``: Commonly, we should extract the
@@ -72,7 +75,8 @@ def __init__(self,
7275
batch_padding: bool = False,
7376
pad_token_id: Optional[int] = None,
7477
mode: str = 'none',
75-
use_fastchat_template: bool = False):
78+
use_fastchat_template: bool = False,
79+
end_str: Optional[str] = None):
7680
super().__init__(path=path,
7781
max_seq_len=max_seq_len,
7882
tokenizer_only=tokenizer_only,
@@ -96,6 +100,7 @@ def __init__(self,
96100
peft_path=peft_path)
97101
self.generation_kwargs = generation_kwargs
98102
self.use_fastchat_template = use_fastchat_template
103+
self.end_str = end_str
99104

100105
def _load_tokenizer(self, path: str, tokenizer_path: Optional[str],
101106
tokenizer_kwargs: dict):
@@ -266,6 +271,8 @@ def _batch_generate(self, inputs: List[str], max_out_len: int,
266271
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
267272
]
268273

274+
if self.end_str:
275+
decodeds = [token.split(self.end_str)[0] for token in decodeds]
269276
return decodeds
270277

271278
def _single_generate(self, inputs: List[str], max_out_len: int,
@@ -329,6 +336,8 @@ def _single_generate(self, inputs: List[str], max_out_len: int,
329336
token[len_:] for token, len_ in zip(decodeds, prompt_lens)
330337
]
331338

339+
if self.end_str:
340+
decodeds = [token.split(self.end_str)[0] for token in decodeds]
332341
return decodeds
333342

334343
def get_logits(self, inputs: List[str]):

0 commit comments

Comments
 (0)
Please sign in to comment.