自学内容网 自学内容网

【mmengine】优化器封装(OptimWrapper)(入门)优化器封装 vs 优化器

  • MMEngine 实现了优化器封装,为用户提供了统一的优化器访问接口。优化器封装支持不同的训练策略,包括混合精度训练、梯度累加和梯度截断。用户可以根据需求选择合适的训练策略。优化器封装还定义了一套标准的参数更新流程,用户可以基于这一套流程,实现同一套代码,不同训练策略的切换。
  • 分别基于 Pytorch 内置的优化器和 MMEngine 的优化器封装(OptimWrapper)进行单精度训练混合精度训练梯度累加,对比二者实现上的区别。

一、 基于 Pytorch 的 SGD 优化器实现单精度训练

import torch
from torch.optim import SGD
import torch.nn as nn
import torch.nn.functional as F

inputs = [torch.zeros(10, 1, 1)] * 10
targets = [torch.ones(10, 1, 1)] * 10
model = nn.Linear(1, 1)
optimizer = SGD(model.parameters(), lr=0.01)
optimizer.zero_grad()

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

二、 使用 MMEngine 的优化器封装实现单精度训练

from mmengine.optim import OptimWrapper

optim_wrapper = OptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    output = model(input)
    loss = F.l1_loss(output, target)
    optim_wrapper.update_params(loss)

优化器封装的 update_params 实现了标准的梯度计算、参数更新和梯度清零流程,可以直接用来更新模型参数。
在这里插入图片描述

三、 基于 Pytorch 的 SGD 优化器实现混合精度训练

在这里插入图片描述

  • 混合精度训练:单精度 float和半精度 float16 混合,其优势为:
    • 内存占用更少
    • 计算更快
from torch.cuda.amp import autocast

model = model.cuda()
inputs = [torch.zeros(10, 1, 1, 1)] * 10
targets = [torch.ones(10, 1, 1, 1)] * 10

for input, target in zip(inputs, targets):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

四、 基于 MMEngine 的 优化器封装实现混合精度训练

from mmengine.optim import AmpOptimWrapper

optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述

  • 混合精度训练需要使用 AmpOptimWrapper,他的 optim_context 接口类似 autocast,会开启混合精度训练的上下文。除此之外他还能加速分布式训练时的梯度累加,这个我们会在下一个示例中介绍

五、 基于 Pytorch 的 SGD 优化器实现混合精度训练和梯度累加

for idx, (input, target) in enumerate(zip(inputs, targets)):
    with autocast():
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    loss.backward()
    if idx % 2 == 0:
        optimizer.step()
        optimizer.zero_grad()

六、基于 MMEngine 的优化器封装实现混合精度训练和梯度累加

optim_wrapper = AmpOptimWrapper(optimizer=optimizer, accumulative_counts=2)

for input, target in zip(inputs, targets):
    with optim_wrapper.optim_context(model):
        output = model(input.cuda())
    loss = F.l1_loss(output, target.cuda())
    optim_wrapper.update_params(loss)

在这里插入图片描述
只需要配置 accumulative_counts 参数,并调用 update_params 接口就能实现梯度累加的功能。除此之外,分布式训练情况下,如果我们配置梯度累加的同时开启了 optim_wrapper 上下文,可以避免梯度累加阶段不必要的梯度同步。

七、 获取学习率/动量

优化器封装提供了 get_lrget_momentum 接口用于获取优化器的一个参数组的学习率:

import torch.nn as nn
from torch.optim import SGD

from mmengine.optim import OptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)
# 封装器
optim_wrapper = OptimWrapper(optimizer)

print("get info from optimizer ------")
print(optimizer.param_groups[0]['lr'])  # 0.01
print(optimizer.param_groups[0]['momentum'])  # 0
print("get info from wrapper ------")
print(optim_wrapper.get_lr())  # {'lr': [0.01]}
print(optim_wrapper.get_momentum())  # {'momentum': [0]}

在这里插入图片描述

八、 导出/加载状态字典

优化器封装和优化器一样,提供了 state_dictload_state_dict 接口,用于导出/加载优化器状态,对于 AmpOptimWrapper,优化器封装还会额外导出混合精度训练相关的参数:

import torch.nn as nn
from torch.optim import SGD
from mmengine.optim import OptimWrapper, AmpOptimWrapper

model = nn.Linear(1, 1)
# 优化器
optimizer = SGD(model.parameters(), lr=0.01)

# ---- 导出 ---- #
print("print state_dict")
# 单精度封装器
optim_wrapper = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper = AmpOptimWrapper(optimizer=optimizer)

# 导出状态字典
optim_state_dict = optim_wrapper.state_dict()
amp_optim_state_dict = amp_optim_wrapper.state_dict()
print(optim_state_dict)
print(amp_optim_state_dict)



# ---- 加载 ---- #
print("load state_dict")
# 单精度封装器
optim_wrapper_new = OptimWrapper(optimizer=optimizer)
# 混合精度封装器
amp_optim_wrapper_new = AmpOptimWrapper(optimizer=optimizer)

# 加载状态字典
amp_optim_wrapper_new.load_state_dict(amp_optim_state_dict)
optim_wrapper_new.load_state_dict(optim_state_dict)

在这里插入图片描述

九、 使用多个优化器

OptimWrapperDict 的核心功能是支持批量导出/加载所有优化器封装的状态字典;支持获取多个优化器封装的学习率、动量。如果没有 OptimWrapperDict,MMEngine 就需要在很多位置对优化器封装的类型做 if else 判断,以获取所有优化器封装的状态。

from torch.optim import SGD
import torch.nn as nn

from mmengine.optim import OptimWrapper, OptimWrapperDict

# model1
gen = nn.Linear(1, 1)
# model2
disc = nn.Linear(1, 1)

# optimizer1
optimizer_gen = SGD(gen.parameters(), lr=0.01)
# optimizer2
optimizer_disc = SGD(disc.parameters(), lr=0.01)

# wrapper1
optim_wapper_gen = OptimWrapper(optimizer=optimizer_gen)
# wrapper2
optim_wapper_disc = OptimWrapper(optimizer=optimizer_disc)

# wrapper_dict = wrapper1 + wrapper2
optim_dict = OptimWrapperDict(gen=optim_wapper_gen, disc=optim_wapper_disc)

print("wrapper_dict = wrapper1 + wrapper2")
print(optim_dict.get_lr())  # {'gen.lr': [0.01], 'disc.lr': [0.01]}
print(optim_dict.get_momentum())  # {'gen.momentum': [0], 'disc.momentum': [0]}

在这里插入图片描述


原文地址:https://blog.csdn.net/m0_51579041/article/details/142667735

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!