Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PPDiffusers] Not use recompute #374

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions ppdiffusers/ppdiffusers/models/unet_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from paddle import nn
from paddle.distributed.fleet.utils import recompute

from ..utils import recompute_use_reentrant
from ..utils import recompute_use_reentrant, use_old_recompute
from ..utils.paddle_utils import apply_freeu
from .attention import Attention
from .dual_transformer_2d import DualTransformer2DModel
Expand Down Expand Up @@ -1013,7 +1013,12 @@ def forward(

blocks = zip(self.resnets, self.motion_modules)
for resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient:
if (
self.training
and self.gradient_checkpointing
and not hidden_states.stop_gradient
and not use_old_recompute()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该不用加吧?之前都没有这些模块的。0.19.4

Copy link
Contributor Author

@co63oc co63oc Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/PaddlePaddle/PaddleMIX/blob/ppdiffusers0.19.4/ppdiffusers/ppdiffusers/models/unet_3d_blocks.py
DownBlockMotion 是没有这些模块,升级增加的,不是要测试 recompute吗
增加 and not use_old_recompute() 用来测试 recompute

):

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1183,7 +1188,12 @@ def forward(

blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks):
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient:
if (
self.training
and self.gradient_checkpointing
and not hidden_states.stop_gradient
and not use_old_recompute()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该不用加吧?之前都没有这些模块的。0.19.4

):

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1387,7 +1397,12 @@ def forward(

hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)

if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient:
if (
self.training
and self.gradient_checkpointing
and not hidden_states.stop_gradient
and not use_old_recompute()
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该不用加吧?之前都没有这些模块的。0.19.4


def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1542,7 +1557,12 @@ def forward(

hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1)

if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient:
if (
self.training
and self.gradient_checkpointing
and not hidden_states.stop_gradient
and not use_old_recompute()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该不用加吧?之前就没有这

):

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -1699,7 +1719,12 @@ def forward(

blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks:
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient:
if (
self.training
and self.gradient_checkpointing
and not hidden_states.stop_gradient
and not use_old_recompute()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些应该不用加吧?之前都没有这些模块的。0.19.4

):

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand Down
6 changes: 3 additions & 3 deletions ppdiffusers/ppdiffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle.nn as nn
from paddle.distributed.fleet.utils import recompute

from ..utils import BaseOutput, recompute_use_reentrant
from ..utils import BaseOutput, recompute_use_reentrant, use_old_recompute
from ..utils.paddle_utils import randn_tensor
from .activations import get_activation
from .attention_processor import SpatialNorm
Expand Down Expand Up @@ -850,7 +850,7 @@ def __init__(

def forward(self, x: paddle.Tensor) -> paddle.Tensor:
r"""The forward method of the `EncoderTiny` class."""
if self.training and self.gradient_checkpointing and not x.stop_gradient:
if self.training and self.gradient_checkpointing and not x.stop_gradient and not use_old_recompute():

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down Expand Up @@ -932,7 +932,7 @@ def forward(self, x: paddle.Tensor) -> paddle.Tensor:
# Clamp.
x = nn.functional.tanh(x / 3) * 3

if self.training and self.gradient_checkpointing and not x.stop_gradient:
if self.training and self.gradient_checkpointing and not x.stop_gradient and not use_old_recompute():

def create_custom_forward(module):
def custom_forward(*inputs):
Expand Down