From 9779bbdbe6cedb3e28457e3849b3c190d8e0baca Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 9 Jan 2025 12:13:53 +0100 Subject: [PATCH 1/3] Create custom environment to render haystack dataclasses --- .../components/routers/conditional_router.py | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index 5fbe399248..528e914dd1 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -8,14 +8,17 @@ from warnings import warn from jinja2 import Environment, TemplateSyntaxError, meta -from jinja2.nativetypes import NativeEnvironment +from jinja2.nativetypes import NativeEnvironment, NativeTemplate from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging +from haystack.dataclasses import Answer, ByteStream, ChatMessage, Document, SparseEmbedding, StreamingChunk from haystack.utils import deserialize_callable, deserialize_type, serialize_callable, serialize_type logger = logging.getLogger(__name__) +haystack_dataclass_types = (ByteStream, Document, ChatMessage, Answer, SparseEmbedding, StreamingChunk) + class NoRouteSelectedException(Exception): """Exception raised when no route is selected in ConditionalRouter.""" @@ -25,6 +28,42 @@ class RouteConditionException(Exception): """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" +class NativeSandboxedTemplate(NativeTemplate): + """ + A template class that returns native Python objects and also respects the sandbox security checks. + """ + + pass + + +class NativeSandboxedEnvironment(SandboxedEnvironment, NativeEnvironment): + """ + An environment that combines sandbox restrictions with native rendering. + """ + + # We tell the environment to use our custom template class by default. + template_class = NativeSandboxedTemplate + + def from_string(self, source, template_class=None): + """ + Override from_string to ensure the sandbox logic + native logic are used together. + """ + if template_class is None: + template_class = self.template_class + return SandboxedEnvironment.from_string(self, source, template_class=template_class) + + def is_safe_attribute(self, obj, attr, value): + """ + Whitelist attributes or slicing on your custom classes so the sandbox won't block them. + """ + # If it's a ChatMessage object, you can whitelist certain attributes: + if isinstance(obj, haystack_dataclass_types): + return True + + # Otherwise, fallback to the default sandbox behavior + return super().is_safe_attribute(obj, attr, value) + + @component class ConditionalRouter: """ @@ -195,8 +234,17 @@ def __init__( # pylint: disable=too-many-positional-arguments warn(msg) self._env = NativeEnvironment() if self._unsafe else SandboxedEnvironment() + self._custom_env = NativeSandboxedEnvironment() self._env.filters.update(self.custom_filters) + # Add custom types to the custom environment + self._custom_env.globals["Document"] = Document + self._custom_env.globals["ChatMessage"] = ChatMessage + self._custom_env.globals["ByteStream"] = ByteStream + self._custom_env.globals["Answer"] = Answer + self._custom_env.globals["SparseEmbedding"] = SparseEmbedding + self._custom_env.globals["StreamingChunk"] = StreamingChunk + self._validate_routes(routes) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any @@ -309,14 +357,15 @@ def run(self, **kwargs): if not rendered: continue # We now evaluate the `output` expression to determine the route output - t_output = self._env.from_string(route["output"]) + t_output = self._custom_env.from_string(route["output"]) output = t_output.render(**kwargs) + # We suppress the exception in case the output is already a string, otherwise # we try to evaluate it and would fail. # This must be done cause the output could be different literal structures. # This doesn't support any user types. with contextlib.suppress(Exception): - if not self._unsafe: + if not self._unsafe and isinstance(output, str): output = ast.literal_eval(output) except Exception as e: msg = f"Error evaluating condition for route '{route}': {e}" From 94f9fd1e40b246e28406253c07f2e966b13486cb Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 9 Jan 2025 12:39:07 +0100 Subject: [PATCH 2/3] Small fixes --- haystack/components/routers/conditional_router.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index 528e914dd1..a36b4355ee 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -52,16 +52,15 @@ def from_string(self, source, template_class=None): template_class = self.template_class return SandboxedEnvironment.from_string(self, source, template_class=template_class) - def is_safe_attribute(self, obj, attr, value): + def is_safe_attribute(self, obj): """ - Whitelist attributes or slicing on your custom classes so the sandbox won't block them. + Whitelist Haystack dataclasses so the sandbox won't block them. """ - # If it's a ChatMessage object, you can whitelist certain attributes: if isinstance(obj, haystack_dataclass_types): return True # Otherwise, fallback to the default sandbox behavior - return super().is_safe_attribute(obj, attr, value) + return super().is_safe_attribute(obj) @component From a82586808b2fdca5595f9cf8250d1ad0b4d7ee38 Mon Sep 17 00:00:00 2001 From: Amna Mubashar Date: Thu, 9 Jan 2025 15:48:34 +0100 Subject: [PATCH 3/3] Update from_string method --- .../components/routers/conditional_router.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/haystack/components/routers/conditional_router.py b/haystack/components/routers/conditional_router.py index a36b4355ee..a2ca1c0d76 100644 --- a/haystack/components/routers/conditional_router.py +++ b/haystack/components/routers/conditional_router.py @@ -8,7 +8,7 @@ from warnings import warn from jinja2 import Environment, TemplateSyntaxError, meta -from jinja2.nativetypes import NativeEnvironment, NativeTemplate +from jinja2.nativetypes import NativeEnvironment, NativeTemplate, Template from jinja2.sandbox import SandboxedEnvironment from haystack import component, default_from_dict, default_to_dict, logging @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) -haystack_dataclass_types = (ByteStream, Document, ChatMessage, Answer, SparseEmbedding, StreamingChunk) +haystack_dataclass_types = (ByteStream, ChatMessage, Document, Answer, SparseEmbedding, StreamingChunk) class NoRouteSelectedException(Exception): @@ -28,7 +28,7 @@ class RouteConditionException(Exception): """Exception raised when there is an error parsing or evaluating the condition expression in ConditionalRouter.""" -class NativeSandboxedTemplate(NativeTemplate): +class NativeSandboxedTemplate(NativeTemplate, Template): """ A template class that returns native Python objects and also respects the sandbox security checks. """ @@ -42,25 +42,25 @@ class NativeSandboxedEnvironment(SandboxedEnvironment, NativeEnvironment): """ # We tell the environment to use our custom template class by default. - template_class = NativeSandboxedTemplate - def from_string(self, source, template_class=None): + def from_string(self, source): """ Override from_string to ensure the sandbox logic + native logic are used together. """ - if template_class is None: - template_class = self.template_class + template_class = NativeSandboxedTemplate + return SandboxedEnvironment.from_string(self, source, template_class=template_class) - def is_safe_attribute(self, obj): + def is_safe_attribute(self, obj, attr="", value=""): """ Whitelist Haystack dataclasses so the sandbox won't block them. """ - if isinstance(obj, haystack_dataclass_types): - return True + + if not isinstance(obj, haystack_dataclass_types): + return False # Otherwise, fallback to the default sandbox behavior - return super().is_safe_attribute(obj) + return SandboxedEnvironment.is_safe_attribute(self, obj, attr, value) @component @@ -236,14 +236,6 @@ def __init__( # pylint: disable=too-many-positional-arguments self._custom_env = NativeSandboxedEnvironment() self._env.filters.update(self.custom_filters) - # Add custom types to the custom environment - self._custom_env.globals["Document"] = Document - self._custom_env.globals["ChatMessage"] = ChatMessage - self._custom_env.globals["ByteStream"] = ByteStream - self._custom_env.globals["Answer"] = Answer - self._custom_env.globals["SparseEmbedding"] = SparseEmbedding - self._custom_env.globals["StreamingChunk"] = StreamingChunk - self._validate_routes(routes) # Inspect the routes to determine input and output types. input_types: Set[str] = set() # let's just store the name, type will always be Any @@ -359,13 +351,21 @@ def run(self, **kwargs): t_output = self._custom_env.from_string(route["output"]) output = t_output.render(**kwargs) + # Check if output is a list/sequence and validate accordingly + if isinstance(output, (list, tuple)): + if all(self._custom_env.is_safe_attribute(item) for item in output): + pass + elif self._custom_env.is_safe_attribute(output): + pass + # We suppress the exception in case the output is already a string, otherwise # we try to evaluate it and would fail. # This must be done cause the output could be different literal structures. # This doesn't support any user types. - with contextlib.suppress(Exception): - if not self._unsafe and isinstance(output, str): - output = ast.literal_eval(output) + else: + with contextlib.suppress(Exception): + if not self._unsafe and isinstance(output, str): + output = ast.literal_eval(output) except Exception as e: msg = f"Error evaluating condition for route '{route}': {e}" raise RouteConditionException(msg) from e