-
Notifications
You must be signed in to change notification settings - Fork 186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PPDiffusers] Not use recompute #374
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
from paddle import nn | ||
from paddle.distributed.fleet.utils import recompute | ||
|
||
from ..utils import recompute_use_reentrant | ||
from ..utils import recompute_use_reentrant, use_old_recompute | ||
from ..utils.paddle_utils import apply_freeu | ||
from .attention import Attention | ||
from .dual_transformer_2d import DualTransformer2DModel | ||
|
@@ -1013,7 +1013,12 @@ def forward( | |
|
||
blocks = zip(self.resnets, self.motion_modules) | ||
for resnet, motion_module in blocks: | ||
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient: | ||
if ( | ||
self.training | ||
and self.gradient_checkpointing | ||
and not hidden_states.stop_gradient | ||
and not use_old_recompute() | ||
): | ||
|
||
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
|
@@ -1183,7 +1188,12 @@ def forward( | |
|
||
blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) | ||
for i, (resnet, attn, motion_module) in enumerate(blocks): | ||
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient: | ||
if ( | ||
self.training | ||
and self.gradient_checkpointing | ||
and not hidden_states.stop_gradient | ||
and not use_old_recompute() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些应该不用加吧?之前都没有这些模块的。0.19.4 |
||
): | ||
|
||
def create_custom_forward(module, return_dict=None): | ||
def custom_forward(*inputs): | ||
|
@@ -1387,7 +1397,12 @@ def forward( | |
|
||
hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1) | ||
|
||
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient: | ||
if ( | ||
self.training | ||
and self.gradient_checkpointing | ||
and not hidden_states.stop_gradient | ||
and not use_old_recompute() | ||
): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些应该不用加吧?之前都没有这些模块的。0.19.4 |
||
|
||
def create_custom_forward(module, return_dict=None): | ||
def custom_forward(*inputs): | ||
|
@@ -1542,7 +1557,12 @@ def forward( | |
|
||
hidden_states = paddle.concat([hidden_states, res_hidden_states], axis=1) | ||
|
||
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient: | ||
if ( | ||
self.training | ||
and self.gradient_checkpointing | ||
and not hidden_states.stop_gradient | ||
and not use_old_recompute() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些应该不用加吧?之前就没有这 |
||
): | ||
|
||
def create_custom_forward(module): | ||
def custom_forward(*inputs): | ||
|
@@ -1699,7 +1719,12 @@ def forward( | |
|
||
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) | ||
for attn, resnet, motion_module in blocks: | ||
if self.training and self.gradient_checkpointing and not hidden_states.stop_gradient: | ||
if ( | ||
self.training | ||
and self.gradient_checkpointing | ||
and not hidden_states.stop_gradient | ||
and not use_old_recompute() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这些应该不用加吧?之前都没有这些模块的。0.19.4 |
||
): | ||
|
||
def create_custom_forward(module, return_dict=None): | ||
def custom_forward(*inputs): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些应该不用加吧?之前都没有这些模块的。0.19.4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://github.com/PaddlePaddle/PaddleMIX/blob/ppdiffusers0.19.4/ppdiffusers/ppdiffusers/models/unet_3d_blocks.py
DownBlockMotion 是没有这些模块,升级增加的,不是要测试 recompute吗
增加 and not use_old_recompute() 用来测试 recompute