Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

relax foreach to handle optional #4901

Merged
merged 3 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions reflex/components/core/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from reflex.components.tags import CondTag, Tag
from reflex.constants import Dirs
from reflex.style import LIGHT_COLOR_MODE, resolved_color_mode
from reflex.utils import types
from reflex.utils.imports import ImportDict, ImportVar
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
Expand Down Expand Up @@ -145,20 +146,20 @@ def cond(condition: Any, c1: Any, c2: Any = None) -> Component | Var:
if c2 is None:
raise ValueError("For conditional vars, the second argument must be set.")

def create_var(cond_part: Any) -> Var[Any]:
return LiteralVar.create(cond_part)

# convert the truth and false cond parts into vars so the _var_data can be obtained.
c1 = create_var(c1)
c2 = create_var(c2)
c1_var = Var.create(c1)
c2_var = Var.create(c2)

if condition is c1_var:
c1_var = c1_var.to(types.value_inside_optional(c1_var._var_type))

# Create the conditional var.
return ternary_operation(
cond_var.bool()._replace(
merge_var_data=VarData(imports=_IS_TRUE_IMPORT),
),
c1,
c2,
c1_var,
c2_var,
)


Expand Down
6 changes: 5 additions & 1 deletion reflex/components/core/foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from reflex.components.base.fragment import Fragment
from reflex.components.component import Component
from reflex.components.core.cond import cond
from reflex.components.tags import IterTag
from reflex.constants import MemoizationMode
from reflex.state import ComponentState
from reflex.utils import types
from reflex.utils.exceptions import UntypedVarError
from reflex.vars.base import LiteralVar, Var

Expand Down Expand Up @@ -85,6 +87,9 @@ def create(
"See https://reflex.dev/docs/library/dynamic-rendering/foreach/"
)

if types.is_optional(iterable._var_type):
iterable = cond(iterable, iterable, [])

component = cls(
iterable=iterable,
render_fn=render_fn,
Expand Down Expand Up @@ -164,7 +169,6 @@ def render(self):
iterable_state=str(tag.iterable),
arg_name=tag.arg_var_name,
arg_index=tag.get_index_var_arg(),
iterable_type=tag.iterable._var_type.mro()[0].__name__,
)


Expand Down
19 changes: 5 additions & 14 deletions reflex/components/tags/iter_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@

import dataclasses
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, Union, get_args
from typing import TYPE_CHECKING, Callable, Iterable

from reflex.components.tags.tag import Tag
from reflex.utils.types import GenericType
from reflex.vars import LiteralArrayVar, Var, get_unique_variable_name
from reflex.vars.sequence import _determine_value_of_array_index

if TYPE_CHECKING:
from reflex.components.component import Component
Expand All @@ -31,24 +33,13 @@ class IterTag(Tag):
# The name of the index var.
index_var_name: str = dataclasses.field(default_factory=get_unique_variable_name)

def get_iterable_var_type(self) -> Type:
def get_iterable_var_type(self) -> GenericType:
"""Get the type of the iterable var.

Returns:
The type of the iterable var.
"""
iterable = self.iterable
try:
if iterable._var_type.mro()[0] is dict:
# Arg is a tuple of (key, value).
return tuple[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
elif iterable._var_type.mro()[0] is tuple:
# Arg is a union of any possible values in the tuple.
return Union[get_args(iterable._var_type)] # pyright: ignore [reportReturnType]
else:
return get_args(iterable._var_type)[0]
except Exception:
return Any # pyright: ignore [reportReturnType]
return _determine_value_of_array_index(self.iterable._var_type)

def get_index_var(self) -> Var:
"""Get the index var for the tag (with curly braces).
Expand Down
64 changes: 47 additions & 17 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,14 @@ def var_operation( # pyright: ignore [reportOverlappingOverload]

@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[bool]],
func: Callable[P, CustomVarOperationReturn[None]],
) -> Callable[P, NoneVar]: ...


@overload
def var_operation( # pyright: ignore [reportOverlappingOverload]
func: Callable[P, CustomVarOperationReturn[bool]]
| Callable[P, CustomVarOperationReturn[bool | None]],
) -> Callable[P, BooleanVar]: ...


Expand All @@ -1607,13 +1614,15 @@ def var_operation(

@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[NUMBER_T]],
func: Callable[P, CustomVarOperationReturn[NUMBER_T]]
| Callable[P, CustomVarOperationReturn[NUMBER_T | None]],
) -> Callable[P, NumberVar[NUMBER_T]]: ...


@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[str]],
func: Callable[P, CustomVarOperationReturn[str]]
| Callable[P, CustomVarOperationReturn[str | None]],
) -> Callable[P, StringVar]: ...


Expand All @@ -1622,7 +1631,8 @@ def var_operation(

@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[LIST_T]],
func: Callable[P, CustomVarOperationReturn[LIST_T]]
| Callable[P, CustomVarOperationReturn[LIST_T | None]],
) -> Callable[P, ArrayVar[LIST_T]]: ...


Expand All @@ -1631,13 +1641,15 @@ def var_operation(

@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]],
func: Callable[P, CustomVarOperationReturn[OBJECT_TYPE]]
| Callable[P, CustomVarOperationReturn[OBJECT_TYPE | None]],
) -> Callable[P, ObjectVar[OBJECT_TYPE]]: ...


@overload
def var_operation(
func: Callable[P, CustomVarOperationReturn[T]],
func: Callable[P, CustomVarOperationReturn[T]]
| Callable[P, CustomVarOperationReturn[T | None]],
) -> Callable[P, Var[T]]: ...


Expand Down Expand Up @@ -3278,53 +3290,71 @@ def __set__(self, instance: Any, value: FIELD_TYPE):
"""

@overload
def __get__(self: Field[bool], instance: None, owner: Any) -> BooleanVar: ...
def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...

@overload
def __get__(
self: Field[int] | Field[float] | Field[int | float], instance: None, owner: Any
) -> NumberVar: ...
self: Field[bool] | Field[bool | None], instance: None, owner: Any
) -> BooleanVar: ...

