自学内容网 自学内容网

使用 Keras 训练一个卷积神经网络(CNN)(入门篇)

在上一篇文章中,我们介绍了如何使用 Keras 训练一个简单的全连接神经网络(MLP)。本文将带你深入学习如何使用 Keras 构建和训练一个卷积神经网络(CNN),用于图像分类任务。我们将继续使用 MNIST 数据集,但这次我们将采用更适合图像数据的 CNN 架构。

目录

  1. 环境准备
  2. 导入必要的库
  3. 加载和预处理数据
  4. 构建卷积神经网络模型
  5. 编译模型
  6. 训练模型
  7. 评估模型
  8. 保存和加载模型
  9. 可视化训练过程
  10. 总结

1. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

2. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于数据可视化。

3. 加载和预处理数据

我们继续使用 Keras 自带的 MNIST 数据集。

# 加载 MNIST 数据集
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# 查看数据形状
print(f"训练数据形状: {x_train.shape}, 训练标签形状: {y_train.shape}")
print(f"测试数据形状: {x_test.shape}, 测试标签形状: {y_test.shape}")

# 数据预处理
# 归一化:将像素值缩放到 0-1 之间
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# CNN 需要添加通道维度
x_train = np.expand_dims(x_train, -1)  # 形状变为 (60000, 28, 28, 1)
x_test = np.expand_dims(x_test, -1)    # 形状变为 (10000, 28, 28, 1)

# 将标签转换为分类编码
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# 可视化部分数据
plt.figure(figsize=(10,10))
for i in range(25):
    plt.subplot(5,5,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_train[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(np.argmax(y_train[i]))
plt.show()

说明:

  • CNN 需要输入数据具有通道维度,因此使用 np.expand_dims 添加一个维度。
  • MNIST 数据集是灰度图像,因此通道维度为 1。

4. 构建卷积神经网络模型

我们将构建一个简单的 CNN 模型,包含两个卷积层和两个池化层,最后接上全连接层进行分类。

model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),  # 卷积层,32 个 3x3 卷积核
    layers.MaxPooling2D((2, 2)),  # 最大池化层,池化窗口 2x2

    layers.Conv2D(64, (3, 3), activation='relu'),  # 卷积层,64 个 3x3 卷积核
    layers.MaxPooling2D((2, 2)),  # 最大池化层

    layers.Flatten(),  # 展平层
    layers.Dense(64, activation='relu'),  # 全连接层,64 个神经元
    layers.Dense(num_classes, activation='softmax')  # 输出层,10 个神经元
])

# 查看模型结构
model.summary()

说明:

  • Conv2D: 二维卷积层,用于提取图像特征。
  • MaxPooling2D: 最大池化层,用于下采样,减少参数数量。
  • Flatten: 将多维输入一维化,以便连接全连接层。
  • Dense: 全连接层,用于分类。

5. 编译模型

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和交叉熵损失函数。
  • 评估指标为准确率。

6. 训练模型

# 设置训练参数
batch_size = 128
epochs = 10

# 训练模型
history = model.fit(x_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    validation_split=0.1)  # 使用 10% 的训练数据作为验证集

说明:

  • 使用 10% 的训练数据作为验证集,以监控模型在验证集上的性能。

7. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\n测试准确率: {test_acc:.4f}")

8. 保存和加载模型

# 保存模型
model.save("mnist_cnn_model.h5")

# 加载模型
new_model = keras.models.load_model("mnist_cnn_model.h5")

9. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))

# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')

# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')

plt.show()

说明:

  • 通过可视化训练过程中的准确率和损失值,可以帮助我们了解模型的训练情况,判断是否存在过拟合或欠拟合。

10. 本节回顾

本节介绍了如何使用 Keras 构建和训练一个简单的卷积神经网络(CNN),用于手写数字识别任务。主要步骤包括:

  1. 环境准备和库导入: 确保安装了必要的库,并导入所需模块。
  2. 数据加载和预处理: 加载 MNIST 数据集,进行归一化,并添加通道维度。
  3. 构建 CNN 模型: 使用 Conv2D、MaxPooling2D、Flatten、Dense 等层构建模型。
  4. 编译模型: 指定优化器、损失函数和评估指标。
  5. 训练模型: 使用训练数据训练模型,并使用验证集监控性能。
  6. 评估模型: 在测试集上评估模型性能。
  7. 保存和加载模型: 将训练好的模型保存到磁盘,并可加载进行预测。
  8. 可视化训练过程: 通过绘制准确率和损失值曲线,了解模型的训练情况。

通过这个基础教程,你可以开始自行探索更复杂的 CNN 模型和更深入的应用,如图像分类、目标检测、图像分割等。

导师简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师


原文地址:https://blog.csdn.net/bravekingzhang/article/details/143781772

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!