Center Loss 和 ArcFace Loss 笔记
一、Center Loss
1. 定义
Center Loss
旨在最小化类内特征的离散程度,通过约束样本特征与其类别中心之间的距离,提高类内特征的聚合性。
2. 公式
对于样本 xi 和其类别yi,Center Loss
的公式为:
- xi: 当前样本的特征向量(通常来自网络的最后一层)。
- Cyi: 类别 yi 的特征中心。
- m: 样本数量。
3. 作用
- 减小类内样本的特征分布范围。
- 提高分类模型对相似类别样本的区分能力。
4. 实现
import torch
import torch.nn as nn
class CenterLoss(nn.Module):
def __init__(self, num_classes, feat_dim, weight=1.0):
"""
:param num_classes: 类别数量
:param feat_dim: 特征向量维度
:param weight: 损失的权重
"""
super(CenterLoss, self).__init__()
self.weight = weight
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim)) # 初始化类别中心
def forward(self, features, labels):
"""
:param features: 网络输出的特征向量 (batch_size, feat_dim)
:param labels: 样本对应的类别标签 (batch_size,)
"""
centers = self.centers[labels] # 获取对应标签的中心
loss = torch.sum((features - centers) ** 2, dim=1).mean() # 欧几里得距离平方和
return self.weight * loss
5. 结合 Cross-Entropy Loss
将 Center Loss
与交叉熵损失结合,联合优化网络:
center_loss = CenterLoss(num_classes=10, feat_dim=512)
cross_entropy_loss = nn.CrossEntropyLoss()
# 训练时
features, logits = model(input_data)
loss_ce = cross_entropy_loss(logits, labels)
loss_center = center_loss(features, labels)
total_loss = loss_ce + 0.1 * loss_center # 合并损失
二、ArcFace Loss
1. 定义
ArcFace Loss
是基于角度的损失函数,用于增强特征的判别性。通过在角度空间引入额外的边际约束,强迫同类样本之间更加接近,而不同类样本之间更加远离。
2. 公式
ArcFace Loss
的公式为:
- θ: 特征和分类权重之间的角度。
- m: 边际(margin)。
最终损失使用交叉熵计算:
- s: 缩放因子,用于平衡模型的学习难度。
3. 作用
- 强化特征的角度判别能力,使得分类更加鲁棒。
- 在人脸识别任务中,显著提高模型的性能。
4. 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ArcFaceLoss(nn.Module):
def __init__(self, in_features, out_features, s=30.0, m=0.50):
"""
:param in_features: 特征向量维度
:param out_features: 类别数量
:param s: 缩放因子
:param m: 边际约束
"""
super(ArcFaceLoss, self).__init__()
self.s = s
self.m = m
self.weight = nn.Parameter(torch.randn(out_features, in_features)) # 分类权重
def forward(self, embeddings, labels):
# Normalize embeddings and weight
embeddings = F.normalize(embeddings, p=2, dim=1)
weight = F.normalize(self.weight, p=2, dim=1)
# Cosine similarity
cosine = F.linear(embeddings, weight)
# Add margin
phi = cosine - self.m
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, labels.view(-1, 1), 1)
cosine_with_margin = one_hot * phi + (1 - one_hot) * cosine
# Scale
logits = self.s * cosine_with_margin
loss = F.cross_entropy(logits, labels)
return loss
解释:
ArcFaceLoss在最后一层网络,输入是上一层的输出特征值x,初始化当前层的w权重。
cos(角度)=w×x/|w|×|x|,由于ArcLoss会对w和x进行归一化到和为1的概率值。所以|w|×|x|=1。则推导出cos(角度)=w×x,那么真实标签位置给角度+m则让角度变大了,cos值变小。w×x变小,输出的预测为真实标签的概率变低。让模型更难训练,那么在一遍又一遍的模型读取图片提取特征的过程中,会让模型逐渐的将真实标签位置的w×x值变大==cos(角度+m)变大,那么角度就会变的更小。只有角度更小的时候,cos余弦相似度才会大,从而让模型认为这个类别是真实的类别。
所以arcloss主要加入了一个m,增大角度,让模型更难训练,让模型把角度变的更小,从而让w的值调整的更加让类间距增大。
简而言之:加入m的值,让真实类和其他类相似度更高,让模型更难训练。迫使模型为了让真实和其他类相似度更低,而让w权重的值更合理。
三、对比分析
四、如何选择
- 如果任务需要提升类内特征的聚合性(如样本分布紧密性),优先考虑
Center Loss
。 - 如果任务需要增强类间特征的判别能力(如人脸识别),优先选择
ArcFace Loss
。 - 可以同时使用两者,将特征聚合和判别性结合,提高模型的鲁棒性。
五、推荐学习资源
- ArcFace: Additive Angular Margin Loss for Deep Face Recognition (论文)
- Center Loss: A Discriminative Feature Learning Approach for Deep Face Recognition (论文)
- PyTorch 官方文档
原文地址:https://blog.csdn.net/jenny88889999/article/details/145041224
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!