Skip to content

Commit 34d46cb

Browse files
committed
Ensure that input_guardrails can block tools from running
1 parent d88bf14 commit 34d46cb

File tree

5 files changed

+423
-14
lines changed

5 files changed

+423
-14
lines changed

src/agents/guardrail.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,12 @@ class InputGuardrail(Generic[TContext]):
9696
function's name.
9797
"""
9898

99+
block_tool_calls: bool = True
100+
"""Whether this guardrail should block tool calls until it completes. If any input guardrail
101+
has this set to True, tool execution will be delayed until all blocking guardrails finish.
102+
Defaults to True for backwards compatibility and safety.
103+
"""
104+
99105
def get_name(self) -> str:
100106
if self.name:
101107
return self.name
@@ -208,6 +214,7 @@ def input_guardrail(
208214
def input_guardrail(
209215
*,
210216
name: str | None = None,
217+
block_tool_calls: bool = True,
211218
) -> Callable[
212219
[_InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co]],
213220
InputGuardrail[TContext_co],
@@ -220,6 +227,7 @@ def input_guardrail(
220227
| None = None,
221228
*,
222229
name: str | None = None,
230+
block_tool_calls: bool = True,
223231
) -> (
224232
InputGuardrail[TContext_co]
225233
| Callable[
@@ -234,14 +242,14 @@ def input_guardrail(
234242
@input_guardrail
235243
def my_sync_guardrail(...): ...
236244
237-
@input_guardrail(name="guardrail_name")
245+
@input_guardrail(name="guardrail_name", block_tool_calls=False)
238246
async def my_async_guardrail(...): ...
239247
"""
240248

241249
def decorator(
242250
f: _InputGuardrailFuncSync[TContext_co] | _InputGuardrailFuncAsync[TContext_co],
243251
) -> InputGuardrail[TContext_co]:
244-
return InputGuardrail(guardrail_function=f, name=name)
252+
return InputGuardrail(guardrail_function=f, name=name, block_tool_calls=block_tool_calls)
245253

246254
if func is not None:
247255
# Decorator was used without parentheses

src/agents/run.py

Lines changed: 162 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,15 +394,26 @@ async def run(
394394
)
395395

