【VOS源码解析-2024CVPR-Cutie】2、trainner 结构解析
源码解析
论文阅读
回顾
在train.py中,里面最重要的就是trainer变量的使用,该变量的使用记录如下:
- 构建trainer
- 加载checkpoint、权重
- 设为训练模式
- 进行一轮epoch训练
- 如果训练过程中中断,不管这一轮训练是否完成都保存权重和checkpoints,以便后续恢复训练
- 将pretrain的最终训练权重保存给变量weights_in_memory,传递给后续的mian_training
2、 trainer.py
2.1 主要模块
如图所示,共有以下主要模块:
- 主体model:CutieTrainWrapper
- 参数分组:get_parameter_groups
- 这里将model的backbone、embeding和其余参数进行了分组,不同的分组赋予不同的lr和weight decay
- 损失计算:LossComputer
- 整合器Integrator
- 用于记录各种训练参数,计算参数的平均值,并于log结合使用,将参数保存至日志中
- 图像可视化:vis
- 把中间训练的结果可视化,并保存下来,用于监控model训练
- 日志记录:TensorboardLogger
- 时间记录:TimeEstimator
2.2 大致流程
- 初始化
- 模型构建初始化
- 日志记录log和整合器Integrator初始化
- 优化器optimizer和损失计算初始化
- 学习率策略调整初始化
- 通过配置文件初始化其余参数
- do_pass(前向传播+反向传播)
- 前向传播
- loss计算、监控中间结果
- 保存相关信息、权重和checkpoints-1
- 反向传播
- 保存相关信息、权重和checkpoints-2
- 权重保存
- 模式选择
2.3 do_pass详细解析
2.3.1 前向传播
#---------------------------------------------------前向传播------------------------------------------------------
torch.set_grad_enabled(self._is_train) #启用梯度计算
for k, v in data.items():
if isinstance(v, torch.Tensor): #确保v是tendor格式
data[k] = v.cuda(non_blocking=True) #将v加载到cuda中
out = self.cutie(data) #前向传播得到预测值
num_filled_objects = out['num_filled_objects'] #对象数量
2.3.2 loss计算、中间过程监控
#----------------------------------------------训练状态下执行,计算loss值-------------------------------------------
if self._is_train:
losses = self.loss_computer.compute({**data, **out}, num_filled_objects) #{**data, **out}将数据进行合并处理
self.integrator.add_dict(losses) #整合器用于收集和汇总训练过程中的各种指标,如损失值、准确率、学习率等
# ---------------------------------------loging, 保存图像日志,监控学习效果--------------------------------------
if self._is_train:
if self.local_rank == 0 and it % self.log_image_interval == 0 and it != 0:
# ----------------------------保存图像日志------------------------------
images = {**data, **out}
self.log.log_image(self.stage, 'vis', vis(images, self.size,
num_filled_objects), it)
2.3.3 保存相关训练信息、权重、checkpoints
此次是按固定间隔(save_weights_interval、save_checkpoint_interval)保存权重
2.3.4 反向传播
(1)将model参数进行分组,并分配不同的lr和weight decay
-
对模型参数进行分组
现代深度学习模型结构复杂多样,不同部分的参数在模型中的作用和重要性不同。例如,在一些视觉模型中,像素编码器(backbone)用于提取图像特征,其参数量通常较大,且在训练初期需要较慢的学习速度来稳定地学习通用的特征表示;而一些特定的嵌入层(如位置嵌入、类别嵌入等)则用于为模型提供额外的先验信息或特定的编码方式,其参数量相对较少,学习速度可以稍快一些,以便更好地适应特定任务。
共分为三组:pixel_encoder参数、embeding参数、其余参数
model参数中以pixel_encoder开头的划分为pixel_encoder参数,以某些特定后缀结尾的划分为embeding参数,剩下的分为其余参数。
最后对这些参数分配不同的lr和weight decay
-
根据不同分组初始化optimizer
(2)模型参数、反向传播、loss、optimizer、scheduler之间的关系
- 获得需要被优化的模型参数
- parameter_groups = get_parameter_groups(self.cutie, stage_cfg, print_log=(local_rank == 0)) #获取模型的参数组用于优化器
- 依据得到的模型参数,和事先设定好的lr等参数,初始化optimizer
self.optimizer = optim.AdamW(parameter_groups, #设置优化器
lr=stage_cfg['learning_rate'],
weight_decay=stage_cfg['weight_decay'],
eps=1e-6 if self.use_amp else 1e-8,
foreach=True)
- 对model的预测结果和GT,使用损失函数计算损失
- losses = self.loss_computer.compute({**data, **out}, num_filled_objects) #{**data, **out}将数据进行合并处理
- 使用计算得到的损失进行反向传播,得到梯度
- losses[‘total_loss’].backward() #对总损失进行反向传播
- 反向传播的具体原理见lianjie
- 计算得到的梯度被保存在参数的.grad变量中
- 使用optimizer更新model参数
- self.optimizer.step() #更新优化器参数
- 使用scheduler对optimizer的lr进行更新
- self.scheduler.step() #更新学习率调度器。这一步会根据配置的学习率调度策略调整优化器的学习率。
- 学习率调整在这里给出了三种调整策略
# ------------------------------------setting up learning rate scheduler----------------------------------------
#选择不同的学习率调度策略
if stage_cfg['lr_schedule'] == 'constant': #学习率在训练过程中保持不变
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda _: 1)
elif stage_cfg['lr_schedule'] == 'poly': #学习率按照多项式衰减策略变化
total_num_iter = stage_cfg['iterations']
self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer,
lr_lambda=lambda x:
(1 - (x / total_num_iter))**0.9)
elif stage_cfg['lr_schedule'] == 'step': #阶梯式衰减
self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
stage_cfg['lr_schedule_steps'],
stage_cfg['lr_schedule_gamma'])
else:
raise NotImplementedError
"""
scheduler 的主要作用是根据预设的策略动态调整优化器的学习率。
模型参数需要在训练过程中被优化,通过反向传播计算损失函数对模型参数的梯度,这些梯度被存储在.grad属性中。optimizer根据计算得到的梯度更新模型参数
scheduler根据预设的调整策略调整optimizer的学习率,使lr动态调整。
"""
2.3.5 保存相关训练信息、权重、checkpoints
如果此时迭代至训练末尾,则开启频繁保存
原文地址:https://blog.csdn.net/weixin_43571113/article/details/145159443
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!