自学内容网 自学内容网

【连续学习之VCL算法】2017年论文:Variational continual learning

1 介绍

年份:2017

期刊: arXiv preprint

Nguyen C V, Li Y, Bui T D, et al. Variational continual learning[J]. arXiv preprint arXiv:1710.10628, 2017.

本文提出的算法是变分连续学习(Variational Continual Learning, VCL),它是一种基于变分推断的在线学习方法,结合了在线变分推断(VI)和蒙特卡洛VI的最新进展,用于训练深度判别模型和生成模型,以实现在连续学习设置中避免灾难性遗忘并适应新任务的能力。关键步骤包括使用变分推断来近似后验分布,并通过核心集(coreset)数据摘要方法增强模型的记忆能力。本文算法属于基于变分推断的算法,它通过在线更新模型参数的后验分布来实现连续学习,这可以归类为基于正则化的算法,因为它利用KL散度最小化来正则化模型参数,以平衡对新数据的适应性和对旧数据的保留。

2 创新点

  1. 变分连续学习框架(VCL)
    • 提出了一种新的连续学习框架,即变分连续学习(VCL),它结合了在线变分推断(VI)和蒙特卡洛VI,适用于复杂的连续学习环境。
  2. 深度模型的连续学习
    • 将VCL框架应用于深度判别模型和深度生成模型,展示了该框架在这些复杂神经网络模型中的有效性。
  3. 核心集(coreset)数据摘要
    • 引入了核心集的概念,这是一种小型的代表性数据集,用于保留先前任务的关键信息,帮助算法在新任务学习中避免遗忘旧任务。
  4. 自动和无参数的连续学习
    • VCL框架避免了传统方法中需要手动调整的超参数,实现了完全自动化的学习过程,且无需额外的验证集来调整参数。
  5. 实验结果的优越性
    • 在多个任务上的实验结果显示,VCL在避免灾难性遗忘方面优于现有的连续学习方法,且不需要调整任何超参数。
  6. 理论基础和扩展性
    • 基于贝叶斯推断的理论基础,VCL提供了一种原则性强、可扩展的解决方案,可以应用于多种不同的模型和学习场景。
  7. 适用于复杂任务演化
    • VCL能够处理任务随时间演变以及全新任务出现的情况,这对于现实世界中任务不断变化的场景具有重要意义。

3 算法

3.1 算法原理

  1. 贝叶斯推断框架
    • 贝叶斯推断提供了一个自然框架来处理连续学习问题。它通过保留模型参数的分布来表示参数的不确定性,这有助于在新数据到来时更新知识,同时保留旧知识。
  2. 在线变分推断(Online VI)
    • 在线VI是一种近似贝叶斯推断的方法,它通过迭代更新近似后验分布来处理新数据。VCL利用在线VI来递归地更新模型参数的后验分布。
  3. 变分连续学习(VCL)
    • VCL通过最小化KL散度(Kullback-Leibler divergence)来找到最佳近似后验分布。具体来说,对于每一步新数据的到来,VCL通过结合之前的后验分布和新数据的似然函数,然后通过变分推断找到新的近似后验分布。
  4. 核心集(Coreset)
    • 为了缓解连续学习中累积的近似误差,VCL引入了核心集的概念。核心集是从先前任务中提取的代表性数据点集合,用于在训练过程中刷新模型对旧任务的记忆。
  5. 递归更新
    • VCL递归地更新模型参数的近似后验分布。给定前一步的后验分布和新数据,VCL通过乘以似然函数并重新归一化来获得新的后验分布。
  6. 预测和参数更新
    • 在测试时,VCL使用最终的变分分布来进行预测。在训练时,VCL通过最大化变分下界(variational lower bound)来更新变分参数,这涉及到计算期望对数似然和KL散度。
  7. 蒙特卡洛方法
    • 为了处理期望对数似然的计算,VCL采用蒙特卡洛方法来近似这些期望值,这通常涉及到使用重参数化技巧(reparameterization trick)来计算梯度。

3.2 算法步骤

  1. 初始化:选择一个先验分布 p ( θ ) p(\theta) p(θ)并初始化变分近似 q 0 ( θ ) = p ( θ ) q_0(\theta) = p(\theta) q0(θ)=p(θ)
  2. 核心集初始化:初始化核心集 C 0 = ∅ C_0 = \emptyset C0=
  3. 对于每一个新任务 t = 1 , 2 , … , T t = 1, 2, \ldots, T t=1,2,,T执行以下步骤:a. 观察新数据集 D t D_t Dt。b. 更新核心集 C t C_t Ct,使用 C t − 1 C_{t-1} Ct1 D t D_t Dt来选择新的代表性数据点。c. 更新非核心集数据点的变分分布:

