Skip to content

Added support for "return" handoffs (#1) #869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 36 additions & 10 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,11 @@ class NextStepHandoff:
new_agent: Agent[Any]


@dataclass
class NextStepHandoffReturnControl:
previous_agent: Agent[Any]


@dataclass
class NextStepFinalOutput:
output: Any
Expand All @@ -201,7 +206,9 @@ class SingleStepResult:
new_step_items: list[RunItem]
"""Items generated during this current step."""

next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain
next_step: (
NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepHandoffReturnControl
)
"""The next step to take."""

@property
Expand Down Expand Up @@ -238,6 +245,7 @@ async def execute_tools_and_side_effects(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent],
) -> SingleStepResult:
# Make a copy of the generated items
pre_step_items = list(pre_step_items)
Expand Down Expand Up @@ -286,6 +294,7 @@ async def execute_tools_and_side_effects(
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
previous_agents=previous_agents,
)

# Next, we'll check if the tool use should result in a final output
Expand Down Expand Up @@ -316,6 +325,7 @@ async def execute_tools_and_side_effects(
final_output=check_tool_use.final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)

# Now we can check if the model also produced a final output
Expand All @@ -340,6 +350,7 @@ async def execute_tools_and_side_effects(
final_output=final_output,
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
elif (
not output_schema or output_schema.is_plain_text()
Expand All @@ -353,6 +364,7 @@ async def execute_tools_and_side_effects(
final_output=potential_final_output_text or "",
hooks=hooks,
context_wrapper=context_wrapper,
previous_agents=previous_agents,
)
else:
# If there's no final output, we can just run again
Expand Down Expand Up @@ -663,6 +675,7 @@ async def execute_handoffs(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
# If there is more than one handoff, add tool responses that reject those handoffs
multiple_handoffs = len(run_handoffs) > 1
Expand All @@ -684,6 +697,8 @@ async def execute_handoffs(
actual_handoff = run_handoffs[0]
with handoff_span(from_agent=agent.name) as span_handoff:
handoff = actual_handoff.handoff
if handoff.should_return_control:
previous_agents.append(agent)
new_agent: Agent[Any] = await handoff.on_invoke_handoff(
context_wrapper, actual_handoff.tool_call.arguments
)
Expand Down Expand Up @@ -825,16 +840,21 @@ async def execute_final_output(
final_output: Any,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
previous_agents: list[Agent[TContext]],
) -> SingleStepResult:
is_returning_control = len(previous_agents) > 0
# Run the on_end hooks
await cls.run_final_output_hooks(agent, hooks, context_wrapper, final_output)

await cls.run_final_output_hooks(
agent, hooks, context_wrapper, final_output, is_returning_control
)
return SingleStepResult(
original_input=original_input,
model_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
next_step=NextStepFinalOutput(final_output),
next_step=NextStepHandoffReturnControl(previous_agents.pop())
if is_returning_control
else NextStepFinalOutput(final_output),
)

@classmethod
Expand All @@ -844,13 +864,19 @@ async def run_final_output_hooks(
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
final_output: Any,
is_returning_control: bool,
):
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is not returning control, run the hooks
if not is_returning_control:
await asyncio.gather(
hooks.on_agent_end(context_wrapper, agent, final_output),
agent.hooks.on_end(context_wrapper, agent, final_output)
if agent.hooks
else _coro.noop_coroutine(),
)
# If the agent is returning control, only run the current agent's hooks
elif agent.hooks:
await agent.hooks.on_end(context_wrapper, agent, final_output)

@classmethod
async def run_single_input_guardrail(
Expand Down
13 changes: 12 additions & 1 deletion src/agents/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class Handoff(Generic[TContext]):
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
a handoff based on your context/state."""

should_return_control: bool = False
"""Whether the Agent that receives control during a handoff should return control to the
original (previous) Agent upon completion of its work. If False, after the Agent that received
the handoff completes its work, the interaction will end.
"""

def get_transfer_message(self, agent: Agent[Any]) -> str:
return json.dumps({"assistant": agent.name})

Expand All @@ -128,6 +134,7 @@ def handoff(
tool_description_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -141,6 +148,7 @@ def handoff(
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -153,6 +161,7 @@ def handoff(
tool_name_override: str | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
should_return_control: bool = False,
) -> Handoff[TContext]: ...


Expand All @@ -164,6 +173,7 @@ def handoff(
input_type: type[THandoffInput] | None = None,
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
should_return_control: bool = False,
) -> Handoff[TContext]:
"""Create a handoff from an agent.

Expand All @@ -181,7 +191,7 @@ def handoff(
hidden from the LLM at runtime.
"""
assert (on_handoff and input_type) or not (on_handoff and input_type), (
"You must provide either both on_handoff and input_type, or neither"
"You must provide either both on_input and input_type, or neither"
)
type_adapter: TypeAdapter[Any] | None
if input_type is not None:
Expand Down Expand Up @@ -247,4 +257,5 @@ async def _invoke_handoff(
input_filter=input_filter,
agent_name=agent.name,
is_enabled=is_enabled,
should_return_control=should_return_control,
)
Loading