1
1
import pytest
2
2
3
- from agents import Agent , ModelSettings , Runner , Tool
4
- from agents ._run_impl import RunImpl
3
+ from agents import Agent , ModelSettings , Runner
4
+ from agents ._run_impl import AgentToolUseTracker , RunImpl
5
5
6
6
from .fake_model import FakeModel
7
- from .test_responses import (
8
- get_function_tool ,
9
- get_function_tool_call ,
10
- get_text_message ,
11
- )
7
+ from .test_responses import get_function_tool , get_function_tool_call , get_text_message
12
8
13
9
14
10
class TestToolChoiceReset :
15
-
16
11
def test_should_reset_tool_choice_direct (self ):
17
12
"""
18
13
Test the _should_reset_tool_choice method directly with various inputs
19
14
to ensure it correctly identifies cases where reset is needed.
20
15
"""
21
- # Case 1: tool_choice = None should not reset
16
+ agent = Agent (name = "test_agent" )
17
+
18
+ # Case 1: Empty tool use tracker should not change the "None" tool choice
22
19
model_settings = ModelSettings (tool_choice = None )
23
- tools1 : list [ Tool ] = [ get_function_tool ( "tool1" )]
24
- # Cast to list[Tool] to fix type checking issues
25
- assert not RunImpl . _should_reset_tool_choice ( model_settings , tools1 )
20
+ tracker = AgentToolUseTracker ()
21
+ new_settings = RunImpl . maybe_reset_tool_choice ( agent , tracker , model_settings )
22
+ assert new_settings . tool_choice == model_settings . tool_choice
26
23
27
- # Case 2: tool_choice = "auto" should not reset
24
+ # Case 2: Empty tool use tracker should not change the "auto" tool choice
28
25
model_settings = ModelSettings (tool_choice = "auto" )
29
- assert not RunImpl ._should_reset_tool_choice (model_settings , tools1 )
26
+ tracker = AgentToolUseTracker ()
27
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
28
+ assert model_settings .tool_choice == new_settings .tool_choice
30
29
31
- # Case 3: tool_choice = "none" should not reset
32
- model_settings = ModelSettings (tool_choice = "none" )
33
- assert not RunImpl ._should_reset_tool_choice (model_settings , tools1 )
30
+ # Case 3: Empty tool use tracker should not change the "required" tool choice
31
+ model_settings = ModelSettings (tool_choice = "required" )
32
+ tracker = AgentToolUseTracker ()
33
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
34
+ assert model_settings .tool_choice == new_settings .tool_choice
34
35
35
36
# Case 4: tool_choice = "required" with one tool should reset
36
37
model_settings = ModelSettings (tool_choice = "required" )
37
- assert RunImpl ._should_reset_tool_choice (model_settings , tools1 )
38
+ tracker = AgentToolUseTracker ()
39
+ tracker .add_tool_use (agent , ["tool1" ])
40
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
41
+ assert new_settings .tool_choice is None
38
42
39
- # Case 5: tool_choice = "required" with multiple tools should not reset
43
+ # Case 5: tool_choice = "required" with multiple tools should reset
40
44
model_settings = ModelSettings (tool_choice = "required" )
41
- tools2 : list [Tool ] = [get_function_tool ("tool1" ), get_function_tool ("tool2" )]
42
- assert not RunImpl ._should_reset_tool_choice (model_settings , tools2 )
43
-
44
- # Case 6: Specific tool choice should reset
45
- model_settings = ModelSettings (tool_choice = "specific_tool" )
46
- assert RunImpl ._should_reset_tool_choice (model_settings , tools1 )
45
+ tracker = AgentToolUseTracker ()
46
+ tracker .add_tool_use (agent , ["tool1" , "tool2" ])
47
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
48
+ assert new_settings .tool_choice is None
49
+
50
+ # Case 6: Tool usage on a different agent should not affect the tool choice
51
+ model_settings = ModelSettings (tool_choice = "foo_bar" )
52
+ tracker = AgentToolUseTracker ()
53
+ tracker .add_tool_use (Agent (name = "other_agent" ), ["foo_bar" , "baz" ])
54
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
55
+ assert new_settings .tool_choice == model_settings .tool_choice
56
+
57
+ # Case 7: tool_choice = "foo_bar" with multiple tools should reset
58
+ model_settings = ModelSettings (tool_choice = "foo_bar" )
59
+ tracker = AgentToolUseTracker ()
60
+ tracker .add_tool_use (agent , ["foo_bar" , "baz" ])
61
+ new_settings = RunImpl .maybe_reset_tool_choice (agent , tracker , model_settings )
62
+ assert new_settings .tool_choice is None
47
63
48
64
@pytest .mark .asyncio
49
65
async def test_required_tool_choice_with_multiple_runs (self ):
50
66
"""
51
- Test scenario 1: When multiple runs are executed with tool_choice="required"
52
- Ensure each run works correctly and doesn't get stuck in infinite loop
53
- Also verify that tool_choice remains "required" between runs
67
+ Test scenario 1: When multiple runs are executed with tool_choice="required", ensure each
68
+ run works correctly and doesn't get stuck in an infinite loop. Also verify that tool_choice
69
+ remains "required" between runs.
54
70
"""
55
71
# Set up our fake model with responses for two runs
56
72
fake_model = FakeModel ()
57
- fake_model .add_multiple_turn_outputs ([
58
- [get_text_message ("First run response" )],
59
- [get_text_message ("Second run response" )]
60
- ])
73
+ fake_model .add_multiple_turn_outputs (
74
+ [[get_text_message ("First run response" )], [get_text_message ("Second run response" )]]
75
+ )
61
76
62
77
# Create agent with a custom tool and tool_choice="required"
63
78
custom_tool = get_function_tool ("custom_tool" )
@@ -71,24 +86,26 @@ async def test_required_tool_choice_with_multiple_runs(self):
71
86
# First run should work correctly and preserve tool_choice
72
87
result1 = await Runner .run (agent , "first run" )
73
88
assert result1 .final_output == "First run response"
74
- assert agent .model_settings .tool_choice == "required" , "tool_choice should stay required"
89
+ assert fake_model .last_turn_args ["model_settings" ].tool_choice == "required" , (
90
+ "tool_choice should stay required"
91
+ )
75
92
76
93
# Second run should also work correctly with tool_choice still required
77
94
result2 = await Runner .run (agent , "second run" )
78
95
assert result2 .final_output == "Second run response"
79
- assert agent .model_settings .tool_choice == "required" , "tool_choice should stay required"
96
+ assert fake_model .last_turn_args ["model_settings" ].tool_choice == "required" , (
97
+ "tool_choice should stay required"
98
+ )
80
99
81
100
@pytest .mark .asyncio
82
101
async def test_required_with_stop_at_tool_name (self ):
83
102
"""
84
- Test scenario 2: When using required tool_choice with stop_at_tool_names behavior
85
- Ensure it correctly stops at the specified tool
103
+ Test scenario 2: When using required tool_choice with stop_at_tool_names behavior, ensure
104
+ it correctly stops at the specified tool
86
105
"""
87
106
# Set up fake model to return a tool call for second_tool
88
107
fake_model = FakeModel ()
89
- fake_model .set_next_output ([
90
- get_function_tool_call ("second_tool" , "{}" )
91
- ])
108
+ fake_model .set_next_output ([get_function_tool_call ("second_tool" , "{}" )])
92
109
93
110
# Create agent with two tools and tool_choice="required" and stop_at_tool behavior
94
111
first_tool = get_function_tool ("first_tool" , return_value = "first tool result" )
@@ -109,8 +126,8 @@ async def test_required_with_stop_at_tool_name(self):
109
126
@pytest .mark .asyncio
110
127
async def test_specific_tool_choice (self ):
111
128
"""
112
- Test scenario 3: When using a specific tool choice name
113
- Ensure it doesn't cause infinite loops
129
+ Test scenario 3: When using a specific tool choice name, ensure it doesn't cause infinite
130
+ loops.
114
131
"""
115
132
# Set up fake model to return a text message
116
133
fake_model = FakeModel ()
@@ -135,17 +152,19 @@ async def test_specific_tool_choice(self):
135
152
@pytest .mark .asyncio
136
153
async def test_required_with_single_tool (self ):
137
154
"""
138
- Test scenario 4: When using required tool_choice with only one tool
139
- Ensure it doesn't cause infinite loops
155
+ Test scenario 4: When using required tool_choice with only one tool, ensure it doesn't cause
156
+ infinite loops.
140
157
"""
141
158
# Set up fake model to return a tool call followed by a text message
142
159
fake_model = FakeModel ()
143
- fake_model .add_multiple_turn_outputs ([
144
- # First call returns a tool call
145
- [get_function_tool_call ("custom_tool" , "{}" )],
146
- # Second call returns a text message
147
- [get_text_message ("Final response" )]
148
- ])
160
+ fake_model .add_multiple_turn_outputs (
161
+ [
162
+ # First call returns a tool call
163
+ [get_function_tool_call ("custom_tool" , "{}" )],
164
+ # Second call returns a text message
165
+ [get_text_message ("Final response" )],
166
+ ]
167
+ )
149
168
150
169
# Create agent with a single tool and tool_choice="required"
151
170
custom_tool = get_function_tool ("custom_tool" , return_value = "tool result" )
@@ -159,3 +178,33 @@ async def test_required_with_single_tool(self):
159
178
# Run should complete without infinite loops
160
179
result = await Runner .run (agent , "first run" )
161
180
assert result .final_output == "Final response"
181
+
182
+ @pytest .mark .asyncio
183
+ async def test_dont_reset_tool_choice_if_not_required (self ):
184
+ """
185
+ Test scenario 5: When agent.reset_tool_choice is False, ensure tool_choice is not reset.
186
+ """
187
+ # Set up fake model to return a tool call followed by a text message
188
+ fake_model = FakeModel ()
189
+ fake_model .add_multiple_turn_outputs (
190
+ [
191
+ # First call returns a tool call
192
+ [get_function_tool_call ("custom_tool" , "{}" )],
193
+ # Second call returns a text message
194
+ [get_text_message ("Final response" )],
195
+ ]
196
+ )
197
+
198
+ # Create agent with a single tool and tool_choice="required" and reset_tool_choice=False
199
+ custom_tool = get_function_tool ("custom_tool" , return_value = "tool result" )
200
+ agent = Agent (
201
+ name = "test_agent" ,
202
+ model = fake_model ,
203
+ tools = [custom_tool ],
204
+ model_settings = ModelSettings (tool_choice = "required" ),
205
+ reset_tool_choice = False ,
206
+ )
207
+
208
+ await Runner .run (agent , "test" )
209
+
210
+ assert fake_model .last_turn_args ["model_settings" ].tool_choice == "required"
0 commit comments