自学内容网 自学内容网

神经网络基础-神经网络搭建和参数计算

1.构建神经网络

在 pytorch 中定义深度神经网络其实就是层堆叠的过程,继承自nn.Module,实现两个方法:

  • __init__方法中定义网络中的层结构,主要是全连接层,并进行初始化。
  • forward方法,在实例化模型的时候,底层会自动调用该函数。该函数中可以定义学习率,为初始化定义的layer传入数据等。

我们来构建如下图所示的神经网络模型:
在这里插入图片描述

编码设计如下:

  1. 第1个隐藏层:权重初始化采用标准化的xavier初始化 激活函数使用sigmoid。
  2. 第2个隐藏层:权重初始化采用标准化的He初始化 激活函数采用relu。
  3. out输出层线性层 假若二分类,采用softmax做数据归一化。
# 创建神经网络
import torch
import torch.nn as nn
# pip install torchsummary
from torchsummary import summary # 计算模型参数,查看模型结构 pip install torchsummary
# 创建神经网络模型类
class Model(nn.Module):
    # 初始化属性值
    def __init__(self):
        # 调用父类的初始化属性值
        super(Model, self).__init__()
        # 创建第一个隐藏层模型,3个输入特征,3个输出特征
        self.linear1 = nn.Linear(3, 3)
        # 初始化权重 xavier 均匀分布初始化
        nn.init.xavier_uniform_(self.linear1.weight)
        # 创建第二个隐藏层,3个输入特征(上一层的输出特征),2个输出特征
        self.linear2 = nn.Linear(3, 2)
        # 初始化权重 kaiming 正太分布初始化
        nn.init.kaiming_normal_(self.linear2.weight)
        # 创建输出层模型
        self.out = nn.Linear(2, 2)
    # 创建向前传播方法,自动执行 forward()方法
    def forward(self, x):
        # 数据经过第一个线性层
        x = self.linear1(x)
        # 使用 sigmoid 激活函数
        x = torch.sigmoid(x)
        # 数据经过第二个线性层
        x = self.linear2(x)
        # 使用 relu 激活函数
        x = torch.relu(x)
        # 数据经过输出层
        x = self.out(x)
        # 使用 softmax 激活函数
        # dim=-1:每一维度行数据相机为1
        x = torch.softmax(x, dim=-1)
        return x

if __name__ == '__main__':
    # 实例化model对象
    model = Model()
    # 随机产生数据
    data = torch.randn(5,3)
    print('data.shape',data.shape)
    # 数据经过神经网络模型训练
    out = model(data)
    print('out.shape',out.shape)
    # 计算模型参数
    # 计算每层每个神经元的 w 和 b 个数总和
    summary(model,input_size=(3,),batch_size=5)
    # 查看模型参数
    print("======查看模型参数w和b======")
    for name, param in model.named_parameters():
        print(name, param)
  • 神经网络的输入数据是为[batch_size, in_features]的张量经过网络处理后获取了[batch_size, out_features]的输出张量。

  • 在上述例子中,batch_size=5, in_features=3,out_features=2,结果如下所示:

    data.shape torch.Size([5, 3])
    out.shape torch.Size([5, 2])
    

    模型参数输出:

    ----------------------------------------------------------------
            Layer (type)               Output Shape         Param #
    ================================================================
                Linear-1                     [5, 3]              12
                Linear-2                     [5, 2]               8
                Linear-3                     [5, 2]               6
    ================================================================
    Total params: 26
    Trainable params: 26
    Non-trainable params: 0
    ----------------------------------------------------------------
    Input size (MB): 0.00
    Forward/backward pass size (MB): 0.00
    Params size (MB): 0.00
    Estimated Total Size (MB): 0.00
    ----------------------------------------------------------------
    ======查看模型参数w和b======
    linear1.weight Parameter containing:
    tensor([[ 0.3857,  0.4809, -0.0346],
            [ 0.3645,  0.2803, -0.6291],
            [ 0.1999, -0.6617,  0.7724]], requires_grad=True)
    linear1.bias Parameter containing:
    tensor([0.3084, 0.5636, 0.4501], requires_grad=True)
    linear2.weight Parameter containing:
    tensor([[ 0.1063,  0.7494,  0.4311],
            [-1.4152,  0.3396, -0.8590]], requires_grad=True)
    linear2.bias Parameter containing:
    tensor([-0.3771,  0.2937], requires_grad=True)
    out.weight Parameter containing:
    tensor([[-0.6012,  0.4727],
            [-0.2953, -0.5854]], requires_grad=True)
    out.bias Parameter containing:
    tensor([-0.3271,  0.4940], requires_grad=True)
    

模型参数的计算:

  1. 以第一个隐层为例:该隐层有3个神经元,每个神经元的参数为:4个(w1,w2,w3,b1),所以一共用3x4=12个参数。
  2. 输入数据和网络权重是两个不同的事儿!对于初学者理解这一点十分重要,要分得清。
    在这里插入图片描述

2. 神经网络的优缺点

  1. 优点
    ➢ 精度高,性能优于其他的机器学习算法,甚至在某些领域超过了人类。
    ➢ 可以近似任意的非线性函数。
    ➢ 近年来在学界和业界受到了热捧,有大量的框架和库可供调。
  2. 缺点
    ➢ 黑箱,很难解释模型是怎么工作的。
    ➢ 训练时间长,需要大量的计算资源。
    ➢ 网络结构复杂,需要调整超参数。
    ➢ 部分数据集上表现不佳,容易发生过拟合。

原文地址:https://blog.csdn.net/dwjf321/article/details/144457551

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!