自学内容网 自学内容网

手把手写深度学习(29):将DDP训练代码改成DeepSpeed

手把手写深度学习(0):专栏文章导航

前言:deepspeed已经成为了大模型时代训练模型的常规武器,这篇博客以一个基于DDP的 Stable Diffusion模型训练为例,讲解如何从将DDP训练代码改成DeepSpeed。

目录

构建optimizer

构建scheduler

Argument

Init Process

加载训练组件

训练代码修改

启动命令行

weight_dtype

DDP并行策略

配置文件

Loss为NAN

训练精度

调参步骤

示例1

示例2


构建optimizer

原来DDP代码如下:

    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

deepspeed只需要在配置文件中写好这些参数即可:

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-08,
            "weight_decay": 1e-2
        }
    }

如果你需要其他optimizer,可以看官方文档:DeepSpeed Configuration JSON - DeepSpeed 

构建scheduler

原先DDP代码中构建scheduler:

    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )

deepspeed中同样只需要在配置文件中构建:

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 0.001,
            "warmup_num_steps": 1000
        }
    }

Argument

必须在train.py 文件添加以下的代码:

parser = argparse.ArgumentParser(description='My training script.')
parser.add_argument('--local_rank', type=int, default=-1,
                    help='local rank passed from distributed launcher')
# Include DeepSpeed configuration arguments
parser = deepspeed.add_config_arguments(parser)
cmd_args = parser.parse_args()

否则会报错!也无法识别deepspeed配置文件!

Init Process

原先DDP的代码:

import torch.distributed as dist
dist.init_process_group(backend=backend, **kwargs)

现在改成:

deepspeed.init_distributed()

加载训练组件

如果在配置文件中写好了之后,只需要像这样直接加载,不需要另外写代码:

    unet, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=unet,
        # optimizer=optimizer,
        model_parameters=trainable_params,
        args=deep_args
    )

训练代码修改

deepspeed的标准训练方式如下,和torch的训练有一点点区别,直接修改就好:

for step, batch in enumerate(data_loader):
    #forward() method
    loss = model_engine(batch)

    #runs backpropagation
    model_engine.backward(loss)

    #weight update
    model_engine.step()

当然optimizer 和 scheduler的step可以自己加上:

            lr_scheduler.step()
            optimizer.step()

启动命令行

这是我的参考:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 deepspeed --num_gpus=8 train_deepspeed.py \
    --deepspeed --deepspeed_config 'configs/deepspeed/stage_2.json'

deepspeed_config这个参数是下面说的配置文件。

weight_dtype

虽然配置文件中会写上训练精度,但是代码中仍需要手动转换!

    if mixed_precision in ("fp16", "bf16"):
        weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
    else:
        weight_dtype = torch.float32

DDP并行策略

原始代码:

    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

 这句话直接删掉,deepspeed会自动帮我们做并行

配置文件

下面是我用的一个stage 2的参考文件:

{
    "bfloat16": {
        "enabled": false
    },
    "fp16": {
        "enabled": true,
        "auto_cast": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1e-4
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 1e-5,
            "betas": [0.9, 0.999],
            "eps": 1e-08,
            "weight_decay": 1e-2
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 0.001,
            "warmup_num_steps": 1000
        }
    },
    "zero_optimization": {
        "stage": 2,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "allgather_partitions": true,
        "allgather_bucket_size": 2e8,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 2e8,
        "contiguous_gradients": true
    },
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 1.0,
    "train_batch_size": 8,
    "train_micro_batch_size_per_gpu": 1,
    "steps_per_print": 1e5
}

Loss为NAN

训练时用的是bf16,使用时是fp16。常常发生于google在TPU上train的模型,如T5。此时需要使用fp32或者bf16。

训练精度

  • 由于 fp16 混合精度大大减少了内存需求,并可以实现更快的速度,因此只有在在此训练模式下表现不佳时,才考虑不使用混合精度训练。 通常,当模型未在 fp16 混合精度中进行预训练时,会出现这种情况(例如,使用 bf16 预训练的模型)。 这样的模型可能会溢出,导致loss为NaN。 如果是这种情况,使用完整的 fp32 模式。

  • 如果是基于 Ampere 架构的 GPU,pytorch 1.7 及更高版本将自动切换为使用更高效的 tf32 格式进行某些操作,但结果仍将采用 fp32。

  • 使用 Trainer,可以使用 --tf32 启用它,或使用 --tf32 0 或 --no_tf32 禁用它。 PyTorch 默认值是使用tf32。

调参步骤

  • 将batch_size设置为1,通过梯度累积实现任意的有效batch_size

  • 如果OOM则,设置--gradient_checkpointing 1 (HF Trainer),或者 model.gradient_checkpointing_enable()

  • 如果OOM则,尝试ZeRO stage 2

  • 如果OOM则,尝试ZeRO stage 2 + offload_optimizer

  • 如果OOM则,尝试ZeRO stage 3

  • 如果OOM则,尝试offload_param到CPU

  • 如果OOM则,尝试offload_optimizer到CPU

  • 如果OOM则,尝试降低一些默认参数。比如使用generate时,减小beam search的搜索范围

  • 如果OOM则,使用混合精度训练,在Ampere的GPU上使用bf16,在旧版本GPU上使用fp16

  • 如果仍然OOM,则使用ZeRO-Infinity ,使用offload_param和offload_optimizer到NVME

  • 一旦使用batch_size=1时,没有导致OOM,测量此时的有效吞吐量,然后尽可能增大batch_size

  • 开始优化参数,可以关闭offload参数,或者降低ZeRO stage,然后调整batch_size,然后继续测量吞吐量,直到性能比较满意(调参可以增加66%的性能)