396396
if current_turn == 1:
397-
input_guardrail_results, turn_result = await asyncio.gather(
397+
# Separate blocking and non-blocking guardrails
398+
all_guardrails = starting_agent.input_guardrails + (
399+
run_config.input_guardrails or []
400+
)
401+
blocking_guardrails, non_blocking_guardrails = (
402+
self._separate_blocking_guardrails(all_guardrails)
403+
)
404+
405+
# Start all guardrails and model call in parallel
406+
all_guardrails_task = asyncio.create_task(
398407
self._run_input_guardrails(
399408
starting_agent,
400-
starting_agent.input_guardrails
401-
+ (run_config.input_guardrails or []),
409+
all_guardrails,
402410
copy.deepcopy(input),
403411
context_wrapper,
404-
),
405-
self._run_single_turn(
412+
)
413+
)
414+
415+
model_response_task = asyncio.create_task(
416+
self._get_model_response_only(
406417
agent=current_agent,
407418
all_tools=all_tools,
408419
original_input=original_input,
@@ -413,8 +424,52 @@ async def run(
413424
should_run_agent_start_hooks=should_run_agent_start_hooks,
414425
tool_use_tracker=tool_use_tracker,
415426
previous_response_id=previous_response_id,
416-
),
427+
)
417428
)
429+
430+
# Get model response first (runs in parallel with guardrails)
431+
model_response, output_schema, handoffs = await model_response_task
432+
433+
# Now handle tool execution based on blocking behavior
434+
if blocking_guardrails:
435+
# Wait for all guardrails to complete before executing tools
436+
input_guardrail_results = await all_guardrails_task
437+
438+
# Now execute tools after guardrails complete
439+
turn_result = await self._execute_tools_from_model_response(
440+
agent=current_agent,
441+
all_tools=all_tools,
442+
original_input=original_input,
443+
generated_items=generated_items,
444+
new_response=model_response,
445+
output_schema=output_schema,
446+
handoffs=handoffs,
447+
hooks=hooks,
448+
context_wrapper=context_wrapper,
449+
run_config=run_config,
450+
tool_use_tracker=tool_use_tracker,
451+
)
452+
else:
453+
# No blocking guardrails - execute tools in parallel with remaining guardrails
454+
tool_execution_task = asyncio.create_task(
455+
self._execute_tools_from_model_response(
456+
agent=current_agent,
457+
all_tools=all_tools,
458+
original_input=original_input,
459+
generated_items=generated_items,
460+
new_response=model_response,
461+
output_schema=output_schema,
462+
handoffs=handoffs,
463+
hooks=hooks,
464+
context_wrapper=context_wrapper,
465+
run_config=run_config,
466+
tool_use_tracker=tool_use_tracker,
467+
)
468+
)
469+
470+
input_guardrail_results, turn_result = await asyncio.gather(
471+
all_guardrails_task, tool_execution_task
472+
)
418473
else:
419474
turn_result = await self._run_single_turn(
420475
agent=current_agent,
@@ -973,6 +1028,107 @@ async def _get_single_step_result_from_response(
9731028
run_config=run_config,
9741029
)
9751030

1031+
@classmethod
1032+
def _separate_blocking_guardrails(
1033+
cls,
1034+
guardrails: list[InputGuardrail[TContext]],
1035+
) -> tuple[list[InputGuardrail[TContext]], list[InputGuardrail[TContext]]]:
1036+
"""Separate guardrails into blocking and non-blocking lists."""
1037+
blocking = []
1038+
non_blocking = []
1039+
1040+
for guardrail in guardrails:
1041+
if guardrail.block_tool_calls:
1042+
blocking.append(guardrail)
1043+
else:
1044+
non_blocking.append(guardrail)
1045+
1046+
return blocking, non_blocking
1047+
1048+
@classmethod
1049+
async def _get_model_response_only(
1050+
cls,
1051+
*,
1052+
agent: Agent[TContext],
1053+
all_tools: list[Tool],
1054+
original_input: str | list[TResponseInputItem],
1055+
generated_items: list[RunItem],
1056+
hooks: RunHooks[TContext],
1057+
context_wrapper: RunContextWrapper[TContext],
1058+
run_config: RunConfig,
1059+
should_run_agent_start_hooks: bool,
1060+
tool_use_tracker: AgentToolUseTracker,
1061+
previous_response_id: str | None,
1062+
) -> tuple[ModelResponse, AgentOutputSchemaBase | None, list[Handoff]]:
1063+
"""Get model response without executing tools. Returns model response and processed metadata."""
1064+
# Ensure we run the hooks before anything else
1065+
if should_run_agent_start_hooks:
1066+
await asyncio.gather(
1067+
hooks.on_agent_start(context_wrapper, agent),
1068+
(
1069+
agent.hooks.on_start(context_wrapper, agent)
1070+
if agent.hooks
1071+
else _coro.noop_coroutine()
1072+
),
1073+
)
1074+
1075+
system_prompt, prompt_config = await asyncio.gather(
1076+
agent.get_system_prompt(context_wrapper),
1077+
agent.get_prompt(context_wrapper),
1078+
)
1079+
1080+
output_schema = cls._get_output_schema(agent)
1081+
handoffs = await cls._get_handoffs(agent, context_wrapper)
1082+
input = ItemHelpers.input_to_new_input_list(original_input)
1083+
input.extend([generated_item.to_input_item() for generated_item in generated_items])
1084+
1085+
new_response = await cls._get_new_response(
1086+
agent,
1087+
system_prompt,
1088+
input,
1089+
output_schema,
1090+
all_tools,
1091+
handoffs,
1092+
context_wrapper,
1093+
run_config,
1094+
tool_use_tracker,
1095+
previous_response_id,
1096+
prompt_config,
1097+
)
1098+
1099+
return new_response, output_schema, handoffs
1100+
1101+
@classmethod
1102+
async def _execute_tools_from_model_response(
1103+
cls,
1104+
*,
1105+
agent: Agent[TContext],
1106+
all_tools: list[Tool],
1107+
original_input: str | list[TResponseInputItem],
1108+
generated_items: list[RunItem],
1109+
new_response: ModelResponse,
1110+
output_schema: AgentOutputSchemaBase | None,
1111+
handoffs: list[Handoff],
1112+
hooks: RunHooks[TContext],
1113+
context_wrapper: RunContextWrapper[TContext],
1114+
run_config: RunConfig,
1115+
tool_use_tracker: AgentToolUseTracker,
1116+
) -> SingleStepResult:
1117+
"""Execute tools and side effects from a model response."""
1118+
return await cls._get_single_step_result_from_response(
1119+
agent=agent,
1120+
original_input=original_input,
1121+
pre_step_items=generated_items,
1122+
new_response=new_response,
1123+
output_schema=output_schema,
1124+
all_tools=all_tools,
1125+
handoffs=handoffs,
1126+
hooks=hooks,
1127+
context_wrapper=context_wrapper,
1128+
run_config=run_config,
1129+
tool_use_tracker=tool_use_tracker,
1130+
)
1131+
9761132
@classmethod
9771133
async def _run_input_guardrails(
9781134
cls,

0 commit comments

Comments
 (0)