litgpt框架笔记
litgpt的fsdp执行原理
python __main__.py finetune_full meta-llama/Lflama-2-7b-hf --config /home/xingzhuang/llm/litgpt/config_hub/finetune/llama-2-7b/full.yaml
( __main__.py 在litgpt/litgpt目录下)
执行该命令大致流程:
-
先讲讲full.yaml 中的global_batch_size和micro_batch_size参数的含义
- global_batch_size表示optimizer做一次step的总batch数,global_batch_size会均分给所有GPU,不妨记为local_batch,当某个GPU完成了自己的local_batch后optimizer才能做step更新参数
- micro_batch_size,每个GPU会将自己的local_batch进一步拆分成micro_batch,拆分大小为micro_batch_size
-
大致执行流程,主要在litgpt/finetune/full.py文件的fit函数中
-
batch = next(train_iterator)每次拿到一个micro_batch做forward
-
is_accumulating表示本轮micro_batch forward完成后,该GPU是否完成了local_batch
-
is_accumulating参数会传给fabric.no_backward_sync判断本轮forward对应的backward是否需要同步其他GPU的local_batch的梯度,其实就是保证local_batch累加的梯度都是自身local_batch的梯度
- 若is_accumulating为True表示该GPU还未完成local_batch,所以不需要同步其他GPU的local_batch的梯度(具体来讲,就是当某个GPU拉取某个layer的全部权重并算出该layer的梯度后,并不将梯度scatter给其他的GPU)
- 若is_accumulating为False表示该GPU已完成local_batch,所以会同步其他GPU的local_batch梯度
-
当所有GPU都完成了自己的local_batch后,则会执行optimizer.step()做一次梯度优化
pytorch-lightning的fsdp+tp原理
python train.py
(train.py在pytorch-lightning/examples/fabric/tensor_parallel/train.py)
执行该命令大致执行流程如下:
litgpt适配fsdp+tp
1.把pytorch-lightning/examples/tensor_parallel文件夹下的parallelism.py和model.py复制到litgpt/litgpt/finetune/下
2. 把full.py中的strategy改为
strategy = ModelParallelStrategy(
# User-defined function that applies the desired parallelizations specific to the model
# (TP, FSDP2, activation checkpointing, ...)
parallelize_fn=parallelize,
# Define the size of the 2D parallelism
# Set to "auto" to apply TP intra-node and DP inter-node
data_parallel_size="auto",
tensor_parallel_size="auto",
)
3.在litgpt/litgpt/model.py下的class CausalSelfAttention的__init__.py方法中加上代码
self.n_heads = config.n_head
self.n_kv_heads = config.n_head
4.修改parallelism.py文件
5.修改litgpt/litgpt/utils.py的load_checkpoint方法
def load_checkpoint(fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True) -> None:
if isinstance(fabric.strategy, FSDPStrategy):
fabric.load_raw(checkpoint_path, model, strict=strict)
elif isinstance(fabric.strategy, ModelParallelStrategy):
fabric.load_raw(checkpoint_path, model, strict=False)
else:
state_dict = lazy_load(checkpoint_path)
state_dict = state_dict.get("model", state_dict)
model.load_state_dict(state_dict, strict=strict)
6.修改litgpt/litgpt/model.py 下的class CausalSelfAttention:
把self.attn改成self.attn_w
7.有个包貌似有问题
/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/api.py
临时解决方法:把/home/xingzhuang/software/anaconda3/envs/litgpt/lib/python3.9/site-packages/torch/distributed/tensor/parallel/style.py的_apply函数中
NotImplementedError改为print,不终止报错
原文地址:https://blog.csdn.net/weixin_46347213/article/details/142327462
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!