Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var #4338

Merged
merged 6 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@
except ModuleNotFoundError:
import pydantic

from pydantic import BaseModel as BaseModelV2

try:
from pydantic.v1 import BaseModel as BaseModelV1
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2

import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
Expand Down Expand Up @@ -1243,7 +1250,7 @@ def __getattribute__(self, name: str) -> Any:
if parent_state is not None:
return getattr(parent_state, name)

if isinstance(value, MutableProxy.__mutable_types__) and (
if MutableProxy._is_mutable_type(value) and (
name in super().__getattribute__("base_vars") or name in backend_vars
):
# track changes in mutable containers (list, dict, set, etc)
Expand Down Expand Up @@ -3526,7 +3533,16 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)

__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
# These types will be wrapped in MutableProxy
__mutable_types__ = (
list,
dict,
set,
Base,
DeclarativeBase,
BaseModelV2,
BaseModelV1,
)

def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
Expand Down Expand Up @@ -3566,6 +3582,18 @@ def _mark_dirty(
if wrapped is not None:
return wrapped(*args, **(kwargs or {}))

@classmethod
def _is_mutable_type(cls, value: Any) -> bool:
"""Check if a value is of a mutable type and should be wrapped.

Args:
value: The value to check.

Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)

def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.

Expand All @@ -3576,9 +3604,7 @@ def _wrap_recursive(self, value: Any) -> Any:
The wrapped value.
"""
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
return type(self)(
wrapped=value,
state=self._self_state,
Expand Down Expand Up @@ -3636,7 +3662,7 @@ def __getattr__(self, __name: str) -> Any:
self._wrap_recursive_decorator,
)

if isinstance(value, self.__mutable_types__) and __name not in (
if self._is_mutable_type(value) and __name not in (
"__wrapped__",
"_self_state",
):
Expand Down
47 changes: 47 additions & 0 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,53 @@ def serialize_base(value: Base) -> dict:
return {k: v for k, v in value.dict().items() if not callable(v)}


try:
from pydantic.v1 import BaseModel as BaseModelV1

@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.

Args:
model: The BaseModel to serialize.

Returns:
The serialized BaseModel.
"""
return model.dict()

from pydantic import BaseModel as BaseModelV2

if BaseModelV1 is not BaseModelV2:

@serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance.

Args:
model: The BaseModel to serialize.

Returns:
The serialized BaseModel.
"""
return model.model_dump()
except ImportError:
# Older pydantic v1 import
from pydantic import BaseModel as BaseModelV1

@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.

Args:
model: The BaseModel to serialize.

Returns:
The serialized BaseModel.
"""
return model.dict()


@serializer
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.
Expand Down
49 changes: 49 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pytest
import pytest_asyncio
from plotly.graph_objects import Figure
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1

import reflex as rx
import reflex.config
Expand Down Expand Up @@ -3413,6 +3415,53 @@ class TypedState(rx.State):
_ = TypedState(field="str")


class ModelV1(BaseModelV1):
"""A pydantic BaseModel v1."""

foo: str = "bar"


class ModelV2(BaseModelV2):
"""A pydantic BaseModel v2."""

foo: str = "bar"


@dataclasses.dataclass
class ModelDC:
"""A dataclass."""

foo: str = "bar"


class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""

v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2()
dc: ModelDC = ModelDC()


def test_mutable_models():
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
state = PydanticState()
assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz"
assert state.dirty_vars == {"v1"}
state.dirty_vars.clear()

assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz"
assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()

# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy)
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()


def test_get_value():
class GetValueState(rx.State):
foo: str = "FOO"
Expand Down
Loading