[论文阅读]MaIL: Improving Imitation Learning with Mamba
Abstract
这项工作介绍了mamba模仿学习(mail),这是一种新颖的模仿学习(il)架构,为最先进的(sota)变换器策略提供了一种计算高效的替代方案。基于变压器的策略由于能够处理具有固有非马尔可夫行为的人类记录数据而取得了显著成果。然而,它们的高性能伴随着大型模型的缺点,这使得有效的训练变得复杂。虽然状态空间模型(ssms)以其效率而闻名,但它们无法与变压器的性能相匹配。mamba显著提高了ssms和竞争对手对transformers的性能,使其成为il政策的一个有吸引力的替代方案。mail利用mamba作为骨干,并引入了一种形式化,允许在编码器-解码器结构中使用mamba。这种形式化使其成为一种通用的架构,既可以用作独立的策略,也可以用作更高级架构的一部分,例如扩散过程中的扩散器。
对LIBERO IL benchmark和三个真实机器人实验的广泛评估表明,mail:i)在所有libero任务中都优于transformer,ii)即使在小数据集下也能实现良好的性能,iii)能够有效地处理多模态感官输入,iv)与transformer相比,对输入噪声更具鲁棒性
Introduction
这里,当前的方法要么使用仅解码器结构[5],要么使用解码器-编码器架构[6]。这些架构中哪一个擅长通常取决于任务。变压器的性能通常伴随着难以训练的大型模型,特别是在数据稀缺的领域。处理观测序列的另一种概念是状态空间模型[12]。这些模型假设观测值(嵌入)之间存在线性关系,通常在计算上更高效。最近的方法,如选择性状态空间模型mamba[13],严格提高了状态空间模型的性能,并在许多任务中与变压器竞争。由于其在推理速度、内存使用和效率方面的特性,mamba是一个有吸引力的il策略模型
邮件可以用作独立的策略,也可以用作更高级流程的一部分,例如扩散流程中的扩散器。我们以两种变体实现邮件。在仅解码器的变体中,mail处理噪声动作和观测特征[5]以及扩散过程的时间嵌入,并输出去噪动作。
Related Works
Sequence Models.
变压器中的自我关注机制允许并行处理序列,有效地解决了rnn在顺序数据处理中的局限性[17,18,19,20]。然而,结构化状态空间模型[12,22,22,13]为变压器提供了一种有吸引力的替代方案。变压器在序列长度上按二次缩放,而结构化状态空间模型则按线性缩放[13]
最近的工作[13]依赖于关联扫描,它也允许并行计算,但还允许输入相关的可学习矩阵[13]
Imitation Learning (IL).
早期的模仿学习方法主要侧重于学习状态-动作对之间的一一映射。但这些方法忽略了历史中包含的丰富时间信息。随后的方法结合了rnns来编码观测序列,证明了利用历史观测可以提高模型性能。然而,这些方法存在基于rnn架构的固有局限性,包括表示能力有限、序列建模时间长以及训练时间慢,因为它们不适合大规模并行化。
Transformer可以对长序列进行建模,同时通过并行序列处理保持训练效率。这一趋势延伸到具有多模态感官输入的il[36,37,38,39],其中变换器对图像和语言序列进行编码
最近,扩散模型在模仿学习中表现出了优越性[5,40,6,41,39]。由于其强大的泛化能力和丰富的表示能力来捕捉多模态动作分布,它们已成为模仿学习领域的sota
Preliminaries
3.1 Mamba: Selective State-Space Models
mamba[13]通过使用选择性扫描算子改进了结构化状态空间序列模型(ssms)
输入序列,输出序列,b、l、d分别表示批量大小、序列长度和维度
标准ssm定义了时不变参数和时间步长向量∆将x(l)的输入映射到隐藏状态,然后可以将其投影到输出y(l)
mamba通过使ssm参数成为输入的函数来实现选择机制
线性是指线性投影层,softplus是relu的平滑近似。那么输出可以通过以下方式计算
其中是离散化的[42]个具有时间步长∆的对应项。由于时变模型只能以循环方式计算,mamba进一步实现了一种硬件感知方法来高效计算选择性ssm。
图1:d-ma:mamba去噪架构集成了用于状态编码的resnet-18和用于动作编码的动作编码器。状态序列的长度为K,而扩散步骤t处的动作序列的长度是J。在将输入馈送到曼巴模块之前,位置编码(PE)和时间编码(TE)增强了输入,其中sk和ak共享相同的位置编码。曼巴模块有n×曼巴块,详细结构[13]如左图所示。mamba模块的输出由线性输出层处理,从而实现一步去噪操作。mamba块中的符号×表示矩阵乘法,σ表示silu激活函数。
3.2 Policy Representations
在这项工作中,我们使用了两种策略表示:行为克隆(bc)和去噪扩散策略(ddps)。为了清楚起见,我们关注的是非连续性的情况。
Behavioral Cloning
行为克隆假设参数化的条件高斯分布作为策略表示,即最大化模型参数θ的可能性简化为均方误差(mse)损失,其中使用演示数据中的状态-动作对来近似s、a的期望。
Denoising Diffusion Policies
去噪扩散策略利用去噪函数从马尔可夫链中采样从开始,对给定的观测值s产生无噪声动作。
训练去噪函数,通过最小化损失来预测噪声动作的源噪声
,t上的期望对应于中的均匀采样。
4 Mamba for Imitation Learning
从成功的仅解码器(D-Tr)和编码器-解码器(ED-Tr)变换器中汲取灵感,我们提出了两种基于mamba的架构:仅解码器mamba(D-Ma)和编码器解码器mamba(ED-Ma)
这些架构充当策略的参数化。具体来说,当采用行为克隆(bc)时,这些架构将条件高斯分布的均值μθ参数化。当使用去噪扩散策略(ddps)时,这些架构将去噪函数εθ参数化。鉴于前一种情况的简单性,我们将重点介绍ddps背景下的这些架构
4.1 Decoder-Only Mamba
与仅解码器转换器类似,我们使用mamba块来处理输入。图1显示了纯解码器mamba架构的概述。ddps的解码器专用mamba被设计为学习去噪函数εθ,该函数接受一系列观测值,噪声动作和diffusion step t,来生成噪声较小的动作序列使用时间嵌入te对扩散步骤进行编码。使用resnet-18对观测值进行编码,在不同时间步长的图像之间共享权重。动作编码器enca用于对有噪声的动作输入进行标记。此外,位置嵌入pe被应用于观测和动作。然后,时间嵌入、状态嵌入和动作嵌入将被输入到mamba解码器decm中。mamba解码器是通过堆叠多个具有残差连接和层归一化的mamba块来实现的。算法1中示出了完整的推理例程
4.2 Encoder-Decoder Mamba
与仅包含自注意机制的解码器变压器相比,具有交叉注意的编码器-解码器变压器是一种更灵活有效的设计,可以处理复杂的输入输出关系,特别是在输入和输出序列结构不同的情况下。然而,由于目标和源共享相同的序列长度,mamba没有提供这样的机制来支持编码器-解码器结构。
我们提出了一种称为mamba聚合的新方法,用于设计mamba的编解码器版本。可视化可以在图2中找到。mamba编码器encm用于处理时间嵌入和状态嵌入,mamba解码器decm用于处理噪声嵌入。由于em和dm的输入具有不同长度的序列,我们建议添加可学习变量来补充每个序列
图2:ed-ma:与d-ma模型不同,ed-ma包含用于处理时间嵌入和状态嵌入的mamba编码器,以及用于处理噪声动作的mamba解码器。为了聚合来自编码器和解码器的信息,将可学习的动作变量引入编码器输入,将可习得的时间变量和状态变量引入解码器输出,以进行序列对齐。
act对比:
5 Experiments
我们的调查侧重于以下关键问题:
q1)MaIL能否实现与变压器相当或更优的性能?
q2)MaIL可以使用多模式输入,如语言指令吗?
q3)MaIL如何有效地处理观察中的连续信息?
5.1 Baselines
我们的实验包含四种架构:去卷积变换器(d-tr)、编解码器变换器(ed-tr)、仅解码器mamba(d-ma)、编解码mamba(ed-ma)。
为了进行公平的比较,我们使用resnet18对每种方法的视觉输入进行编码。对于使用语言指令的任务,我们使用预训练的clip模型[43]来获得相应的语言嵌入,该嵌入用于所有方法的训练和推理。
基于上述设置,我们实施以下模仿学习策略:
行为克隆(bc)我们实现了一种用变压器和曼巴结构的mse损失训练的vanilla行为克隆策略。
基于bc中相同结构的去噪扩散策略(ddp),我们进一步使用离散去噪过程实现了一种扩散策略[44]。我们为每种架构使用16个扩散时间步长进行训练和采样。
5.2 Simulation Evaluation
LIBERO
评估是使用libero基准进行的,该基准包括五个不同的任务套件:LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, LIBERO-Long, and LIBERO90。每个任务套件包括10个任务和50个人类演示,但libero-90除外,它包含90个任务,50个演示。每个任务套件都旨在测试机器人学习和操纵能力的不同方面。任务可视化如图3所示。更多细节见附录c。
Evaluation Protocol
我们分别在五个libero任务套件中比较了每种方法。除了libero-90包含900个轨迹外,我们没有使用完整的演示,而是为每个子任务只使用了20%的演示,每个任务套件总共使用了100个轨迹。我们调整变压器和曼巴的超参数,确保它们的参数量相似。所有模型都训练了50个epoch,我们使用最后一个检查点进行评估。遵循libero的官方基准设置,我们为每个子任务执行了20次部署 rollouts,每个任务套件总共进行了200次评估,但libero-90除外,它包括1800次评估。我们报告了超过3个种子的每个任务套件的平均成功率。
Main Results.
我们在表1中报告了主要结果。我们基于mamba的架构d-ma和ed-ma在基于bc策略的所有libero任务套件中的表现明显优于基于转换器的方法
表1:libero基准测试的性能,其中“w/o语言”表示我们不使用语言指令,“w/language”表示我们使用从预训练的剪辑模型生成的语言令牌,h1和h5分别表示使用当前状态和5步历史状态
具体来说,基于曼巴的模型在libero-object和libero-90中的成功率提高了近30%。当使用ddp策略时,我们的模型始终超过变压器基线,在大多数任务中性能提高超过5%。这些结果证实了q1,表明mail的性能优于变压器。
为了解决q2问题,我们使用额外的语言嵌入作为输入,将邮件与libero-target和libero-90上的transformers进行了比较。我们观察到,在这些任务中,基于mamba的方法有了显著改进,表明邮件有效地利用了多模态输入。
鉴于最近的视觉模仿学习作品使用历史观察作为输入,我们用1和5个历史观察来评估这些方法。我们发现历史信息并不总是能提高绩效。 只有在libero对象中h5模型的表现优于h1模型,而在其他任务中,h5模型取得了类似或更差的结果。
基于mamba的h5模型的性能再次始终优于基于transformer的模型,这表明mail能够有效地捕获连续的观察特征,回答了问题3
Ablation on Observation Occlusions
为了进一步了解transformer和mamba的顺序学习能力,我们随机屏蔽图像区域并测试模型的性能下降。结果如图4所示。而对于零遮挡,变压器架构可以与mamba相当,添加遮挡会更快地降低变压器的性能,表明mamba可以更好地从历史序列中提取重要信息。
Ablation on Dataset Size
鉴于邮件仅在20%的演示中表现良好,我们有兴趣随着数据集大小的增加来评估其可扩展性。我们在libero空间任务上使用bc策略将基于mamba的模型与transformer模型进行了比较。结果如图4所示。很明显,当数据稀缺时,基于mamba的模型明显优于transformer,并且随着数据集大小的增加,其性能也相当。
5.3 Real Robot Evaluation
我们基于7自由度franka熊猫机器人设计了三个具有挑战性的任务,利用模型的视觉输入。位于机器人前方不同角度的两个摄像头提供视觉数据。一个图像被裁剪并调整为(128,256,3),而另一个图像则调整为(256,256,3)。
整个设置如图5所示。这些图像在每个时间步上堆叠以形成观察结果。我们从输入中排除了机器人状态,因为之前的研究报告称,包括它们可能会导致性能不佳[7]。
动作空间是8维的,包括关节位置和夹持器状态。下面详细介绍的任务设置如图6-8所示。相应的结果如表2-4所示。
我们使用ddp-h1模型将ed-tr与ed-ma进行了比较。我们对每种方法训练了100个迭代周期(收敛),并使用最终的检查点对模型进行了评估。对于每个任务,我们为对象执行了20个具有不同初始状态的展开。为了确保公平比较,我们对变压器和曼巴评估使用了相同的初始状态。从结果来看,基于mamba的方法与变压器模型取得了相当的结果。
6 Limitations
虽然mail在较小的数据集大小下表现出了出色的性能,但随着数据集的扩展,它的优势变得不那么明显。当在更大的数据集上训练时,mail的结果与transformer模型相当,但并不超过后者。
此外,mamba的设计是为了快速高效地处理大规模序列。然而,在序列相对较短的模仿学习策略的背景下,Transformer的推理时间与mamba相似。这降低了mamba在这些场景中的性能效率优势。
7 Conclusion
总之,这项工作提出了一种新的模仿学习(il)策略架构mail,它弥合了处理观察序列的效率和性能之间的差距。通过利用状态空间模型的优势并对其进行严格改进,mail为传统上基于大型复杂变压器的策略提供了一种有竞争力的替代方案。在编码器-解码器结构中引入mamba增强了其通用性,使其既适合独立使用,也适合集成到扩散过程等高级架构中。对libero-il基准测试和真实机器人实验的广泛评估表明,mail不仅匹配而且超越了现有基线的性能,使其成为一种有前景的il任务方法。
B Transformer Architecture
我们描述了扩散策略中的两种基于变换器的架构:仅解码器模型(图9)和编码器-解码器模型(见图10)。这两种架构都利用了变压器模型的优势来有效地处理顺序数据并捕获长期依赖关系。
图9:仅解码器学习块。该架构集成了用于状态编码的resnet-18和用于地平线j动作的动作编码器,这两个组件都馈入了一个自我注意机制。位置编码(pe)和时间编码(te)增强了输入。最终,自我关注的输出被馈送到线性输出层,以预测未来的行为
图10:编码器-解码器学习块。此图说明了为策略学习设计的编码器-解码器转换器块的架构。在编码器中,状态使用resnet-18进行编码,通过时间编码(te)和位置编码(pe)进行增强,并通过自我注意进行处理。解码器然后利用对编码动作的自注意,并采用交叉注意来整合来自编码器的编码状态。最终,交叉注意力的输出被馈送到线性输出层,以预测未来的行动。
D Model Details
D.1 Parameter Comparison
我们还评估了配备rtx2060gpu的本地pc上的推理时间,使用32的批处理大小,以确保在表5中的相同条件下评估所有模型。
D.2 Training Details
我们在表6中列出了基于transformer和基于mamba的策略的训练超参数。为了确保公平比较,我们将两种策略的超参数调整到同一水平。
这些策略是使用libero提供的人类专家演示进行训练的,在主实验中,我们只对每个任务进行10次演示。
所有模型都在配备4个a100 gpu的集群上训练,批处理大小为256,使用3个不同的种子在50个迭代周期内进行训练。最后,我们计算了这3个种子的平均成功率。
E.3 Data Collection
遥操作用于收集所有真实机器人任务的数据,其中领导者机器人由人类控制,追随者机器人跟随领导者机器人,如图12所示。物体被放置在跟随机器人的前方,摄像头看不到领导机器人或人类。将引导机器人的当前关节状态作为期望的关节状态发送给跟随机器人。夹持器的状态被认为是二进制的,要么关闭,要么打开。为引导机器人的夹持器设置阈值;如果当前宽度低于阈值,跟随机器人的夹持器将关闭,否则将打开。
E.4 Evaluation
对于评估,使用模型的输出有时会激活机器人的安全机制,因为它违反了一定的约束。为了解决这个问题,在当前关节位置和模型输出之间生成轨迹。然后在每个时间步长将该轨迹的点提供给机器人,而不是模型的原始输出。该轨迹的长度取决于模型的输出与当前机器人状态的距离。
原文地址:https://blog.csdn.net/qq_33673253/article/details/140396789
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!