示例1

官方提供了完整的例子,很多操作可以照抄:GitHub - microsoft/DeepSpeedExamples: Example models using DeepSpeed

初学者可以从imagenet开始学:DeepSpeedExamples/training/imagenet/main.py at master · microsoft/DeepSpeedExamples · GitHub

import argparse
import os
import pdb
import random
import shutil
import time
import warnings
from enum import Enum

import torch
import deepspeed
from openpyxl import Workbook
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset

model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
parser.add_argument('data', metavar='DIR', nargs='?', default='imagenet',
                    help='path to dataset (default: imagenet)')
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
                    choices=model_names,
                    help='model architecture: ' +
                        ' | '.join(model_names) +
                        ' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
                    help='number of total epochs to run')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                    help='manual epoch number (useful on restarts)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
                    metavar='N',
                    help='mini-batch size (default: 256), this is the total '
                         'batch size of all GPUs on the current node when '
                         'using Data Parallel or Distributed Data Parallel')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
                    metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
                    help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
                    metavar='W', help='weight decay (default: 1e-4)',
                    dest='weight_decay')
parser.add_argument('-p', '--print-freq', default=10, type=int,
                    metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
                    help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
                    help='use pre-trained model')
parser.add_argument('--world-size', default=-1, type=int,
                    help='number of nodes for distributed training')
parser.add_argument('--seed', default=None, type=int,
                    help='seed for initializing training. ')
parser.add_argument('--gpu', default=None, type=int,
                    help='GPU id to use.')
parser.add_argument('--multiprocessing_distributed', action='store_true',
                    help='Use multi-processing distributed training to launch '
                         'N processes per node, which has N GPUs. This is the '
                         'fastest way to use PyTorch for either single node or '
                         'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument('--local_rank', type=int, default=-1, help="local rank for distributed training on gpus")

parser = deepspeed.add_config_arguments(parser)
best_acc1 = 0

def main():
    args = parser.parse_args()

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
        cudnn.deterministic = True
        cudnn.benchmark = False
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
        
    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if args.world_size == -1:
        args.world_size = int(os.environ["WORLD_SIZE"])

    args.distributed = args.world_size > 1 or args.multiprocessing_distributed

    if torch.cuda.is_available():
        ngpus_per_node = torch.cuda.device_count()
    else:
        ngpus_per_node = 1

    args.world_size = ngpus_per_node * args.world_size
    t_losses, t_acc1s = main_worker(args.gpu, ngpus_per_node, args)
    #dist.barrier()
    
    # Write the losses to an excel file
    if dist.get_rank() ==0:
        all_losses = [torch.empty_like(t_losses) for _ in range(ngpus_per_node)]
        dist.gather(tensor=t_losses, gather_list=all_losses,dst=0)
    else:
        dist.gather(tensor=t_losses, dst=0)

    if dist.get_rank() ==0:
        all_acc1s = [torch.empty_like(t_acc1s) for _ in range(ngpus_per_node)]
        dist.gather(tensor=t_acc1s, gather_list=all_acc1s,dst=0)
    else:
        dist.gather(tensor=t_acc1s, dst=0)

    if dist.get_rank() == 0:
        outputfile = "Acc_loss_log.xlsx"
        workbook = Workbook()
        sheet1 = workbook.active
        sheet1.cell(row= 1, column = 1, value = "Loss")
        sheet1.cell(row= 1, column = ngpus_per_node + 4, value = "Acc")
        for rank in range(ngpus_per_node):
            for row_idx, (gpu_losses, gpu_acc1s) in enumerate(zip(all_losses[rank], all_acc1s[rank])):
                sheet1.cell(row=row_idx + 2, column = rank+1, value = float(gpu_losses))
                sheet1.cell(row=row_idx + 2, column = rank+1 + ngpus_per_node + 3, value = float(gpu_acc1s))
        workbook.save(outputfile)

def main_worker(gpu, ngpus_per_node, args):
    global best_acc1
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    # In case of distributed process, initializes the distributed backend
    # which will take care of sychronizing nodes/GPUs
    if args.local_rank == -1:
        if args.gpu:
            device = torch.device('cuda:{}'.format(args.gpu))
        else:
            device = torch.device("cuda")
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        deepspeed.init_distributed()
    def print_rank_0(msg):
        if args.local_rank <=0:
            print(msg)

    args.batch_size = int(args.batch_size / ngpus_per_node)
    if not torch.cuda.is_available():# and not torch.backends.mps.is_available():
        print('using CPU, this will be slow')
        device = torch.device("cpu")
        model = model.to(device)

    # define loss function (criterion), optimizer, and learning rate scheduler
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.gpu is None:
                checkpoint = torch.load(args.resume)
            elif torch.cuda.is_available():
                # Map model to be loaded to specified single gpu.
                loc = 'cuda:{}'.format(args.gpu)
                checkpoint = torch.load(args.resume, map_location=loc)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            if args.gpu is not None:
                # best_acc1 may be from a checkpoint from a different GPU
                best_acc1 = best_acc1.to(args.gpu)
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # Initialize DeepSpeed for the model
    model, optimizer, _, _ = deepspeed.initialize(
        model = model,
        optimizer = optimizer,
        args = args,
        lr_scheduler = None,#scheduler,
        dist_init_required=True
        )

    # Data loading code
    if args.dummy:
        print("=> Dummy data is used!")
        train_dataset = datasets.FakeData(1281167, (3, 224, 224), 1000, transforms.ToTensor())
        val_dataset = datasets.FakeData(50000, (3, 224, 224), 1000, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    if args.local_rank != -1:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
    else:
        train_sampler = None
        val_sampler = None

    print("Batch_size:",args.batch_size)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)


    if args.evaluate:
        validate(val_loader, model, criterion, args)
        return

    losses = torch.empty(args.epochs).cuda()
    acc1s = torch.empty(args.epochs).cuda()
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        # train for one epoch
        this_loss = train(train_loader, model, criterion, optimizer, epoch, device, args)
        losses[epoch] = this_loss

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, args)
        acc1s[epoch] = acc1

        scheduler.step()
        
        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        if not args.multiprocessing_distributed or (args.multiprocessing_distributed
                and args.gpu is None):
            save_checkpoint({
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_acc1': best_acc1,
                'optimizer' : optimizer.state_dict(),
                'scheduler' : scheduler.state_dict()
            }, is_best)

    return (losses, acc1s)

