LeNet实验 四分类 与 四分类变为多个二分类
目录
1. 划分二分类
可以根据不同的类别进行多个划分,以实现NonDemented为例,划分为NonDemented和Demented两类,不属于NonDemented的全都属于Demented
2. 训练独立的二分类模型
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from 文件准备 import data_dir
# 数据生成器
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
validation_split=0.2 # 20%用于验证
)
train_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(28, 28),
batch_size=32,
class_mode='binary',
subset='training'
)
validation_generator = train_datagen.flow_from_directory(
data_dir,
target_size=(28, 28),
batch_size=32,
class_mode='binary',
subset='validation'
)
# 构建LeNet-5模型
model = models.Sequential()
model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(28, 28, 3), padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(16, (5, 5), activation='relu', padding='same'))
model.add(layers.AveragePooling2D((2, 2)))
model.add(layers.Conv2D(120, (5, 5), activation='relu', padding='same'))
model.add(layers.Flatten())
model.add(layers.Dense(84, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
# 编译模型
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
# 训练模型
model.fit(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
epochs=10,
validation_data=validation_generator,
validation_steps=validation_generator.samples // validation_generator.batch_size
)
# 保存模型
model.save('lenet_binary_classification_model.h5')
3. 预测结果代码
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
# 加载模型
model = tf.keras.models.load_model('lenet_binary_classification_model.h5')
# 预处理图像
def preprocess_image(img_path):
img = image.load_img(img_path, target_size=(28, 28))
img_array = image.img_to_array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
return img_array
# 预测图像
img_path = 'D:\Pycharm_workspace\LeNet实验_二分类\Demented\moderateDem24.jpg' # 测试图像路径
img_array = preprocess_image(img_path)
prediction = model.predict(img_array)
predicted_class = 'Demented' if prediction[0][0] > 0.5 else 'NonDemented'
print(f'The predicted class is: {predicted_class}')
# 显示图像
img = image.load_img(img_path, target_size=(28, 28))
plt.imshow(img)
plt.title(f'Predicted: {predicted_class}')
plt.show()
4. 预测结果
Demented结果
NonDemented结果没有。。。。。。
竟然全都没有。。。。因为预测的全部都是Demented
疯狂找原因中
猜测是像素太低使得训练的模型准确率太低
于是重新训练
5 改进训练模型
进行重新训练
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# 定义LeNet模型
def create_lenet_model(input_shape):
model = Sequential([
Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),
MaxPooling2D((2, 2), strides=2),
Conv2D(16, (5, 5), activation='relu'),
MaxPooling2D((2, 2), strides=2),
Flatten(),
Dense(120, activation='relu'),
Dense(84, activation='relu'),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
# 训练数据生成器
train_generator = train_datagen.flow_from_directory(
'D:\Pycharm_workspace\LeNet实验_二分类\image',
target_size=(176, 208),
batch_size=32,
class_mode='binary',
subset='training'
)
# 验证数据生成器
validation_generator = train_datagen.flow_from_directory(
'D:\Pycharm_workspace\LeNet实验_二分类\image',
target_size=(176, 208),
batch_size=32,
class_mode='binary',
subset='validation'
)
# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)
# 保存模型
model.save('dementia_classification_model.h5')
# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()
# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()
这里还有图形画loss与准确率但是我忘记保存了,就用控制台的输出
可以看到loss值非常小而且准确率是100
6 优化后 预测结果代码
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
import os
# 加载模型
model = load_model('dementia_classification_model.h5')
# 定义类别标签
class_labels = ['Demented', 'NonDemented']
# 预测函数
def predict_image(img_path):
img = image.load_img(img_path, target_size=(176, 208))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0
prediction = model.predict(img_array)
predicted_class = class_labels[int(prediction[0] > 0.5)]
# 显示图像和预测结果
plt.imshow(image.load_img(img_path))
plt.title(f'Predicted: {predicted_class}')
plt.axis('off')
plt.show()
# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_二分类\image\NonDemented\nonDem1.jpg' # 替换为你的图片路径
predict_image(img_path)
7 优化后预测结果
图片与预测结果对应上了(右侧是图片链接可以看到是Dem的类型)
NonDem的也是对应上了
就此训练完成
8 训练四分类模型
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
# 定义LeNet模型
def create_lenet_model(input_shape):
model = Sequential([
Conv2D(6, (5, 5), activation='relu', input_shape=input_shape, padding='same'),
MaxPooling2D((2, 2), strides=2),
Conv2D(16, (5, 5), activation='relu'),
MaxPooling2D((2, 2), strides=2),
Flatten(),
Dense(120, activation='relu'),
Dense(84, activation='relu'),
Dense(4, activation='softmax')
])
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
return model
# 数据增强和数据生成器
train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
# 训练数据生成器
train_generator = train_datagen.flow_from_directory(
'D:\Pycharm_workspace\LeNet实验_四分类\image',
target_size=(176, 208),
batch_size=32,
class_mode='categorical',
subset='training'
)
# 验证数据生成器
validation_generator = train_datagen.flow_from_directory(
'D:\Pycharm_workspace\LeNet实验_四分类\image',
target_size=(176, 208),
batch_size=32,
class_mode='categorical',
subset='validation'
)
# 创建并训练模型
input_shape = (176, 208, 3)
model = create_lenet_model(input_shape)
history = model.fit(train_generator, epochs=10, validation_data=validation_generator)
# 保存模型
model.save('dementia_classification_model.h5')
# 绘制训练和验证损失
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.title('训练和验证损失')
plt.xlabel('时期')
plt.ylabel('损失')
plt.legend()
plt.show()
# 绘制训练和验证准确率
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.title('训练和验证准确率')
plt.xlabel('时期')
plt.ylabel('准确率')
plt.legend()
plt.show()
loss值与准确率的变化图
可以看到才第四轮准确率就已经很高了
9 预测结果代码
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import matplotlib.pyplot as plt
# 加载模型
model = load_model('dementia_classification_model.h5')
# 定义类别标签
class_labels = ['MildDemented', 'ModerateDemented', 'NonDemented', 'VeryMildDemented']
# 预测函数
def predict_image(img_path):
img = image.load_img(img_path, target_size=(176, 208))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0
prediction = model.predict(img_array)
predicted_class = class_labels[np.argmax(prediction)]
# 显示图像和预测结果
plt.imshow(image.load_img(img_path))
plt.title(f'Predicted: {predicted_class}')
plt.axis('off')
plt.show()
# 预测并展示结果
img_path = r'D:\Pycharm_workspace\LeNet实验_四分类\image\VeryMildDemented\verymildDem0.jpg' # 你的图片路径
predict_image(img_path)
10 四分类结果识别
1 MildDem成功识别(右侧有图片名称)
2 ModerateDem 成功识别
3 NonDem成功识别
4 VeryMildDem成功识别
原文地址:https://blog.csdn.net/2301_78488802/article/details/140561179
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!