Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] fix fiddle #4500

Merged
merged 1 commit into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs_nnx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
sys.path.insert(0, os.path.abspath('..'))
# Include local extension.
sys.path.append(os.path.abspath('./_ext'))
# Set environment variable to indicate that we are building the docs.
os.environ['FLAX_DOC_BUILD'] = 'true'

# patch sphinx
# -- Project information -----------------------------------------------------
Expand Down
6 changes: 5 additions & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dataclasses
import inspect
import os
import threading
import typing as tp
from abc import ABCMeta
Expand All @@ -38,6 +39,7 @@

G = tp.TypeVar('G', bound='Object')

BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ

def _collect_stats(
node: tp.Any, node_stats: dict[int, dict[type[Variable], SizeBytes]]
Expand Down Expand Up @@ -157,7 +159,9 @@ def __init_subclass__(cls) -> None:
init=cls._graph_node_init, # type: ignore
)

cls.__signature__ = inspect.signature(cls.__init__)
if BUILDING_DOCS:
# set correct signature for sphinx
cls.__signature__ = inspect.signature(cls.__init__)

if not tp.TYPE_CHECKING:

Expand Down
Loading