自学内容网 自学内容网

深度学习实验--初步探索数据增强、优化器对模型的影响

前言

  • 这次主要是探究,优化器、数据增强对模型训练的影响;
  • 基础篇还剩下几个, 后面的难度会逐步提升;
  • 越学,越觉得这个东西很玄学,没有过硬实力真的很难把控;
  • 欢迎收藏 + 关注, 本人会持续更新.

1、实验

💁‍♂ 提示:本次实验并不一定具有广泛意义,具体也是需要更具不同场景进行不同分析,这里仅仅作为简单的对比,初步了解不同优化器的区别。

优化器对比实验

  • 目的:对比AdamSGD优化器
  • 条件:用优化版VGG16(详情请看代码实现)对人脸进行识别分类,训练50轮,效果如下:

在这里插入图片描述

⛹️‍♀️ 分析

  • Adam 每一次训练的时候是结合动量和自适应学习方法,可以自动调整学习率;
  • SGD是每一次训练一批数据的时候,是用一个一个样本进行训练的,比较简单;
  • 效果分析Adam比较快提高准确率,在训练初期的时候效果好,但是在后期的时候SGD综合效果更好一点点。

优化器结合数据增强

分别进行与不进行数据增强,跑20轮

效果如图

  • 不进行

在这里插入图片描述

  • 进行

在这里插入图片描述

初步分析

  • 从准确率和损失率来看,发现数据增强后效果较好,但是都存在过拟合的情况,下面加大轮次进行训练。

分别进行与不进行数据增强,跑50轮

效果如图

  • 不进行

在这里插入图片描述

  • 进行

在这里插入图片描述

初步分析

  • 从准确率和损失率来看,发现在加大训练轮次后,进行数据增强效果更好,因为解决了过拟合的情况,不进行数据增强依然存在过拟合情况
  • 效果:从模型效果来看,想要继续优化,最好的办法就是换更好的模型(本文用的是VGG16),后面可以用Resnet这些模型进行再一次优化。

总结

  • 数据、优化器均对效果有不同影响,我感觉这也是深度学习很难得地方,很难把控,解决方法,我感觉只有多实践,多看论文,多复现论文积累经验。

2、代码实现

这里完整代码是50轮数据增强代码,不数据增强,只需要不允许

1、数据处理

1、导入库

import tensorflow as tf
import numpy as np 
from tensorflow.keras import models, datasets, layers

gpu = tf.config.list_physical_devices("GPU")

print(gpu)
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2、数据导入与划分数据集

这里查看图像数据分类信息

import os, pathlib

data_dir = './data/'
data_dir = pathlib.Path(data_dir)

personClassNames = os.listdir(data_dir)

personClassNames
['Angelina Jolie',
 'Brad Pitt',
 'Denzel Washington',
 'Hugh Jackman',
 'Jennifer Lawrence',
 'Johnny Depp',
 'Kate Winslet',
 'Leonardo DiCaprio',
 'Megan Fox',
 'Natalie Portman',
 'Nicole Kidman',
 'Robert Downey Jr',
 'Sandra Bullock',
 'Scarlett Johansson',
 'Tom Cruise',
 'Tom Hanks',
 'Will Smith']

划分: 训练集 :验证集 = 8 : 2

batch_size = 16
image_width, image_height = 336, 336

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    './data/',
    subset='training',
    validation_split=0.2,
    batch_size=batch_size,
    image_size=(image_width, image_height),
    seed=42,
    shuffle=True
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    './data/',
    subset='validation',
    validation_split=0.2,
    batch_size=batch_size,
    image_size=(image_width, image_height),
    seed=42,
    shuffle=True
)
Found 1800 files belonging to 17 classes.
Using 1440 files for training.
Found 1800 files belonging to 17 classes.
Using 360 files for validation.

查看数据划分后的格式

one_batch_data = next(iter(train_ds))

one_batch_data_images, one_batcg_data_labels = one_batch_data

