自学内容网 自学内容网

DANN & GRL

域自适应是指在目标域与源域的数据分布不同但任务相同下的迁移学习,从而将模型在源域上的良好性能迁移到目标域上,极大地缓解目标域标签缺失严重导致模型性能受损的问题。

介绍一篇经典工作 DANN

模型结构

model

在训练阶段需要预测如下两个任务:

  • 实现源域数据集准确分类,即图像分类误差的最小化,这与正常分类任务保持一致
  • 实现源域和目标域准确分类,即域分类器的误差最小化。而特征提取器的目标是最大化域分类误差,使得域分类器无法分辨数据是来自源域还是目标域,从而让特征提取器学习到域不变特征(domain-invariant)。也就是说特征提取器和域分类器的目标是相反的
    • 本质上就是让特征提取器不要过拟合源域,要学习出源域和目标域的泛化特征
    • 这两个网络对抗训练,DANN通过GRL层使特征提取器更新的梯度与域判别器的梯度相反,构造出了类似于GAN的对抗损失,又通过该层避免了GAN的两阶段训练过程,提升模型训练稳定性

GRL

GRL是作用在特征提取器上的,对其参数梯度取反。

具体实现如下:

class ReverseLayerF(Function):

    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha

        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha

        return output, None

调用如下:

def forward(self, input_data, alpha):
    input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
    feature = self.feature(input_data)
    feature = feature.view(-1, 50 * 4 * 4)
    reverse_feature = ReverseLayerF.apply(feature, alpha)
    class_output = self.class_classifier(feature)
    domain_output = self.domain_classifier(reverse_feature)

    return class_output, domain_output

参考


原文地址:https://blog.csdn.net/transformer_WSZ/article/details/142472661

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!