diff --git a/ppdiffusers/examples/stable_diffusion/train_txt2img_laion400m_trainer.py b/ppdiffusers/examples/stable_diffusion/train_txt2img_laion400m_trainer.py index ae46b2093..3c8b51363 100644 --- a/ppdiffusers/examples/stable_diffusion/train_txt2img_laion400m_trainer.py +++ b/ppdiffusers/examples/stable_diffusion/train_txt2img_laion400m_trainer.py @@ -27,6 +27,11 @@ from paddlenlp.utils.log import logger +def use_fusedlinear_for_speed(default=False): + if default: + paddle.nn.Linear = paddle.incubate.nn.FusedLinear + + def main(): parser = PdArgumentParser( (SDModelArguments, SDDataArguments, SDTrainingArguments)) @@ -124,4 +129,7 @@ def main(): if __name__ == "__main__": + # for higher ips + use_fusedlinear_for_speed(True) + main()