print("images [N, W, H, C]: ",one_batch_data_images.shape)
print("labels: ", one_batcg_data_labels)
images [N, W, H, C]:  (16, 336, 336, 3)
labels:  tf.Tensor([13 11 15  4  5  3 13  5 15  7  3  5  9 16 13  1], shape=(16,), dtype=int32)

3、展示一部分数据

import matplotlib.pyplot as plt 

plt.figure(figsize=(20, 10))

for i in range(10):
    plt.subplot(5, 10, i + 1)
    
    plt.imshow(one_batch_data_images[i].numpy().astype('uint8'))
    
    plt.title(personClassNames[one_batcg_data_labels[i]])
    
    plt.axis('off')
    
plt.show()


在这里插入图片描述

4、数据归一化和内存加速

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    
    .prefetch(buffer_size=AUTOTUNE)
)

5、数据增强

这里对图像进行旋转

data_augmental = tf.keras.Sequential([
    tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.experimental.preprocessing.RandomRotation(0.2)
])

随机取一张图片看效果

# 取归一化后数据

temp_data = next(iter(train_ds))

temp_data_images, _ = temp_data

test_image = tf.expand_dims(temp_data_images[i], 0)

# 随机旋转9张
plt.figure(figsize=(8, 8))
for i in range(9):
    data_strength = data_augmental(test_image)
    plt.subplot(3, 3, i + 1)
    plt.imshow(data_strength[0])
    plt.axis("off")


在这里插入图片描述

6、数据整合

对训练集进行数据增强

AUTOTUNE = tf.data.AUTOTUNE

def prepare(ds):
    ds = ds.map(lambda x, y : (data_augmental(x), y), num_parallel_calls=AUTOTUNE)
    return ds 

train_ds = prepare(train_ds)

2、模型构建和实验

这里采用VGG16模型,优化器默认使用“adam”,本次实验主要对比“adam”和“sgd", 池化层:采用平均池化
修改: 全连接层大幅度降低计算量

from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten, Dropout, BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # VGG16模型
    vgg16 = VGG16(include_top=False, weights='imagenet', input_shape=(image_width, image_height, 3), pooling='avg')
    
    # 冻结权重
    for layer in vgg16.layers:
        layer.trainable = False
        
    # 获取卷积输出
    x = vgg16.output
    
    # 展开
    x = Flatten()(x)
    x = Dense(170, activation='relu')(x)  # 修改全连接层
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    
    # 分类输出
    output = Dense(len(personClassNames), activation='softmax')(x)
    
    # 模型整合
    vgg16_model = Model(inputs=vgg16.input, outputs=output)
    
    # 超参数设计
    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    
    return vgg16_model

实验:分别检验SGD和Aadm优化器的效果

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())

# 输出一个结构展示
model1.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 336, 336, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 336, 336, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 336, 336, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 168, 168, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 168, 168, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 168, 168, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 84, 84, 128)       0         
                                                                 
 block3_conv1 (Conv2D)       (None, 84, 84, 256)       295168    
                                                                 
 block3_conv2 (Conv2D)       (None, 84, 84, 256)       590080    
                                                                 
 block3_conv3 (Conv2D)       (None, 84, 84, 256)       590080    
                                                                 
 block3_pool (MaxPooling2D)  (None, 42, 42, 256)       0         
                                                                 
 block4_conv1 (Conv2D)       (None, 42, 42, 512)       1180160   
                                                                 
 block4_conv2 (Conv2D)       (None, 42, 42, 512)       2359808   
                                                                 
 block4_conv3 (Conv2D)       (None, 42, 42, 512)       2359808   
                                                                 
 block4_pool (MaxPooling2D)  (None, 21, 21, 512)       0         
                                                                 
 block5_conv1 (Conv2D)       (None, 21, 21, 512)       2359808   
                                                                 
 block5_conv2 (Conv2D)       (None, 21, 21, 512)       2359808   
                                                                 
 block5_conv3 (Conv2D)       (None, 21, 21, 512)       2359808   
                                                                 
 block5_pool (MaxPooling2D)  (None, 10, 10, 512)       0         
                                                                 
 global_average_pooling2d (G  (None, 512)              0         
 lobalAveragePooling2D)                                          
                                                                 
 flatten (Flatten)           (None, 512)               0         
                                                                 
 dense (Dense)               (None, 170)               87210     
                                                                 
 batch_normalization (BatchN  (None, 170)              680       
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 170)               0         
                                                                 
 dense_1 (Dense)             (None, 17)                2907      
                                                                 
