自学内容网 自学内容网

litgpt框架笔记

litgpt的fsdp执行原理

python __main__.py finetune_full meta-llama/Lflama-2-7b-hf --config /home/xingzhuang/llm/litgpt/config_hub/finetune/llama-2-7b/full.yaml  
( __main__.py 在litgpt/litgpt目录下)

执行该命令大致流程:

  • 先讲讲full.yaml 中的global_batch_size和micro_batch_size参数的含义

    • global_batch_size表示optimizer做一次step的总batch数,global_batch_size会均分给所有GPU,不妨记为local_batch,当某个GPU完成了自己的local_batch后optimizer才能做step更新参数
    • micro_batch_size,每个GPU会将自己的local_batch进一步拆分成micro_batch,拆分大小为micro_batch_size
  • 大致执行流程,主要在litgpt/finetune/full.py文件的fit函数中
    在这里插入图片描述

  • batch = next(train_iterator)每次拿到一个micro_batch做forward

  • is_accumulating表示本轮micro_batch forward完成后,该GPU是否完成了local_batch

  • is_accumulating参数会传给fabric.no_backward_sync判断本轮forward对应的backward是否需要同步其他GPU的local_batch的梯度,其实就是保证local_batch累加的梯度都是自身local_batch的梯度

    • 若is_accumulating为True表示该GPU还未完成local_batch,所以不需要同步其他GPU的local_batch的梯度(具体来讲,就是当某个GPU拉取某个layer的全部权重并算出该layer的梯度后,并不将梯度scatter给其他的GPU)
    • 若is_accumulating为False表示该GPU已完成local_batch,所以会同步其他GPU的local_batch梯度
  • 当所有GPU都完成了自己的local_batch后,则会执行optimizer.step()做一次梯度优化

pytorch-lightning的fsdp+tp原理

python train.py
(train.py在pytorch-lightning/examples/fabric/tensor_parallel/train.py)

执行该命令大致执行流程如下:

litgpt适配fsdp+tp

1.把pytorch-lightning/examples/tensor_parallel文件夹下的parallelism.py和model.py复制到litgpt/litgpt/finetune/下
2. 把full.py中的strategy改为

strategy = ModelParallelStrategy(
            # User-defined function that applies the desired parallelizations specific to the model
            # (TP, FSDP2, activation checkpointing, ...)
            parallelize_fn=parallelize,
            # Define the size of the 2D parallelism
            # Set to "auto" to apply TP intra-node and DP inter-node
            data_parallel_size="auto",  
            tensor_parallel_size="auto",
        )

3.在litgpt/litgpt/model.py下的class CausalSelfAttention的__init__.py方法中加上代码

self.n_heads = config.n_head
self.n_kv_heads = config.n_head  

4.修改parallelism.py文件

5.修改litgpt/litgpt/utils.py的load_checkpoint方法

def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
    if isinstance(fabric.strategy, FSDPStrategy):
        fabric.load_raw(checkpoint_path, model, strict=strict)
    elif isinstance(fabric.strategy, ModelParallelStrategy):
        fabric.load_raw(checkpoint_path, model, strict=False)
    else:
        state_dict = lazy_load(checkpoint_path)
        state_dict = state_dict.get("model", state_dict)
        model.load_state_dict(state_dict, strict=strict)

6.修改litgpt/litgpt/model.py 下的class CausalSelfAttention:
把self.attn改成self.attn_w

7.有个包貌似有问题
/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/api.py
临时解决方法:把/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/style.py的_apply函数中
NotImplementedError改为print,不终止报错


原文地址:https://blog.csdn.net/weixin_46347213/article/details/142327462

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!