自学内容网 自学内容网

pytorch nn.Dropout类介绍

在 PyTorch 中,nn.Dropout 是一种正则化方法,随机将输入张量的一部分元素置为零,以防止过拟合并提高模型的泛化能力。其基本用法如下:

import torch
import torch.nn as nn

dropout = nn.Dropout(p=0.5)  # 丢弃概率为 50%
x = torch.ones((2, 3, 4))  # 输入张量
output = dropout(x)  # 输出的部分元素会被置为零
  • 它在训练阶段,对于输入张量中的每个元素,会以概率p将其置为 0。对于未被置为 0 的元素,需要进行数值缩放,缩放因子为1 / (1 - p)
  • 在给定的代码中,p = 0.5,这意味着每个元素有 0.5 的概率被置为 0,而未被置为 0 的元素将乘以1 / (1 - 0.5)=2

注: 输入张量的每个元素会以概率p将其置为 0,没有维度限制。

如何在指定维度上进行 Dropout?

PyTorch 的标准 nn.Dropout 无法直接指定某个维度进行 Dropout,但可以通过以下几种方法实现在指定维度共享 Dropout 掩码

方法 1:自定义 Dropout 类(参考上文)

可以继承 nn.Module,实现一个支持沿指定


原文地址:https://blog.csdn.net/qq_27390023/article/details/145094229

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!