=================================================================
Total params: 14,805,485
Trainable params: 90,457
Non-trainable params: 14,715,028
_________________________________________________________________

3、模型训练与实验验证

epochs = 50

history_model1 = model1.fit(
    train_ds,
    validation_data=val_ds,
    verbose=1,
    epochs=epochs
)

history_model2 = model2.fit(
    train_ds,
    validation_data=val_ds,
    verbose=1,
    epochs=epochs
)
Epoch 1/50
2024-11-29 20:04:09.371628: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8101
 5/90 [>.............................] - ETA: 3s - loss: 3.6247 - accuracy: 0.0500    

2024-11-29 20:04:11.269447: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
90/90 [==============================] - 11s 65ms/step - loss: 2.9974 - accuracy: 0.1243 - val_loss: 2.7417 - val_accuracy: 0.0889
Epoch 2/50
90/90 [==============================] - 7s 65ms/step - loss: 2.4747 - accuracy: 0.2313 - val_loss: 2.5382 - val_accuracy: 0.1611
Epoch 3/50
90/90 [==============================] - 8s 68ms/step - loss: 2.1940 - accuracy: 0.2771 - val_loss: 2.3752 - val_accuracy: 0.2222
Epoch 4/50
90/90 [==============================] - 9s 62ms/step - loss: 2.0498 - accuracy: 0.3292 - val_loss: 2.1286 - val_accuracy: 0.3194
Epoch 5/50
90/90 [==============================] - 7s 61ms/step - loss: 1.9032 - accuracy: 0.3688 - val_loss: 2.0959 - val_accuracy: 0.2500
Epoch 6/50
90/90 [==============================] - 7s 64ms/step - loss: 1.8406 - accuracy: 0.3958 - val_loss: 1.8965 - val_accuracy: 0.3361
Epoch 7/50
90/90 [==============================] - 7s 60ms/step - loss: 1.7796 - accuracy: 0.4160 - val_loss: 1.8721 - val_accuracy: 0.3444
Epoch 8/50
90/90 [==============================] - 7s 67ms/step - loss: 1.6878 - accuracy: 0.4417 - val_loss: 1.8793 - val_accuracy: 0.3694
Epoch 9/50
90/90 [==============================] - 7s 67ms/step - loss: 1.5962 - accuracy: 0.4806 - val_loss: 2.1096 - val_accuracy: 0.3778
Epoch 10/50
90/90 [==============================] - 8s 65ms/step - loss: 1.5721 - accuracy: 0.4736 - val_loss: 1.8630 - val_accuracy: 0.3639
Epoch 11/50
90/90 [==============================] - 7s 67ms/step - loss: 1.5559 - accuracy: 0.4771 - val_loss: 1.8641 - val_accuracy: 0.3667
Epoch 12/50
90/90 [==============================] - 8s 59ms/step - loss: 1.5054 - accuracy: 0.4993 - val_loss: 1.6805 - val_accuracy: 0.4222
Epoch 13/50
90/90 [==============================] - 7s 62ms/step - loss: 1.4850 - accuracy: 0.5153 - val_loss: 1.7722 - val_accuracy: 0.4139
Epoch 14/50
90/90 [==============================] - 7s 64ms/step - loss: 1.4857 - accuracy: 0.5014 - val_loss: 1.9245 - val_accuracy: 0.4000
Epoch 15/50
90/90 [==============================] - 8s 63ms/step - loss: 1.4274 - accuracy: 0.5299 - val_loss: 1.6763 - val_accuracy: 0.4306
Epoch 16/50
90/90 [==============================] - 7s 58ms/step - loss: 1.4398 - accuracy: 0.5347 - val_loss: 1.8342 - val_accuracy: 0.4139
Epoch 17/50
90/90 [==============================] - 7s 68ms/step - loss: 1.3925 - accuracy: 0.5514 - val_loss: 1.9606 - val_accuracy: 0.3722
Epoch 18/50
90/90 [==============================] - 7s 64ms/step - loss: 1.3788 - accuracy: 0.5576 - val_loss: 1.9656 - val_accuracy: 0.4056
Epoch 19/50
90/90 [==============================] - 7s 66ms/step - loss: 1.3547 - accuracy: 0.5507 - val_loss: 1.7001 - val_accuracy: 0.4472
Epoch 20/50
90/90 [==============================] - 7s 63ms/step - loss: 1.3617 - accuracy: 0.5299 - val_loss: 1.7058 - val_accuracy: 0.4139
Epoch 21/50
90/90 [==============================] - 8s 61ms/step - loss: 1.3571 - accuracy: 0.5500 - val_loss: 1.7327 - val_accuracy: 0.4111
Epoch 22/50
90/90 [==============================] - 7s 64ms/step - loss: 1.3489 - accuracy: 0.5562 - val_loss: 1.8420 - val_accuracy: 0.4139
Epoch 23/50
90/90 [==============================] - 7s 71ms/step - loss: 1.3550 - accuracy: 0.5444 - val_loss: 1.5965 - val_accuracy: 0.4944
Epoch 24/50
90/90 [==============================] - 7s 63ms/step - loss: 1.3284 - accuracy: 0.5597 - val_loss: 1.7336 - val_accuracy: 0.3944
Epoch 25/50
90/90 [==============================] - 7s 66ms/step - loss: 1.3051 - accuracy: 0.5604 - val_loss: 2.1655 - val_accuracy: 0.3333
Epoch 26/50
90/90 [==============================] - 7s 66ms/step - loss: 1.2735 - accuracy: 0.5813 - val_loss: 1.8967 - val_accuracy: 0.3972
Epoch 27/50
90/90 [==============================] - 8s 58ms/step - loss: 1.2642 - accuracy: 0.5813 - val_loss: 1.8246 - val_accuracy: 0.4556
Epoch 28/50
90/90 [==============================] - 7s 63ms/step - loss: 1.2473 - accuracy: 0.5778 - val_loss: 1.8763 - val_accuracy: 0.4167
Epoch 29/50
90/90 [==============================] - 7s 61ms/step - loss: 1.2544 - accuracy: 0.5868 - val_loss: 1.8783 - val_accuracy: 0.4278
Epoch 30/50
90/90 [==============================] - 7s 62ms/step - loss: 1.2677 - accuracy: 0.5958 - val_loss: 1.9555 - val_accuracy: 0.3667
Epoch 31/50
90/90 [==============================] - 7s 63ms/step - loss: 1.2191 - accuracy: 0.5847 - val_loss: 2.1074 - val_accuracy: 0.3333
Epoch 32/50
90/90 [==============================] - 7s 64ms/step - loss: 1.1922 - accuracy: 0.5903 - val_loss: 1.8136 - val_accuracy: 0.4111
Epoch 33/50
90/90 [==============================] - 7s 71ms/step - loss: 1.2122 - accuracy: 0.5826 - val_loss: 1.8247 - val_accuracy: 0.4417
Epoch 34/50
90/90 [==============================] - 7s 65ms/step - loss: 1.1953 - accuracy: 0.5944 - val_loss: 1.7219 - val_accuracy: 0.4861
Epoch 35/50
90/90 [==============================] - 7s 59ms/step - loss: 1.2409 - accuracy: 0.5750 - val_loss: 2.0469 - val_accuracy: 0.4028
Epoch 36/50
90/90 [==============================] - 7s 64ms/step - loss: 1.1802 - accuracy: 0.5972 - val_loss: 1.7781 - val_accuracy: 0.4000
Epoch 37/50
90/90 [==============================] - 7s 65ms/step - loss: 1.2651 - accuracy: 0.5715 - val_loss: 1.7563 - val_accuracy: 0.4278
Epoch 38/50
90/90 [==============================] - 7s 62ms/step - loss: 1.1964 - accuracy: 0.5806 - val_loss: 1.7468 - val_accuracy: 0.4778
Epoch 39/50
90/90 [==============================] - 7s 61ms/step - loss: 1.1818 - accuracy: 0.6118 - val_loss: 1.5965 - val_accuracy: 0.4917
Epoch 40/50
90/90 [==============================] - 7s 69ms/step - loss: 1.2117 - accuracy: 0.5868 - val_loss: 2.2473 - val_accuracy: 0.3306
Epoch 41/50
90/90 [==============================] - 7s 66ms/step - loss: 1.1878 - accuracy: 0.6014 - val_loss: 1.6355 - val_accuracy: 0.4944
Epoch 42/50
90/90 [==============================] - 7s 60ms/step - loss: 1.1425 - accuracy: 0.6104 - val_loss: 2.2436 - val_accuracy: 0.3472
Epoch 43/50
90/90 [==============================] - 7s 65ms/step - loss: 1.1815 - accuracy: 0.5979 - val_loss: 1.7965 - val_accuracy: 0.4667
Epoch 44/50
90/90 [==============================] - 7s 63ms/step - loss: 1.1452 - accuracy: 0.6104 - val_loss: 1.8330 - val_accuracy: 0.4500
Epoch 45/50
90/90 [==============================] - 7s 63ms/step - loss: 1.1664 - accuracy: 0.5972 - val_loss: 1.8414 - val_accuracy: 0.4306
Epoch 46/50
90/90 [==============================] - 7s 59ms/step - loss: 1.1557 - accuracy: 0.6153 - val_loss: 2.9897 - val_accuracy: 0.3333
Epoch 47/50
90/90 [==============================] - 7s 72ms/step - loss: 1.1534 - accuracy: 0.6167 - val_loss: 1.9379 - val_accuracy: 0.4222
Epoch 48/50
90/90 [==============================] - 7s 61ms/step - loss: 1.1549 - accuracy: 0.6229 - val_loss: 1.6648 - val_accuracy: 0.4667
Epoch 49/50
90/90 [==============================] - 7s 59ms/step - loss: 1.1369 - accuracy: 0.6319 - val_loss: 1.7084 - val_accuracy: 0.4833
Epoch 50/50
90/90 [==============================] - 7s 60ms/step - loss: 1.1110 - accuracy: 0.6313 - val_loss: 1.7481 - val_accuracy: 0.4639
Epoch 1/50
90/90 [==============================] - 8s 60ms/step - loss: 3.1046 - accuracy: 0.0993 - val_loss: 2.8330 - val_accuracy: 0.0444
Epoch 2/50
90/90 [==============================] - 7s 61ms/step - loss: 2.7295 - accuracy: 0.1694 - val_loss: 2.6644 - val_accuracy: 0.1694
Epoch 3/50
90/90 [==============================] - 7s 62ms/step - loss: 2.5302 - accuracy: 0.2174 - val_loss: 2.5356 - val_accuracy: 0.1944
Epoch 4/50
90/90 [==============================] - 7s 62ms/step - loss: 2.3906 - accuracy: 0.2333 - val_loss: 2.3604 - val_accuracy: 0.2333
Epoch 5/50
90/90 [==============================] - 7s 68ms/step - loss: 2.2756 - accuracy: 0.2632 - val_loss: 2.3378 - val_accuracy: 0.2167
Epoch 6/50
90/90 [==============================] - 7s 60ms/step - loss: 2.1989 - accuracy: 0.2861 - val_loss: 2.2364 - val_accuracy: 0.2639
Epoch 7/50
90/90 [==============================] - 7s 63ms/step - loss: 2.1255 - accuracy: 0.3187 - val_loss: 2.1736 - val_accuracy: 0.3111
Epoch 8/50
90/90 [==============================] - 7s 64ms/step - loss: 2.1089 - accuracy: 0.3076 - val_loss: 1.9729 - val_accuracy: 0.3583
Epoch 9/50
90/90 [==============================] - 7s 69ms/step - loss: 2.0349 - accuracy: 0.3271 - val_loss: 1.9062 - val_accuracy: 0.3833
Epoch 10/50
90/90 [==============================] - 7s 66ms/step - loss: 1.9601 - accuracy: 0.3479 - val_loss: 1.8394 - val_accuracy: 0.4056
Epoch 11/50
90/90 [==============================] - 7s 65ms/step - loss: 1.9124 - accuracy: 0.3646 - val_loss: 1.8802 - val_accuracy: 0.3833
Epoch 12/50
90/90 [==============================] - 7s 66ms/step - loss: 1.8987 - accuracy: 0.3757 - val_loss: 1.8779 - val_accuracy: 0.4083
Epoch 13/50
90/90 [==============================] - 7s 61ms/step - loss: 1.8571 - accuracy: 0.3993 - val_loss: 1.7853 - val_accuracy: 0.4000
Epoch 14/50
90/90 [==============================] - 7s 63ms/step - loss: 1.7928 - accuracy: 0.4174 - val_loss: 1.8882 - val_accuracy: 0.4000
Epoch 15/50
90/90 [==============================] - 7s 63ms/step - loss: 1.7943 - accuracy: 0.4167 - val_loss: 1.8541 - val_accuracy: 0.3917
Epoch 16/50
90/90 [==============================] - 7s 59ms/step - loss: 1.7808 - accuracy: 0.4062 - val_loss: 1.8767 - val_accuracy: 0.3722
Epoch 17/50
90/90 [==============================] - 7s 71ms/step - loss: 1.7786 - accuracy: 0.4174 - val_loss: 1.7747 - val_accuracy: 0.3944
Epoch 18/50
90/90 [==============================] - 7s 67ms/step - loss: 1.7188 - accuracy: 0.4264 - val_loss: 1.7444 - val_accuracy: 0.4306
Epoch 19/50
90/90 [==============================] - 7s 63ms/step - loss: 1.7334 - accuracy: 0.4264 - val_loss: 1.9065 - val_accuracy: 0.3806
Epoch 20/50
90/90 [==============================] - 7s 67ms/step - loss: 1.6817 - accuracy: 0.4583 - val_loss: 1.7351 - val_accuracy: 0.4028
Epoch 21/50
90/90 [==============================] - 7s 68ms/step - loss: 1.6734 - accuracy: 0.4493 - val_loss: 1.7435 - val_accuracy: 0.4250
Epoch 22/50
90/90 [==============================] - 7s 67ms/step - loss: 1.6751 - accuracy: 0.4313 - val_loss: 1.8199 - val_accuracy: 0.3667
Epoch 23/50
90/90 [==============================] - 7s 59ms/step - loss: 1.6424 - accuracy: 0.4431 - val_loss: 1.7815 - val_accuracy: 0.3833
Epoch 24/50
90/90 [==============================] - 7s 64ms/step - loss: 1.6310 - accuracy: 0.4611 - val_loss: 2.1978 - val_accuracy: 0.3194
Epoch 25/50
90/90 [==============================] - 7s 58ms/step - loss: 1.6379 - accuracy: 0.4500 - val_loss: 1.7472 - val_accuracy: 0.4111
Epoch 26/50
90/90 [==============================] - 7s 66ms/step - loss: 1.5982 - accuracy: 0.4549 - val_loss: 1.7060 - val_accuracy: 0.4250
Epoch 27/50
90/90 [==============================] - 7s 64ms/step - loss: 1.6151 - accuracy: 0.4479 - val_loss: 1.7618 - val_accuracy: 0.3833
Epoch 28/50
90/90 [==============================] - 7s 67ms/step - loss: 1.5639 - accuracy: 0.4667 - val_loss: 1.7085 - val_accuracy: 0.4278
Epoch 29/50
90/90 [==============================] - 7s 68ms/step - loss: 1.5825 - accuracy: 0.4660 - val_loss: 1.7090 - val_accuracy: 0.4417
Epoch 30/50
90/90 [==============================] - 7s 59ms/step - loss: 1.5632 - accuracy: 0.4660 - val_loss: 1.6777 - val_accuracy: 0.4278
Epoch 31/50
90/90 [==============================] - 7s 59ms/step - loss: 1.5528 - accuracy: 0.4840 - val_loss: 1.7472 - val_accuracy: 0.4250
Epoch 32/50
90/90 [==============================] - 7s 59ms/step - loss: 1.5206 - accuracy: 0.4826 - val_loss: 1.6585 - val_accuracy: 0.4361
Epoch 33/50
90/90 [==============================] - 7s 64ms/step - loss: 1.5087 - accuracy: 0.5069 - val_loss: 1.8085 - val_accuracy: 0.3861
Epoch 34/50
90/90 [==============================] - 7s 62ms/step - loss: 1.4967 - accuracy: 0.4917 - val_loss: 1.7083 - val_accuracy: 0.4222
Epoch 35/50
90/90 [==============================] - 7s 64ms/step - loss: 1.5370 - accuracy: 0.4736 - val_loss: 1.8167 - val_accuracy: 0.3806
Epoch 36/50
90/90 [==============================] - 7s 63ms/step - loss: 1.5184 - accuracy: 0.4951 - val_loss: 1.7889 - val_accuracy: 0.4139
Epoch 37/50
90/90 [==============================] - 7s 61ms/step - loss: 1.4898 - accuracy: 0.5069 - val_loss: 1.7182 - val_accuracy: 0.4111
Epoch 38/50
90/90 [==============================] - 7s 59ms/step - loss: 1.4586 - accuracy: 0.5139 - val_loss: 1.7317 - val_accuracy: 0.4083
Epoch 39/50
90/90 [==============================] - 7s 66ms/step - loss: 1.4991 - accuracy: 0.4986 - val_loss: 1.8129 - val_accuracy: 0.3917
Epoch 40/50
90/90 [==============================] - 7s 62ms/step - loss: 1.4643 - accuracy: 0.5146 - val_loss: 1.7422 - val_accuracy: 0.4194
Epoch 41/50
90/90 [==============================] - 7s 61ms/step - loss: 1.4585 - accuracy: 0.5132 - val_loss: 1.6986 - val_accuracy: 0.4389
Epoch 42/50
90/90 [==============================] - 7s 63ms/step - loss: 1.4478 - accuracy: 0.5174 - val_loss: 1.6919 - val_accuracy: 0.4556
Epoch 43/50
90/90 [==============================] - 7s 58ms/step - loss: 1.4544 - accuracy: 0.5104 - val_loss: 1.6879 - val_accuracy: 0.4472
Epoch 44/50
90/90 [==============================] - 7s 58ms/step - loss: 1.4157 - accuracy: 0.5146 - val_loss: 1.6645 - val_accuracy: 0.4278
Epoch 45/50
90/90 [==============================] - 7s 63ms/step - loss: 1.4005 - accuracy: 0.5292 - val_loss: 1.7376 - val_accuracy: 0.4556
Epoch 46/50
90/90 [==============================] - 7s 67ms/step - loss: 1.4008 - accuracy: 0.5250 - val_loss: 1.7319 - val_accuracy: 0.4222
Epoch 47/50
90/90 [==============================] - 7s 59ms/step - loss: 1.4031 - accuracy: 0.5271 - val_loss: 1.7014 - val_accuracy: 0.4389
Epoch 48/50
90/90 [==============================] - 7s 67ms/step - loss: 1.3762 - accuracy: 0.5347 - val_loss: 1.7294 - val_accuracy: 0.4222
Epoch 49/50
90/90 [==============================] - 7s 60ms/step - loss: 1.4056 - accuracy: 0.5188 - val_loss: 1.7143 - val_accuracy: 0.4417
Epoch 50/50
90/90 [==============================] - 7s 64ms/step - loss: 1.4376 - accuracy: 0.5312 - val_loss: 1.7685 - val_accuracy: 0.4417

4、结果检验

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()


在这里插入图片描述


原文地址:https://blog.csdn.net/weixin_74085818/article/details/144145798

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!