def train(train_loader, model, criterion, optimizer, epoch, device, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):

        # measure data loading time
        data_time.update(time.time() - end)

        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))
        top5.update(acc5[0], images.size(0))

        # compute gradient and do SGD step
        model.backward(loss)
        model.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % args.print_freq == 0:
            progress.display(i + 1)

    return (float(losses.val))


def validate(val_loader, model, criterion, args):

    def run_validate(loader, base_progress=0):
        with torch.no_grad():
            end = time.time()
            for i, (images, target) in enumerate(loader):
                i = base_progress + i

                if torch.cuda.is_available():
                    target = target.cuda(args.gpu, non_blocking=True)
                    images = images.cuda(args.gpu, non_blocking=True)

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, target, topk=(1, 5))
                losses.update(loss.item(), images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % args.print_freq == 0:
                    progress.display(i + 1)

    batch_time = AverageMeter('Time', ':6.3f', Summary.NONE)
    losses = AverageMeter('Loss', ':.4e', Summary.NONE)
    top1 = AverageMeter('Acc@1', ':6.2f', Summary.AVERAGE)
    top5 = AverageMeter('Acc@5', ':6.2f', Summary.AVERAGE)
    progress = ProgressMeter(
        len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    run_validate(val_loader)
    if args.distributed:
        top1.all_reduce()
        top5.all_reduce()

    if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
        aux_val_dataset = Subset(val_loader.dataset,
                                 range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
        aux_val_loader = torch.utils.data.DataLoader(
            aux_val_dataset, batch_size=args.batch_size, shuffle=False,
            num_workers=args.workers, pin_memory=True)
        run_validate(aux_val_loader, len(val_loader))

    progress.display_summary()

    return top1.avg


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

class Summary(Enum):
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def all_reduce(self):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
        dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
        self.sum, self.count = total.tolist()
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)
        
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))
        
    def display_summary(self):
        entries = [" *"]
        entries += [meter.summary() for meter in self.meters]
        print(' '.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


if __name__ == '__main__':
    main()

示例2

原始的DDP代码:

import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess

from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoModel

from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr

from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_distributed_dataloader, sanity_check


def init_dist(launcher="slurm", backend="nccl", port=29500, **kwargs):
    """Initializes distributed environment."""
    if launcher == "pytorch":
        rank = int(os.environ["RANK"])
        num_gpus = torch.cuda.device_count()
        local_rank = rank % num_gpus
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend=backend, **kwargs)

    elif launcher == "slurm":
        proc_id = int(os.environ["SLURM_PROCID"])
        ntasks = int(os.environ["SLURM_NTASKS"])
        node_list = os.environ["SLURM_NODELIST"]
        num_gpus = torch.cuda.device_count()
        local_rank = proc_id % num_gpus
        torch.cuda.set_device(local_rank)
        addr = subprocess.getoutput(
            f"scontrol show hostname {node_list} | head -n1")
        os.environ["MASTER_ADDR"] = addr
        os.environ["WORLD_SIZE"] = str(ntasks)
        os.environ["RANK"] = str(proc_id)
        port = os.environ.get("PORT", port)
        os.environ["MASTER_PORT"] = str(port)
        dist.init_process_group(backend=backend)
        print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; "
              f"node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")

    else:
        raise NotImplementedError(f"Not implemented launcher type: `{launcher}`!")

    return local_rank


