【漫话机器学习系列】043.提前停止训练(Early Stopping)
提前停止训练(Early Stopping)
提前停止(Early Stopping) 是一种在训练机器学习模型(尤其是深度学习模型)时常用的正则化技术,用于防止过拟合并提升模型的泛化能力。它通过监控验证集的性能,在性能不再提高或开始下降时终止训练,从而选择性能最佳的模型。
工作原理
提前停止的基本思想是:
- 在每个训练轮次(epoch)后,评估模型在验证集上的性能(通常使用损失函数值或评价指标,如准确率)。
- 如果验证集性能在多个轮次内未改善,则停止训练并恢复到性能最佳的模型状态。
实现步骤
-
分割数据集: 将训练数据分为训练集和验证集,训练集用于优化模型参数,验证集用于监控模型的泛化性能。
-
设定监控指标: 选择一个监控指标(如验证损失、验证准确率等),作为衡量模型性能的标准。
-
设定耐心值(Patience): 耐心值是指允许验证集性能在指定轮次内未改善的次数。如果超过耐心值还未见性能提升,则停止训练。
-
保存最佳模型: 在训练过程中,记录验证集性能最优的模型状态,停止训练后使用该状态作为最终模型。
优点
- 防止过拟合:通过终止训练,避免模型过度拟合训练数据。
- 提高泛化能力:选择验证集上性能最优的模型,提升模型对未见数据的表现。
- 节省训练时间:减少不必要的迭代,节约计算资源。
- 动态调整:适应数据集的不同复杂度,不需要预设固定的训练轮次。
缺点
- 需要验证集:需要分出一部分数据作为验证集,可能导致训练数据减少。
- 过早停止的风险:模型可能在某些训练阶段出现短暂波动,提前停止可能会错过更好的优化结果。
- 适合深度学习模型:对于小规模模型或简单问题,提前停止的效果可能不明显。
实现方式
1. 使用 TensorFlow/Keras 实现
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
# 示例数据集
X_train = np.random.rand(1000, 20) # 1000个样本,每个样本20个特征
y_train = np.random.randint(2, size=(1000, 1)) # 1000个样本的二分类标签
X_val = np.random.rand(200, 20) # 200个样本,每个样本20个特征
y_val = np.random.randint(2, size=(200, 1)) # 200个样本的二分类标签
model = Sequential([
Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
Dense(1, activation='sigmoid')
])
# 定义 EarlyStopping 回调,监控验证集损失,如果连续5个epoch没有改善则停止训练,并恢复最佳权重
early_stopping = EarlyStopping(
monitor='val_loss', # 监控的指标
patience=5, # 在验证集性能不提升的轮数后停止
restore_best_weights=True # 恢复验证集性能最优的模型
)
# 编译模型
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
# 训练模型,使用早停机制
model.fit(
X_train, y_train,
validation_data=(X_val, y_val),
epochs=100,
callbacks=[early_stopping]
)
运行结果
Epoch 1/100
32/32 [==============================] - 1s 8ms/step - loss: 0.2522 - accuracy: 0.5210 - val_loss: 0.2504 - val_accuracy: 0.5350
Epoch 2/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2491 - accuracy: 0.5320 - val_loss: 0.2502 - val_accuracy: 0.5300
Epoch 3/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2484 - accuracy: 0.5320 - val_loss: 0.2507 - val_accuracy: 0.4950
Epoch 4/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2468 - accuracy: 0.5260 - val_loss: 0.2521 - val_accuracy: 0.4950
Epoch 5/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2456 - accuracy: 0.5560 - val_loss: 0.2524 - val_accuracy: 0.5150
Epoch 6/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2452 - accuracy: 0.5450 - val_loss: 0.2540 - val_accuracy: 0.5000
Epoch 7/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2457 - accuracy: 0.5500 - val_loss: 0.2529 - val_accuracy: 0.4750
2. 使用 PyTorch 实现
import torch
import torch.nn as nn
class EarlyStopping:
def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
self.patience = patience
self.delta = delta
self.best_loss = None
self.counter = 0
self.early_stop = False
self.path = path
def __call__(self, val_loss, model):
if self.best_loss is None or val_loss < self.best_loss - self.delta:
self.best_loss = val_loss
self.counter = 0
torch.save(model.state_dict(), self.path)
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
# 定义示例数据集
X_train = torch.randn(1000, 20) # 1000个样本,每个样本20个特征
y_train = torch.randint(0, 2, (1000, 1)) # 1000个样本的二分类标签
X_val = torch.randn(200, 20) # 200个样本,每个样本20个特征
y_val = torch.randint(0, 2, (200, 1)) # 200个样本的二分类标签
# 定义模型
model = nn.Sequential(
nn.Linear(20, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
# 定义训练和验证函数
def train():
pass
def validate():
return torch.tensor(0.5) # 示例验证损失
# 使用示例
early_stopping = EarlyStopping(patience=5)
for epoch in range(100):
train() # 训练过程
val_loss = validate() # 验证损失
early_stopping(val_loss, model)
if early_stopping.early_stop:
model.load_state_dict(torch.load('checkpoint.pt'))
break
总结
提前停止训练是机器学习和深度学习中的一种简单高效的正则化方法,能够显著提升模型的泛化能力,同时减少训练时间。结合耐心值(patience)、监控指标以及最佳模型保存机制,可以灵活地应用到各种场景中。
原文地址:https://blog.csdn.net/IT_ORACLE/article/details/145048155
免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!