q ~ t ( θ ) = arg ⁡ min ⁡ q ∈ Q K L ( q ( θ ) ∥ q ~ t − 1 ( θ ) p ( D t ∪ C t − 1 ∖ C t ∣ θ ) Z ) \tilde{q}_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_{t-1}(\theta) p(D_t \cup C_{t-1} \setminus C_t | \theta)}{Z} \right) q~t(θ)=argqQminKL(q(θ)Zq~t1(θ)p(DtCt1Ctθ))

其中, Z Z Z是归一化常数。

d. 计算最终的变分分布(仅用于预测):

q t ( θ ) = arg ⁡ min ⁡ q ∈ Q K L ( q ( θ ) ∥ q ~ t ( θ ) p ( C t ∣ θ ) Z ) q_t(\theta) = \arg\min_{q \in Q} KL \left( q(\theta) \parallel \frac{\tilde{q}_t(\theta) p(C_t | \theta)}{Z} \right) qt(θ)=argqQminKL(q(θ)Zq~t(θ)p(Ctθ))

e. 进行预测:在测试输入 x ∗ x^* x上,使用 q t ( θ ) q_t(\theta) qt(θ)来计算预测分布:

p ( y ∗ ∣ x ∗ , D 1 : t ) = ∫ q t ( θ ) p ( y ∗ ∣ θ , x ∗ ) d θ p(y^* | x^*, D_{1:t}) = \int q_t(\theta) p(y^* | \theta, x^*) d\theta p(yx,D1:t)=qt(θ)p(yθ,x)dθ

4 实验分析

图1展示了论文中测试的多头网络架构,包括判别模型(a)和生成模型(b),其中判别模型中低层网络参数θS在多个任务中共享,每个任务t有自己的“头部网络”θtH,映射到共同隐藏层的输出;生成模型中头部网络生成来自潜在变量z的中间层表示。

图6展示了在训练后各个任务生成器生成的图像,其中每列代表特定任务生成器的输出,每行显示所有训练任务生成器的结果,明显地,简单直接的在线学习方法遭受了灾难性遗忘,而其他方法(如VCL)成功地记住了之前的任务。实验结论是,与简单在线学习相比,VCL等方法在连续学习环境中能更好地保留对先前任务的记忆,避免了灾难性遗忘,展现出更好的长期记忆性能。

5 思考

(1)代码举例理解本文算法

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.nn.functional import softmax

# 假设我们有一个简单的神经网络模型
class SimpleNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 变分连续学习算法的实现
def variational_continual_learning(model, prior_mu, prior_sigma, tasks_num, lr=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for t in range(tasks_num):
        # 加载当前任务的数据
        datasets, labels = data_loader(t)
        
        # 遍历当前任务的数据进行训练
        for data, label in zip(datasets, labels):
            # 前向传播
            output = model(data)
            log_likelihood = softmax(output, dim=1).gather(1, label.unsqueeze(1)).squeeze(1).log()
            
            # 计算损失函数,包括负对数似然和KL散度
            loss = -log_likelihood + kl_divergence(model.fc2.weight, model.fc2.bias, prior_mu, prior_sigma)
            
            # 反向传播和优化
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
    return model

def kl_divergence(weights, biases, prior_mu, prior_sigma):
    # 计算权重和偏置的KL散度
    posterior_mu = weights
    posterior_sigma = torch.nn.functional.softplus(biases) + 1e-6  # 防止sigma为0
    
    # KL散度计算公式
    kl_w = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 + (posterior_mu - prior_mu)**2 / posterior_sigma**2 - 1)
    kl_b = 0.5 * (torch.log(prior_sigma) - torch.log(posterior_sigma) + posterior_sigma**2 - 1)
    
    return kl_w.sum() + kl_b.sum()

# 假设我们有一个数据加载器,用于加载连续的任务
def data_loader(task_id):
    # 这里只是一个示例,实际中需要根据task_id加载不同的数据
    # 返回当前任务的数据和标签
    pass

# 初始化模型
input_size = 784  # 例如MNIST数据集
hidden_size = 100
output_size = 10  # 假设有10个类别
model = SimpleNN(input_size, hidden_size, output_size)

# 设置先验分布的均值和标准差
prior_mu = torch.zeros(output_size)
prior_sigma = torch.ones(output_size)

# 执行变分连续学习算法
tasks_num = 5  # 假设有5个连续的任务
trained_model = variational_continual_learning(model, prior_mu, prior_sigma, tasks_num)

原文地址:https://blog.csdn.net/weixin_43935696/article/details/144759652

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!