def main(
    image_finetune: bool,

    name: str,
    launcher: str,

    output_dir: str,
    pretrained_model_path: str,
    pretrained_clip_path: str,

    train_data: Dict,
    train_cfg: bool = True,
    cfg_uncond_ratio: float = 0.05,
    pose_shuffle_ratio: float = 0.0,

    pretrained_vae_path: str = "",
    pretrained_mm_path: str = "",
    unet_checkpoint_path: str = "",
    unet_additional_kwargs: Dict = None,
    noise_scheduler_kwargs: Dict = None,
    pose_guider_kwargs: Dict = None,
    fusion_blocks: str = "full",
    clip_projector_kwargs: Dict = None,
    fix_ref_t: bool = False,
    zero_snr: bool = False,
    snr_gamma: Optional[float] = None,
    v_pred: bool = False,

    max_train_epoch: int = -1,
    max_train_steps: int = 100,

    learning_rate: float = 5e-5,
    scale_lr: bool = False,
    lr_warmup_steps: int = 0,
    lr_scheduler: str = "constant",

    trainable_modules: Tuple[str] = (),
    num_workers: int = 4,
    train_batch_size: int = 1,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    checkpointing_epochs: int = 5,
    checkpointing_steps: int = -1,
    checkpointing_steps_tuple: Tuple[int] = (),

    mixed_precision: str = "fp16",
    resume: bool = False,

    global_seed: int or str = 42,
    is_debug: bool = False,

    *args,
    **kwargs,
):
    # check version
    check_min_version("0.30.0.dev0")

    # Initialize distributed training
    local_rank = init_dist(launcher=launcher)
    global_rank = dist.get_rank()
    num_processes = dist.get_world_size()
    is_main_process = global_rank == 0

    if global_seed == "random":
        global_seed = int(datetime.now().timestamp()) % 65535

    seed = global_seed + global_rank
    torch.manual_seed(seed)

    # Logging folder
    if resume:
        # first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
        # the second split "a/b/c/checkpoints" -> "a/b/c
        output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
    else:
        folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
        output_dir = os.path.join(output_dir, folder_name)
    if is_debug and os.path.exists(output_dir):
        os.system(f"rm -rf {output_dir}")
    *_, config = inspect.getargvalues(inspect.currentframe())

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    # Handle the output folder creation
    if is_main_process:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))

    if train_cfg and is_main_process:
        logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")

    # Load scheduler, tokenizer and models
    if is_main_process:
        logging.info("Load scheduler, tokenizer and models.")
    if pretrained_vae_path != "":
        if 'SVD' in pretrained_vae_path:
            vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
        else:
            vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
    else:
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")

    image_encoder = AutoModel.from_pretrained(pretrained_clip_path)

    noise_scheduler_kwargs_dict = OmegaConf.to_container(
        noise_scheduler_kwargs
    ) if noise_scheduler_kwargs is not None else {}
    if zero_snr:
        if is_main_process:
            logging.info("Enable Zero-SNR")
        noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
        if v_pred:
            noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
            noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
    noise_scheduler = DDPMScheduler.from_pretrained(
        pretrained_model_path,
        subfolder="scheduler",
        **noise_scheduler_kwargs_dict,
    )

    unet = RealisDanceUnet(
        pretrained_model_path=pretrained_model_path,
        image_finetune=image_finetune,
        unet_additional_kwargs=unet_additional_kwargs,
        pose_guider_kwargs=pose_guider_kwargs,
        clip_projector_kwargs=clip_projector_kwargs,
        fix_ref_t=fix_ref_t,
        fusion_blocks=fusion_blocks,
    )

    # Load pretrained unet weights
    unet_state_dict = {}
    if pretrained_mm_path != "" and not image_finetune:
        if is_main_process:
            logging.info(f"mm from checkpoint: {pretrained_mm_path}")
        mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
        state_dict = mm_ckpt[
            "state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
        unet_state_dict.update(
            {name: param for name, param in state_dict.items() if "motion_modules." in name})
        unet_state_dict.pop("animatediff_config", "")
        m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
        print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0
        for unet_main_module in unet.unet_main.children():
            if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
                unet_main_module.temporal_transformer.proj_out = zero_module(
                    unet_main_module.temporal_transformer.proj_out
                )

    resume_step = 0
    if unet_checkpoint_path != "":
        if is_main_process:
            logging.info(f"from checkpoint: {unet_checkpoint_path}")
        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
        if "global_step" in unet_checkpoint_path:
            if is_main_process:
                logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
            if resume:
                resume_step = unet_checkpoint_path['global_step']
        state_dict = unet_checkpoint_path["state_dict"]
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module."):
                new_k = k[7:]
            else:
                new_k = k
            new_state_dict[new_k] = state_dict[k]
        m, u = unet.load_state_dict(new_state_dict, strict=False)
        if is_main_process:
            logging.info(f"Load from checkpoint with missing keys:\n{m}")
            logging.info(f"Load from checkpoint with unexpected keys:\n{u}")
        assert len(u) == 0

    # Set unet trainable parameters
    unet.requires_grad_(False)
    unet.set_trainable_parameters(trainable_modules)

    # Set learning rate and optimizer
    if scale_lr:
        learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)

    trainable_parameter_keys = []
    trainable_params = []
    for param_name, param in unet.named_parameters():
        if param.requires_grad:
            trainable_parameter_keys.append(param_name)
            trainable_params.append(param)
    if is_main_process:
        logging.info(f"trainable params number: {trainable_parameter_keys}")
        logging.info(f"trainable params number: {len(trainable_params)}")
        logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")

    optimizer = torch.optim.AdamW(
        trainable_params,
        lr=learning_rate,
        betas=(adam_beta1, adam_beta2),
        weight_decay=adam_weight_decay,
        eps=adam_epsilon,
    )

    # Set learning rate scheduler
    lr_scheduler = get_scheduler(
        lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
        num_training_steps=max_train_steps * gradient_accumulation_steps,
    )

    # Freeze vae and image_encoder
    vae.eval()
    vae.requires_grad_(False)
    image_encoder.eval()
    image_encoder.requires_grad_(False)

    # move to cuda
    vae.to(local_rank)
    image_encoder.to(local_rank)
    unet.to(local_rank)
    unet = DDP(unet, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True)

    # Get the training dataloader
    train_dataloader = get_distributed_dataloader(
        dataset_config=train_data,
        batch_size=train_batch_size,
        num_processes=num_processes,
        num_workers=num_workers,
        shuffle=True,
        global_rank=global_rank,
        seed=global_seed,)

    # Get the training iteration
    overrode_max_train_steps = False
    if max_train_steps == -1:
        assert max_train_epoch != -1
        max_train_steps = max_train_epoch * len(train_dataloader)
        overrode_max_train_steps = True

    if checkpointing_steps == -1:
        assert checkpointing_epochs != -1
        checkpointing_steps = checkpointing_epochs * len(train_dataloader)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if overrode_max_train_steps:
        max_train_steps = max_train_epoch * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps

    if is_main_process:
        logging.info("***** Running training *****")
        logging.info(f"  Num examples = {len(train_dataloader)}")
        logging.info(f"  Num Epochs = {num_train_epochs}")
        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
        logging.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
        logging.info(f"  Total optimization steps = {max_train_steps}")
    global_step = resume_step
    first_epoch = int(resume_step / num_update_steps_per_epoch)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
    progress_bar.set_description("Steps")

    # Support mixed-precision training
    scaler = torch.cuda.amp.GradScaler() if mixed_precision in ("fp16", "bf16") else None

    for epoch in range(first_epoch, num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)
        unet.train()

        for step, batch in enumerate(train_dataloader):
            # Data batch sanity check
            # if epoch == first_epoch and step == 0:
            #     sanity_check(batch, f"{output_dir}/sanity_check", image_finetune, global_rank)

            """ >>>> Training >>>> """
            # Get images
            image = batch["image"].to(local_rank)
            pose = batch["pose"].to(local_rank)
            hamer = batch["hamer"].to(local_rank)
            smpl = batch["smpl"].to(local_rank)
            ref_image = batch["ref_image"].to(local_rank)
            ref_image_clip = batch["ref_image_clip"].to(local_rank)

            if train_cfg and random.random() < cfg_uncond_ratio:
                ref_image_clip = torch.zeros_like(ref_image_clip)
                drop_reference = True
            else:
                drop_reference = False

            if not image_finetune and pose_shuffle_ratio > 0:
                B, C, L, H, W = pose.shape
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]

            # Convert images to latent space
            with torch.no_grad():
                if not image_finetune:
                    video_length = image.shape[2]
                    image = rearrange(image, "b c f h w -> (b f) c h w")

                latents = vae.encode(image).latent_dist
                latents = latents.sample()
                latents = latents * vae.config.scaling_factor

                ref_latents = vae.encode(ref_image).latent_dist
                ref_latents = ref_latents.sample()
                ref_latents = ref_latents * vae.config.scaling_factor

                clip_latents = image_encoder(ref_image_clip).last_hidden_state
                # clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)

                if not image_finetune:
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)

                # Sample noise that we"ll add to the latents
                bsz = latents.shape[0]
                noise = torch.randn_like(latents)

                del image, ref_image, ref_image_clip
                torch.cuda.empty_cache()

            # Sample a random timestep for each video
            train_timesteps = noise_scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            # Mixed-precision training
            if mixed_precision in ("fp16", "bf16"):
                weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
            else:
                weight_dtype = torch.float32
            with torch.cuda.amp.autocast(
                enabled=mixed_precision in ("fp16", "bf16"),
                dtype=weight_dtype
            ):
                model_pred = unet(
                    sample=noisy_latents,
                    ref_sample=ref_latents,
                    pose=pose,
                    hamer=hamer,
                    smpl=smpl,
                    timestep=timesteps,
                    encoder_hidden_states=clip_latents,
                    drop_reference=drop_reference,
                ).sample

                if snr_gamma is None:
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                else:
                    snr = compute_snr(noise_scheduler, timesteps)
                    mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
                        dim=1
                    )[0]
                    if noise_scheduler.config.prediction_type == "epsilon":
                        mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8)  # incase zero-snr
                    elif noise_scheduler.config.prediction_type == "v_prediction":
                        mse_loss_weights = mse_loss_weights / (snr + 1)
                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                    loss = loss.mean()

            # Backpropagate
            if mixed_precision in ("fp16", "bf16"):
                scaler.scale(loss).backward()
                """ >>> gradient clipping >>> """
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
                """ <<< gradient clipping <<< """
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                """ >>> gradient clipping >>> """
                torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
                """ <<< gradient clipping <<< """
                optimizer.step()

            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            global_step += 1
            """ <<<< Training <<<< """
            
            del noisy_latents, ref_latents, pose, hamer, smpl
            torch.cuda.empty_cache()

            # Save checkpoint
            if is_main_process and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
                save_path = os.path.join(output_dir, f"checkpoints")
                state_dict = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "state_dict": unet.state_dict(),
                }
                torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
                logging.info(f"Saved state to {save_path} (global_step: {global_step})")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            if is_main_process and global_step % 500 == 0:
                logging.info(f"step: {global_step} / {max_train_steps}:  {logs}")
            if global_step >= max_train_steps:
                break

    # save the final checkpoint
    if is_main_process:
        save_path = os.path.join(output_dir, f"checkpoints")
        state_dict = {
            "epoch": num_train_epochs - 1,
            "global_step": global_step,
            "state_dict": unet.state_dict(),
        }
        torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
        logging.info(f"Saved final state to {save_path}")

    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
    parser.add_argument("--output", type=str, required=False, default="")
    parser.add_argument("--resume", type=str, default="")
    args = parser.parse_args()

    exp_name = Path(args.config).stem
    exp_config = OmegaConf.load(args.config)

    if args.resume != "":
        exp_config["unet_checkpoint_path"] = args.resume
        exp_config["lr_warmup_steps"] = 0
        exp_config["resume"] = True

    if args.output != "":
        exp_config["output_dir"] = args.output

    main(name=exp_name, launcher=args.launcher, **exp_config)

