自学内容网 自学内容网

训练Diffusion Models节省内存的五个常用技巧(附代码)

Diffusion Models专栏文章汇总:入门与实战

前言:随着Diffusion视频生成模型的兴起,模型越做越大,计算资源显得愈发珍贵,很多时候感觉A100都已经不够用了。本篇博客讨论在训练Diffusion Models的时候一些常用的节省内存技巧,不涉及内存切片/模型切片等知识。

技巧一:trainable_modules

一般情况下不需要训练模型的全部权重,因此可以设置训练一部分权重。

    def set_trainable_parameters(self, trainable_modules):
        self.requires_grad_(False)
        for param_name, param in self.named_parameters():
            for trainable_module_name in trainable_modules:
                if trainable_module_name in param_name:
                    param.requires_grad = True
                    break

技巧二:Efficient training

考虑使用LoRA、Adapter等高效微调的手段;

首选diffusers和PEFT框架训练

技巧三:Delete

即使释放没用的变量。

例如在做完vae encode 之后,及时释放。

            with torch.no_grad():
                if not image_finetune:
                    video_length = image.shape[2]
                    image = rearrange(image, "b c f h w -> (b f) c h w")

                latents = vae.encode(image).latent_dist
                latents = latents.sample()
                latents = latents * vae.config.scaling_factor

                ref_latents = vae.encode(ref_image).latent_dist
                ref_latents = ref_latents.sample()
                ref_latents = ref_latents * vae.config.scaling_factor

                clip_latents = image_encoder(ref_image_clip).last_hidden_state
                # clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)

                if not image_finetune:
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)

                # Sample noise that we"ll add to the latents
                bsz = latents.shape[0]
                noise = torch.randn_like(latents)

                del image, ref_image, ref_image_clip
                torch.cuda.empty_cache()

此外video的读取经常也会出现内存泄漏,需要及时del reader

video_reader = VideoReader(video_path)
del video_reader

技巧四:梯度禁用

务必将不需要被训练的部分梯度禁用了!

    # Freeze vae and image_encoder
    vae.eval()
    vae.requires_grad_(False)
    image_encoder.eval()
    image_encoder.requires_grad_(False)

还需要注意vae encode的时候启用 with torch.no_grad() 

技巧五:Accelerate/DeepSpeed

最后一条绝杀,能够解决所有的问题!

但是注意要优先使用accelerate中的deepspeed,原生的deepspeed可能会有问题!


原文地址:https://blog.csdn.net/qq_41895747/article/details/143695415

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!