@overload
def __get__(self: Field[str], instance: None, owner: Any) -> StringVar: ...
def __get__(
self: Field[int]
| Field[float]
| Field[int | float]
| Field[int | None]
| Field[float | None]
| Field[int | float | None],
instance: None,
owner: Any,
) -> NumberVar: ...

@overload
def __get__(self: Field[None], instance: None, owner: Any) -> NoneVar: ...
def __get__(
self: Field[str] | Field[str | None], instance: None, owner: Any
) -> StringVar: ...

@overload
def __get__(
self: Field[list[V]] | Field[set[V]],
self: Field[list[V]]
| Field[set[V]]
| Field[list[V] | None]
| Field[set[V] | None],
instance: None,
owner: Any,
) -> ArrayVar[Sequence[V]]: ...

@overload
def __get__(
self: Field[SEQUENCE_TYPE],
self: Field[SEQUENCE_TYPE] | Field[SEQUENCE_TYPE | None],
instance: None,
owner: Any,
) -> ArrayVar[SEQUENCE_TYPE]: ...

@overload
def __get__(
self: Field[MAPPING_TYPE], instance: None, owner: Any
self: Field[MAPPING_TYPE] | Field[MAPPING_TYPE | None],
instance: None,
owner: Any,
) -> ObjectVar[MAPPING_TYPE]: ...

@overload
def __get__(
self: Field[BASE_TYPE], instance: None, owner: Any
self: Field[BASE_TYPE] | Field[BASE_TYPE | None], instance: None, owner: Any
) -> ObjectVar[BASE_TYPE]: ...

@overload
def __get__(
self: Field[SQLA_TYPE], instance: None, owner: Any
self: Field[SQLA_TYPE] | Field[SQLA_TYPE | None], instance: None, owner: Any
) -> ObjectVar[SQLA_TYPE]: ...

if TYPE_CHECKING:

@overload
def __get__(
self: Field[DATACLASS_TYPE], instance: None, owner: Any
self: Field[DATACLASS_TYPE] | Field[DATACLASS_TYPE | None],
instance: None,
owner: Any,
) -> ObjectVar[DATACLASS_TYPE]: ...

@overload
Expand Down
27 changes: 21 additions & 6 deletions reflex/vars/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,14 @@ def object_keys_operation(value: ObjectVar):
Returns:
The keys of the object.
"""
if not types.is_optional(value._var_type):
return var_operation_return(
js_expression=f"Object.keys({value})",
var_type=list[str],
)
return var_operation_return(
js_expression=f"Object.keys({value})",
var_type=list[str],
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.keys(value))({value})",
var_type=(list[str] | None),
)


Expand All @@ -457,9 +462,14 @@ def object_values_operation(value: ObjectVar):
Returns:
The values of the object.
"""
if not types.is_optional(value._var_type):
return var_operation_return(
js_expression=f"Object.values({value})",
var_type=list[value._value_type()],
)
return var_operation_return(
js_expression=f"Object.values({value})",
var_type=list[value._value_type()],
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.values(value))({value})",
var_type=(list[value._value_type()] | None),
)


Expand All @@ -473,9 +483,14 @@ def object_entries_operation(value: ObjectVar):
Returns:
The entries of the object.
"""
if not types.is_optional(value._var_type):
return var_operation_return(
js_expression=f"Object.entries({value})",
var_type=list[tuple[str, value._value_type()]],
)
return var_operation_return(
js_expression=f"Object.entries({value})",
var_type=list[tuple[str, value._value_type()]],
js_expression=f"((value) => value ?? undefined === undefined ? undefined : Object.entries(value))({value})",
var_type=(list[tuple[str, value._value_type()]] | None),
)


Expand Down
20 changes: 20 additions & 0 deletions tests/integration/test_var_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class VarOperationState(rx.State):
list2: rx.Field[list] = rx.field([3, 4])
list3: rx.Field[list] = rx.field(["first", "second", "third"])
list4: rx.Field[list] = rx.field([Object(name="obj_1"), Object(name="obj_2")])
optional_list: rx.Field[list | None] = rx.field(None)
optional_dict: rx.Field[dict[str, str] | None] = rx.field(None)
optional_list_value: rx.Field[list[str] | None] = rx.field(["red", "yellow"])
optional_dict_value: rx.Field[dict[str, str] | None] = rx.field({"name": "red"})
str_var1: rx.Field[str] = rx.field("first")
str_var2: rx.Field[str] = rx.field("second")
str_var3: rx.Field[str] = rx.field("ThIrD")
Expand Down Expand Up @@ -645,6 +649,22 @@ def index():
),
id="typed_dict_in_foreach",
),
rx.box(
rx.foreach(VarOperationState.optional_list, rx.text.span),
id="optional_list",
),
rx.box(
rx.foreach(VarOperationState.optional_dict, rx.text.span),
id="optional_dict",
),
rx.box(
rx.foreach(VarOperationState.optional_list_value, rx.text.span),
id="optional_list_value",
),
rx.box(
rx.foreach(VarOperationState.optional_dict_value, rx.text.span),
id="optional_dict_value",
),
)


Expand Down
Loading