修改后:

import os
import math
import logging
import inspect
import argparse
import datetime
import random
import subprocess

from pathlib import Path
from tqdm.auto import tqdm
from einops import rearrange
from omegaconf import OmegaConf
from typing import Dict, Tuple, Optional

import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from transformers import AutoModel

from diffusers import AutoencoderKL, DDPMScheduler, AutoencoderKLTemporalDecoder
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.training_utils import compute_snr

from src.models.motion_module import VanillaTemporalModule, zero_module
from src.models.rd_unet import RealisDanceUnet
from src.utils.util import get_deepspeed_dataloader

import deepspeed

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--output", type=str, default="")
parser.add_argument("--resume", type=str, default="")
parser.add_argument('--local_rank', type=int, default=-1,
                help='local rank passed from distributed launcher')
parser = deepspeed.add_config_arguments(parser)
deep_args = parser.parse_args()

exp_name = Path(deep_args.config).stem
exp_config = OmegaConf.load(deep_args.config)

if deep_args.resume != "":
    exp_config["unet_checkpoint_path"] = deep_args.resume
    exp_config["lr_warmup_steps"] = 0
    exp_config["resume"] = True

if deep_args.output != "":
    exp_config["output_dir"] = deep_args.output

def main(
    image_finetune: bool,

    name: str,

    output_dir: str,
    pretrained_model_path: str,
    pretrained_clip_path: str,

    train_data: Dict,
    train_cfg: bool = True,
    cfg_uncond_ratio: float = 0.05,
    pose_shuffle_ratio: float = 0.0,

    pretrained_vae_path: str = "",
    pretrained_mm_path: str = "",
    unet_checkpoint_path: str = "",
    unet_additional_kwargs: Dict = None,
    noise_scheduler_kwargs: Dict = None,
    pose_guider_kwargs: Dict = None,
    fusion_blocks: str = "full",
    clip_projector_kwargs: Dict = None,
    fix_ref_t: bool = False,
    zero_snr: bool = False,
    snr_gamma: Optional[float] = None,
    v_pred: bool = False,

    max_train_epoch: int = -1,
    max_train_steps: int = 100,

    learning_rate: float = 5e-5,
    scale_lr: bool = False,
    lr_warmup_steps: int = 0,
    lr_scheduler: str = "constant",

    trainable_modules: Tuple[str] = (),
    num_workers: int = 4,
    train_batch_size: int = 1,
    adam_beta1: float = 0.9,
    adam_beta2: float = 0.999,
    adam_weight_decay: float = 1e-2,
    adam_epsilon: float = 1e-08,
    max_grad_norm: float = 1.0,
    gradient_accumulation_steps: int = 1,
    checkpointing_epochs: int = 5,
    checkpointing_steps: int = -1,
    checkpointing_steps_tuple: Tuple[int] = (),

    mixed_precision: str = "fp16",
    resume: bool = False,

    global_seed: int or str = 42,
    is_debug: bool = False,

    *args,
    **kwargs,
):
    # check version
    check_min_version("0.30.0.dev0")

    # Initialize distributed training
    deepspeed.init_distributed()

    global_rank = dist.get_rank()

    if global_seed == "random":
        global_seed = int(datetime.now().timestamp()) % 65535

    seed = global_seed
    torch.manual_seed(seed)

    # Logging folder
    if resume:
        # first split "a/b/c/checkpoints/xxx.ckpt" -> "a/b/c/checkpoints",
        # the second split "a/b/c/checkpoints" -> "a/b/c
        output_dir = os.path.split(os.path.split(unet_checkpoint_path)[0])[0]
    else:
        folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H")
        output_dir = os.path.join(output_dir, folder_name)
    if is_debug and os.path.exists(output_dir):
        os.system(f"rm -rf {output_dir}")
    *_, config = inspect.getargvalues(inspect.currentframe())

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )

    if mixed_precision in ("fp16", "bf16"):
        weight_dtype = torch.bfloat16 if mixed_precision == "bf16" else torch.float16
    else:
        weight_dtype = torch.float32


    if train_cfg:
        logging.info(f"Enable CFG training with drop rate {cfg_uncond_ratio}.")

    # Load scheduler, tokenizer and models
    if pretrained_vae_path != "":
        if 'SVD' in pretrained_vae_path:
            vae = AutoencoderKLTemporalDecoder.from_pretrained(pretrained_vae_path, subfolder="vae")
        else:
            vae = AutoencoderKL.from_pretrained(pretrained_vae_path, subfolder="sd-vae-ft-mse")
    else:
        vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")

    image_encoder = AutoModel.from_pretrained(pretrained_clip_path)

    noise_scheduler_kwargs_dict = OmegaConf.to_container(
        noise_scheduler_kwargs
    ) if noise_scheduler_kwargs is not None else {}
    if zero_snr:
        logging.info("Enable Zero-SNR")
        noise_scheduler_kwargs_dict["rescale_betas_zero_snr"] = True
        if v_pred:
            noise_scheduler_kwargs_dict["prediction_type"] = "v_prediction"
            noise_scheduler_kwargs_dict["timestep_spacing"] = "linspace"
    noise_scheduler = DDPMScheduler.from_pretrained(
        pretrained_model_path,
        subfolder="scheduler",
        **noise_scheduler_kwargs_dict,
    )

    unet = RealisDanceUnet(
        pretrained_model_path=pretrained_model_path,
        image_finetune=image_finetune,
        unet_additional_kwargs=unet_additional_kwargs,
        pose_guider_kwargs=pose_guider_kwargs,
        clip_projector_kwargs=clip_projector_kwargs,
        fix_ref_t=fix_ref_t,
        fusion_blocks=fusion_blocks,
    )

    # Load pretrained unet weights
    unet_state_dict = {}
    if pretrained_mm_path != "" and not image_finetune:
        logging.info(f"mm from checkpoint: {pretrained_mm_path}")
        mm_ckpt = torch.load(pretrained_mm_path, map_location="cpu")
        state_dict = mm_ckpt[
            "state_dict"] if "state_dict" in mm_ckpt else mm_ckpt
        unet_state_dict.update(
            {name: param for name, param in state_dict.items() if "motion_modules." in name})
        unet_state_dict.pop("animatediff_config", "")
        m, u = unet.unet_main.load_state_dict(unet_state_dict, strict=False)
        print(f"mm ckpt missing keys: {len(m)}, unexpected keys: {len(u)}")
        assert len(u) == 0
        for unet_main_module in unet.unet_main.children():
            if isinstance(unet_main_module, VanillaTemporalModule) and unet_main_module.zero_initialize:
                unet_main_module.temporal_transformer.proj_out = zero_module(
                    unet_main_module.temporal_transformer.proj_out
                )

    resume_step = 0
    if unet_checkpoint_path != "":
        logging.info(f"from checkpoint: {unet_checkpoint_path}")
        unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
        if "global_step" in unet_checkpoint_path:
            logging.info(f"global_step: {unet_checkpoint_path['global_step']}")
        state_dict = unet_checkpoint_path["state_dict"]
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith("module."):
                new_k = k[7:]
            else:
                new_k = k
            new_state_dict[new_k] = state_dict[k]
        m, u = unet.load_state_dict(new_state_dict, strict=False)
        logging.info(f"Load from checkpoint with missing keys:\n{m}")
        logging.info(f"Load from checkpoint with unexpected keys:\n{u}")


    # Set unet trainable parameters
    unet.set_trainable_parameters(trainable_modules)

    trainable_parameter_keys = []
    trainable_params = []
    for param_name, param in unet.named_parameters():
        if param.requires_grad:
            trainable_parameter_keys.append(param_name)
            trainable_params.append(param)

    # logging.info(f"trainable params number: {trainable_parameter_keys}")
    logging.info(f"trainable params number: {len(trainable_params)}")
    logging.info(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")

    torch.cuda.set_device(deep_args.local_rank)
    device = torch.device("cuda", deep_args.local_rank)
    unet, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=unet,
        # optimizer=optimizer,
        model_parameters=trainable_params,
        args=deep_args
    )

    # Freeze vae and image_encoder
    vae.eval()
    vae.requires_grad_(False)
    image_encoder.eval()
    image_encoder.requires_grad_(False)

    # move to cuda
    vae.to(weight_dtype)
    image_encoder.to(weight_dtype)
    unet.to(weight_dtype)
    vae.to(device)
    image_encoder.to(device)
    unet.to(device)

    # Get the training dataloader
    train_dataloader = get_deepspeed_dataloader(
        dataset_config=train_data,
        batch_size=train_batch_size,
        num_workers=num_workers,
        shuffle=True,
        seed=global_seed,)

    # Get the training iteration
    overrode_max_train_steps = False
    if max_train_steps == -1:
        assert max_train_epoch != -1
        max_train_steps = max_train_epoch * len(train_dataloader)
        overrode_max_train_steps = True

    if checkpointing_steps == -1:
        assert checkpointing_epochs != -1
        checkpointing_steps = checkpointing_epochs * len(train_dataloader)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
    if overrode_max_train_steps:
        max_train_steps = max_train_epoch * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

    # Handle the output folder creation
    if unet.global_rank == 0:
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
        os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
        OmegaConf.save(config, os.path.join(output_dir, "config.yaml"))

    # Train!
    if unet.global_rank == 0:
        logging.info("***** Running training *****")
        logging.info(f"  Num examples = {len(train_dataloader)}")
        logging.info(f"  Num Epochs = {num_train_epochs}")
        logging.info(f"  Instantaneous batch size per device = {train_batch_size}")
        logging.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
        logging.info(f"  Total optimization steps = {max_train_steps}")
    global_step = resume_step
    first_epoch = int(resume_step / num_update_steps_per_epoch)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, max_train_steps), disable=not unet.global_rank == 0)
    progress_bar.set_description("Steps")


    for epoch in range(first_epoch, num_train_epochs):
        train_dataloader.sampler.set_epoch(epoch)
        unet.train()

        for step, batch in enumerate(train_dataloader):
            # Data batch sanity check
            """ >>>> Training >>>> """
            # Get images
            image = batch["image"].to(weight_dtype)
            pose = batch["pose"].to(weight_dtype)
            hamer = batch["hamer"].to(weight_dtype)
            smpl = batch["smpl"].to(weight_dtype)
            ref_image = batch["ref_image"].to(weight_dtype)
            ref_image_clip = batch["ref_image_clip"].to(weight_dtype)

            image = image.to(device)
            pose = pose.to(device)
            hamer = hamer.to(device)
            smpl = smpl.to(device)
            ref_image = ref_image.to(device)
            ref_image_clip = ref_image_clip.to(device)

            if train_cfg and random.random() < cfg_uncond_ratio:
                ref_image_clip = torch.zeros_like(ref_image_clip)
                drop_reference = True
            else:
                drop_reference = False

            if not image_finetune and pose_shuffle_ratio > 0:
                B, C, L, H, W = pose.shape
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    pose[:, :, rand_idx[0], :, :] = pose[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    hamer[:, :, rand_idx[0], :, :] = hamer[:, :, rand_idx[rand_idx[0]], :, :]
                if random.random() < pose_shuffle_ratio:
                    rand_idx = torch.randperm(L).long()
                    smpl[:, :, rand_idx[0], :, :] = smpl[:, :, rand_idx[rand_idx[0]], :, :]

            # Convert images to latent space
            with torch.no_grad():
                if not image_finetune:
                    video_length = image.shape[2]
                    image = rearrange(image, "b c f h w -> (b f) c h w")

                latents = vae.encode(image).latent_dist
                latents = latents.sample()
                latents = latents * vae.config.scaling_factor

                ref_latents = vae.encode(ref_image).latent_dist
                ref_latents = ref_latents.sample()
                ref_latents = ref_latents * vae.config.scaling_factor

                clip_latents = image_encoder(ref_image_clip).last_hidden_state
                # clip_latents = image_encoder.vision_model.post_layernorm(clip_latents)

                if not image_finetune:
                    latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)

                # Sample noise that we"ll add to the latents
                bsz = latents.shape[0]
                noise = torch.randn_like(latents)

                latents = latents.to(weight_dtype)
                noise = noise.to(weight_dtype)

            # Sample a random timestep for each video
            train_timesteps = noise_scheduler.config.num_train_timesteps
            timesteps = torch.randint(0, train_timesteps, (bsz,), device=latents.device)
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            # Get the target for loss depending on the prediction type
            if noise_scheduler.config.prediction_type == "epsilon":
                target = noise
            elif noise_scheduler.config.prediction_type == "v_prediction":
                target = noise_scheduler.get_velocity(latents, noise, timesteps)
            else:
                raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

            # Predict the noise residual and compute loss
            model_pred = unet(
                sample=noisy_latents,
                ref_sample=ref_latents,
                pose=pose,
                hamer=hamer,
                smpl=smpl,
                timestep=timesteps,
                encoder_hidden_states=clip_latents,
                drop_reference=drop_reference,
            ).sample

            if snr_gamma is None:
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
            else:
                snr = compute_snr(noise_scheduler, timesteps)
                mse_loss_weights = torch.stack([snr, snr_gamma * torch.ones_like(timesteps)], dim=1).min(
                    dim=1
                )[0]
                if noise_scheduler.config.prediction_type == "epsilon":
                    mse_loss_weights = mse_loss_weights / snr.clamp(min=1e-8)  # incase zero-snr
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    mse_loss_weights = mse_loss_weights / (snr + 1)
                loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
                loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
                loss = loss.mean()

            unet.backward(loss)
            # unet.clip_grad_norm_(max_grad_norm)
            unet.step()

            lr_scheduler.step()
            optimizer.step()
            # optimizer.zero_grad()
            progress_bar.update(1)
            global_step += 1
            """ <<<< Training <<<< """

            # Save checkpoint
            if unet.global_rank == 0 and (global_step % checkpointing_steps == 0 or global_step in checkpointing_steps_tuple):
                save_path = os.path.join(output_dir, f"checkpoints")
                state_dict = {
                    "epoch": epoch,
                    "global_step": global_step,
                    "state_dict": unet.state_dict(),
                }
                torch.save(state_dict, os.path.join(save_path, f"checkpoint-iter-{global_step}.ckpt"))
                logging.info(f"Saved state to {save_path} (global_step: {global_step})")

            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            if global_step >= max_train_steps:
                break

    # save the final checkpoint
    if unet.global_rank == 0:
        save_path = os.path.join(output_dir, f"checkpoints")
        state_dict = {
            "epoch": num_train_epochs - 1,
            "global_step": global_step,
            "state_dict": unet.state_dict(),
        }
        torch.save(state_dict, os.path.join(save_path, f"checkpoint-final.ckpt"))
        logging.info(f"Saved final state to {save_path}")

    dist.destroy_process_group()

main(name=exp_name, **exp_config)

另外训diffusion优先用accelerate,这里面已经集成了deepspeed!


原文地址:https://blog.csdn.net/qq_41895747/article/details/143589904

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!