Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support rest-style docstrings when loading tools from function #9004

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 57 additions & 2 deletions haystack/tools/from_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
import re
import textwrap
from typing import Any, Callable, Dict, Optional

from pydantic import create_model

from haystack.tools.errors import SchemaGenerationError
from haystack.tools.tool import Tool

# Define constants for ReST directives
REST_DIRECTIVES = r"param|return|returns|raise|raises"


def create_tool_from_function(
function: Callable, name: Optional[str] = None, description: Optional[str] = None
Expand Down Expand Up @@ -72,7 +77,53 @@ def get_weather(
If there is an error generating the JSON schema for the Tool.
"""

tool_description = description if description is not None else (function.__doc__ or "")
tool_description = ""
param_descriptions_from_rest: Dict[str, str] = {}
return_description = ""
raises_descriptions = []

if description is not None:
tool_description = description
else:
# Process docstring if available
if function.__doc__:
docstring = textwrap.dedent(function.__doc__).strip()

# Check if this is a ReST-style docstring
if re.search(rf":({REST_DIRECTIVES})\s+", docstring):
# Extract main description (everything before first directive)
main_parts = re.split(rf":({REST_DIRECTIVES})\s+", docstring, 1)
tool_description = main_parts[0].strip()

# Parse parameter descriptions (handling both :param name: and :param type name: formats)
param_pattern = re.compile(rf":param\s+(\w+)\s*:(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL)
param_descriptions_from_rest = {name: desc.strip() for name, desc in param_pattern.findall(docstring)}

# Parse return descriptions
return_pattern = re.compile(rf":return:\s*(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL)
return_matches = return_pattern.findall(docstring)
if return_matches:
return_description = return_matches[0].strip()

# Parse raises descriptions
raises_pattern = re.compile(
rf":raises?\s+(\w+(?:,\s*\w+)*)\s*:\s*(.*?)(?=:(?:{REST_DIRECTIVES})|$)", re.DOTALL
)
for exc_types, desc in raises_pattern.findall(docstring):
for exc_type in re.split(r",\s*", exc_types):
raises_descriptions.append(f"{exc_type}: {desc.strip()}")
else:
# Not a ReST-style docstring, use the whole thing
tool_description = docstring.strip()

# Build a comprehensive description including return values and exceptions
full_description = tool_description

if return_description:
full_description += f"\n\nReturns: {return_description}"

if raises_descriptions:
full_description += "\n\nRaises:\n" + "\n".join(f"- {r}" for r in raises_descriptions)

signature = inspect.signature(function)

Expand All @@ -89,8 +140,12 @@ def get_weather(
default = param.default if param.default is not param.empty else ...
fields[param_name] = (param.annotation, default)

# Priority 1: Get descriptions from Annotated type hints
if hasattr(param.annotation, "__metadata__"):
descriptions[param_name] = param.annotation.__metadata__[0]
# Priority 2: Get descriptions from ReST docstring
elif param_name in param_descriptions_from_rest:
descriptions[param_name] = param_descriptions_from_rest[param_name]

# create Pydantic model and generate JSON schema
try:
Expand All @@ -109,7 +164,7 @@ def get_weather(
if param_name in schema["properties"]:
schema["properties"][param_name]["description"] = param_description

return Tool(name=name or function.__name__, description=tool_description, parameters=schema, function=function)
return Tool(name=name or function.__name__, description=full_description, parameters=schema, function=function)


def tool(function: Callable) -> Tool:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
Supports ReST-style docstrings for functions when generating a tool from a function.
The ReST-style docstring will be automatically parsed to infer the tool's description and
argument descriptions.
30 changes: 30 additions & 0 deletions test/tools/test_from_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,19 @@ def function_with_docstring(city: str) -> str:
return f"Weather report for {city}: 20°C, sunny"


def function_with_rest_docstring(city: str) -> str:
"""
Get weather report for a city.

:param city: The city for which to get the weather.
:return: The weather report for the city.
:raises ValueError: If the city is not found.
"""
if city == "":
raise ValueError("City not found.")
return f"Weather report for {city}: 20°C, sunny"


def test_from_function_description_from_docstring():
tool = create_tool_from_function(function=function_with_docstring)

Expand All @@ -19,6 +32,23 @@ def test_from_function_description_from_docstring():
assert tool.function == function_with_docstring


def test_from_function_description_from_rest_docstring():
tool = create_tool_from_function(function=function_with_rest_docstring)

assert tool.name == "function_with_rest_docstring"
assert tool.description == (
"Get weather report for a city.\n\n"
"Returns: The weather report for the city.\n\n"
"Raises:\n- ValueError: If the city is not found."
)
assert tool.parameters == {
"type": "object",
"properties": {"city": {"type": "string", "description": "The city for which to get the weather."}},
"required": ["city"],
}
assert tool.function == function_with_rest_docstring


def test_from_function_with_empty_description():
tool = create_tool_from_function(function=function_with_docstring, description="")

Expand Down