Skip to content

Commit

Permalink
[ENG-3953] Support pydantic BaseModel (v1 and v2) as state var (#4338)
Browse files Browse the repository at this point in the history
* [ENG-3953] Support pydantic BaseModel (v1 and v2) as state var

Provide serializers and mutable proxy tracking for pydantic models directly.

* conditionally define v2 serializer

Co-authored-by: Khaleel Al-Adhami <[email protected]>

* Add `MutableProxy._is_mutable_value` to avoid duplicate logic

* Conditionally import BaseModel to handle older pydantic v1 versions

* pre-commit fu

---------

Co-authored-by: Khaleel Al-Adhami <[email protected]>
  • Loading branch information
masenf and adhami3310 authored Nov 22, 2024
1 parent 5702a18 commit a6b324b
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 6 deletions.
38 changes: 32 additions & 6 deletions reflex/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@
except ModuleNotFoundError:
import pydantic

from pydantic import BaseModel as BaseModelV2

try:
from pydantic.v1 import BaseModel as BaseModelV1
except ModuleNotFoundError:
BaseModelV1 = BaseModelV2

import wrapt
from redis.asyncio import Redis
from redis.exceptions import ResponseError
Expand Down Expand Up @@ -1250,7 +1257,7 @@ def __getattribute__(self, name: str) -> Any:
if parent_state is not None:
return getattr(parent_state, name)

if isinstance(value, MutableProxy.__mutable_types__) and (
if MutableProxy._is_mutable_type(value) and (
name in super().__getattribute__("base_vars") or name in backend_vars
):
# track changes in mutable containers (list, dict, set, etc)
Expand Down Expand Up @@ -3558,7 +3565,16 @@ class MutableProxy(wrapt.ObjectProxy):
pydantic.BaseModel.__dict__
)

__mutable_types__ = (list, dict, set, Base, DeclarativeBase)
# These types will be wrapped in MutableProxy
__mutable_types__ = (
list,
dict,
set,
Base,
DeclarativeBase,
BaseModelV2,
BaseModelV1,
)

def __init__(self, wrapped: Any, state: BaseState, field_name: str):
"""Create a proxy for a mutable object that tracks changes.
Expand Down Expand Up @@ -3598,6 +3614,18 @@ def _mark_dirty(
if wrapped is not None:
return wrapped(*args, **(kwargs or {}))

@classmethod
def _is_mutable_type(cls, value: Any) -> bool:
"""Check if a value is of a mutable type and should be wrapped.
Args:
value: The value to check.
Returns:
Whether the value is of a mutable type.
"""
return isinstance(value, cls.__mutable_types__)

def _wrap_recursive(self, value: Any) -> Any:
"""Wrap a value recursively if it is mutable.
Expand All @@ -3608,9 +3636,7 @@ def _wrap_recursive(self, value: Any) -> Any:
The wrapped value.
"""
# Recursively wrap mutable types, but do not re-wrap MutableProxy instances.
if isinstance(value, self.__mutable_types__) and not isinstance(
value, MutableProxy
):
if self._is_mutable_type(value) and not isinstance(value, MutableProxy):
return type(self)(
wrapped=value,
state=self._self_state,
Expand Down Expand Up @@ -3668,7 +3694,7 @@ def __getattr__(self, __name: str) -> Any:
self._wrap_recursive_decorator,
)

if isinstance(value, self.__mutable_types__) and __name not in (
if self._is_mutable_type(value) and __name not in (
"__wrapped__",
"_self_state",
):
Expand Down
47 changes: 47 additions & 0 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,53 @@ def serialize_base(value: Base) -> dict:
}


try:
from pydantic.v1 import BaseModel as BaseModelV1

@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.dict()

from pydantic import BaseModel as BaseModelV2

if BaseModelV1 is not BaseModelV2:

@serializer(to=dict)
def serialize_base_model_v2(model: BaseModelV2) -> dict:
"""Serialize a pydantic v2 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.model_dump()
except ImportError:
# Older pydantic v1 import
from pydantic import BaseModel as BaseModelV1

@serializer(to=dict)
def serialize_base_model_v1(model: BaseModelV1) -> dict:
"""Serialize a pydantic v1 BaseModel instance.
Args:
model: The BaseModel to serialize.
Returns:
The serialized BaseModel.
"""
return model.dict()


@serializer
def serialize_set(value: Set) -> list:
"""Serialize a set to a JSON serializable list.
Expand Down
49 changes: 49 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pytest
import pytest_asyncio
from plotly.graph_objects import Figure
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1

import reflex as rx
import reflex.config
Expand Down Expand Up @@ -3413,6 +3415,53 @@ class TypedState(rx.State):
_ = TypedState(field="str")


class ModelV1(BaseModelV1):
"""A pydantic BaseModel v1."""

foo: str = "bar"


class ModelV2(BaseModelV2):
"""A pydantic BaseModel v2."""

foo: str = "bar"


@dataclasses.dataclass
class ModelDC:
"""A dataclass."""

foo: str = "bar"


class PydanticState(rx.State):
"""A state with pydantic BaseModel vars."""

v1: ModelV1 = ModelV1()
v2: ModelV2 = ModelV2()
dc: ModelDC = ModelDC()


def test_mutable_models():
"""Test that dataclass and pydantic BaseModel v1 and v2 use dep tracking."""
state = PydanticState()
assert isinstance(state.v1, MutableProxy)
state.v1.foo = "baz"
assert state.dirty_vars == {"v1"}
state.dirty_vars.clear()

assert isinstance(state.v2, MutableProxy)
state.v2.foo = "baz"
assert state.dirty_vars == {"v2"}
state.dirty_vars.clear()

# Not yet supported ENG-4083
# assert isinstance(state.dc, MutableProxy)
# state.dc.foo = "baz"
# assert state.dirty_vars == {"dc"}
# state.dirty_vars.clear()


def test_get_value():
class GetValueState(rx.State):
foo: str = "FOO"
Expand Down

0 comments on commit a6b324b

Please sign in to comment.