斯坦福iDP3源码剖析:逐步分解Improved 3D Diffusion Policy的实现(人形机器人的动作策略之一)
前言
今25年1.14日起,我和同事孙老师连续出差苏州、无锡、南京、上海
- 1.14日在苏州,一家探讨人形合作研发,一家是客户
- 1.15-1.16两天在南京,和同事姚博士、合作商一块接待一机器人集团客户
客户表示高校偏科研,但我们做到了科研与落地并重,很希望合作——主动提出拉群保持逐月推进 - 1.17日在无锡,参观一集团工厂、交流可合作开发的业务场景,并约定年后再去一趟电器厂
- 1.18日则在上海约了4位,分别来自两人形公司、一国家级实验室、一大模型独角兽
我们连连感慨,绝大部分工厂都将在今2025年开始做一系列智能升级、智能改造
- 包括且不限于物料分拣、线缆插拔(项目组里包含清华、国防科大等院校的博士生)、智能装配(和国科大一教授团队合作)、打螺钉
而背后用的策略方法,也将从传统的深度学习方法,往大模型 + 模仿学习 + RL方面迁移,这是势不可挡的大趋势 - 而我司则在加大力度赋能工厂,比如 此前如此文《斯坦福iDP3——改进3D扩散策略以赋能人形机器人的训练:不再依赖相机校准和点云分割(含3D扩散策略DP3的详解)》所述 “ 截止到25年1.12日,我们idp3的复现迎来大进展,idp3架构拆解完了,且还弄了一个通用架构——可以同时跑dexcap和ipd3 ”
- 因此在出差的间隙,我于昨天把人形动作预测策略——ipd3源码的所有代码文件整体看了下,确实如姚博士所说,模块清晰 各司其职
本想着这几天出差完后 把ipd3的源码也做下解读,想了下,只要有时间空闲,我便开始解读吧
包括1.17日从无锡来上海的路上——高铁上 酒店大堂里 网约车上 餐厅里,我都拿出了MacBook Pro修订本篇《iDP3源码剖析》博客
可能这就是为何每次出差,和做AI 大模型 具身的技术人交流时,十之八九都看过我博客(不管在哪个TOP高校 不管在哪个大厂)的原因吧,背后毕竟有着十多年的积累」
于此,今天便有了本文「注意,看本文之前,建议先通过此文了解 iDP3的原理」,且重点分析其learning的代码:Improved-3D-Diffusion-Policy
而为了让本文的源码剖析足够清晰,我是花了不少心思的,因为源码分析其实很容易变成各种堆砌代码,所以我特意做了以下这几点措施
- 待解读的代码,尽可能控制在10行以内,因为按我的经验,超过10行 看着就累了
- 即便有解读,贴的代码 也要逐行都有对应的注释
因为这样 可以更加一目了然 - 为了随时让读者知道某个被分析的函数处在哪个文件夹下,以及在整体中的位置及与前后代码文件的关联
对于较长的代码文件,我会特意在分析代码文件之前,贴一下对应的代码结构截图
如此,还是为了一目了然 - 给每个章节的代码文件名称都加上了对应的一句话说明,这样让大家一目了然被分析的代码文件是具体干什么的,且可让整个目录更有全局感,更清晰
第一部分 数据集:diffusion_policy_3d的common、config、dataset
1.1 diffusion_policy_3d/common
本文件夹下 有一些代码文件 暂未解读,比如
-
checkpoint_util.py该类用于管理模型训练过程中的检查点(checkpoint),确保只保留性能最好的k个检查点
-
// 待更
1.1.1 common/gr1_action_util.py:转换和处理与关节和末端执行器EEF相关的数据
该代码片段主要用于转换和处理与机器人关节和末端执行器EEF相关的数据
- 首先,导入了numpy、torch以及自定义的rotation_util模块,并定义了若干初始姿态与位置变量(init_arm_pos、init_arm_quat等)
- joint32_to_joint25函数将包含32个关节数据的数组转换为只包含25个关节数据的数组,主要通过选择和映射腰部、头部、手臂与手的关节索引
- joint25_to_joint32函数则执行反向操作,将25个关节数据填充回32个
- extract_eef_action函数从传入的eef_action向量中提取身体动作、双臂位置和旋转,以及手部动作
这里的手臂旋转采用6D表示法,用rotation_util模块可进一步转换至四元数 - 最后,extract_abs_eef 函数基于增量位置和旋转,计算得到新的绝对位置和旋转。它会先将四元数转换至6D旋转进行相加,再通过rotation_util还原为新的四元数,以便完整表达最终的末端执行器位姿
1.1.3 common/multi_realsense.py:管理和处理多个 RealSense 摄像头的数据流
该代码片段主要用于管理和处理多个 RealSense 摄像头的数据流。它包括初始化摄像头、获取摄像头数据、处理点云数据等功能
定个各个类之前
- 首先,导入了必要的库和模块,包括 multiprocessing、`numpy` 和 `pyrealsense2`
设置了多进程的启动方法为 `fork`,并配置了 numpy 的打印选项 - get_realsense_id 函数用于获取连接到系统的所有 RealSense 摄像头的序列号,并返回这些序列号的列表
- init_given_realsense 函数用于初始化指定的 RealSense 摄像头
它接受多个参数,包括设备序列号、是否启用 RGB 和深度流、是否启用点云、同步模式等
根据这些参数配置摄像头,并返回摄像头的管道、对齐对象、深度比例和相机信息 - grid_sample_pcd 函数用于对点云数据进行网格采样。它接受一个点云数组和网格大小,返回采样后的点云数组
CameraInfo 类用于存储相机的内参信息,包括宽度、高度、焦距、主点坐标和比例
SingleVisionProcess 类继承自 Process,用于管理单个摄像头的数据流。它在初始化时接受多个参数,包括设备序列号、队列、是否启用 RGB 和深度流、是否启用点云、同步模式、点云数量、远近裁剪距离、是否使用网格采样和图像大小
def __init__(self, device, queue, # 初始化方法,接受设备和队列作为参数
enable_rgb=True, # 是否启用 RGB 流,默认值为 True
enable_depth=False, # 是否启用深度流,默认值为 False
enable_pointcloud=False, # 是否启用点云,默认值为 False
sync_mode=0, # 同步模式,默认值为 0
num_points=2048, # 点云数量,默认值为 2048
z_far=1.0, # 远裁剪距离,默认值为 1.0
z_near=0.1, # 近裁剪距离,默认值为 0.1
use_grid_sampling=True, # 是否使用网格采样,默认值为 True
img_size=224) -> None: # 图像大小,默认值为 224
类中定义了
- get_vision 方法用于获取摄像头数据
- run 方法用于启动摄像头数据流
- terminate 方法用于终止数据流
- create_colored_point_cloud 方法用于创建带颜色的点云
MultiRealSense 类用于管理多个 RealSense 摄像头。它在初始化时接受多个参数,包括是否使用前置和右侧摄像头、摄像头索引、点云数量、远近裁剪距离、是否使用网格采样和图像大小,详见如下
# 初始化方法,接受多个参数,默认使用前置摄像头,不使用右侧摄像头
def __init__(self, use_front_cam=True, use_right_cam=False,
# 前置摄像头和右侧摄像头的索引,默认值分别为 0 和 1
front_cam_idx=0, right_cam_idx=1,
# 前置摄像头和右侧摄像头的点云数量,默认值分别为 4096 和 1024
front_num_points=4096, right_num_points=1024,
# 前置摄像头的远近裁剪距离,默认值分别为 1.0 和 0.1
front_z_far=1.0, front_z_near=0.1,
# 右侧摄像头的远近裁剪距离,默认值分别为 0.5 和 0.01
right_z_far=0.5, right_z_near=0.01,
use_grid_sampling=True, # 是否使用网格采样,默认值为 True
img_size=384): # 图像大小,默认值为 384
类中定义了
- _call方法,用于获取摄像头数据
- finalize方法,用于终止所有摄像头的数据流
- _del_方法用于在对象销毁时调用finalize方法
通过这些类和函数,代码实现了对多个 RealSense 摄像头的数据管理和处理,适用于需要同时处理多个摄像头数据的应用场景
// 待更
1.2 diffusion_policy_3d/config
1.3 diffusion_policy_3d/dataset:各种数据集及相关处理
1.3.1 dataset/base_dataset.py:低维、图像、点云、通用等4类数据集
该代码文件定义了四个基类,分别用于处理低维数据集、图像数据集、点云数据集和通用数据集。这些基类继承自 torch.utils.data.Dataset,并定义了一些抽象方法和默认行为
下面逐一阐述这4个基类
处理低维数据集:BaseLowdimDataset类
class BaseLowdimDataset(torch.utils.data.Dataset):
def get_validation_dataset(self) -> 'BaseLowdimDataset':
# 默认返回一个空的数据集
return BaseLowdimDataset()
def get_normalizer(self, **kwargs) -> LinearNormalizer:
raise NotImplementedError() # 抛出未实现的异常
def get_all_actions(self) -> torch.Tensor:
raise NotImplementedError() # 抛出未实现的异常
# 定义 __len__ 方法,返回数据集的长度
def __len__(self) -> int:
return 0 # 默认返回 0
# 定义 __getitem__ 方法,返回一个包含观察和动作的字典
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
output:
obs: T, Do # 观察数据的形状为 (T, Do)
action: T, Da # 动作数据的形状为 (T, Da)
"""
raise NotImplementedError() # 抛出未实现的异常
另外三个类的实现差不多,就不再一一贴它们的代码了
- 处理图像数据集:BaseImageDataset 类
- 处理点云数据集:BasePointcloudDataset 类
- 通用的数据集基类:BaseDataset 类
这些基类为不同类型的数据集提供了统一的接口和默认行为,子类可以继承这些基类并实现具体的方法,以处理特定类型的数据集
1.3.2 dataset/gr1_dex_dataset_3d.py:处理 3D 数据集
GR1DexDataset3D 类继承自 BaseDataset,用于处理 3D 数据集
其构造函数接受多个参数
def __init__(self,
zarr_path, # 数据集路径
horizon=1, # 时间跨度
pad_before=0, # 前填充
pad_after=0, # 后填充
seed=42, # 随机种子
val_ratio=0.0, # 验证集比例
max_train_episodes=None, # 最大训练集数量
task_name=None, # 任务名称
num_points=4096, # 点云数量
):
在初始化过程中,使用 cprint 打印加载数据集的信息,并设置类的属性
在构造函数中
- 首先调用父类的构造函数 super().__init__() 进行初始化。然后,使用 cprint 打印加载数据集的信息,并设置类的属性 task_name 和 num_points
super().__init__() # 调用父类的构造函数 cprint(f'Loading GR1DexDataset from {zarr_path}', 'green') # 打印加载数据集的信息 self.task_name = task_name # 设置任务名称 self.num_points = num_points # 设置点云数量
- 接下来,定义一个包含 `state` 和 `action` 的 buffer_keys 列表,并将 `point_cloud` 添加到该列表中
通过调用 ReplayBuffer.copy_from_path 方法,从指定路径加载数据,并生成一个 ReplayBuffer 对象buffer_keys = [ # 定义缓冲区键列表 'state', # 状态 'action', # 动作 ] buffer_keys.append('point_cloud') # 添加点云键
self.replay_buffer = ReplayBuffer.copy_from_path( # 从指定路径加载重放缓冲区 zarr_path, keys=buffer_keys)
- 接着,使用 get_val_mask 方法生成验证集掩码 val_mask,并通过取反操作生成训练集掩码 train_mask
为了控制训练集的大小,使用 downsample_mask 方法对训练集掩码进行下采样val_mask = get_val_mask( # 获取验证集掩码 n_episodes=self.replay_buffer.n_episodes, # 重放缓冲区中的集数 val_ratio=val_ratio, # 验证集比例 seed=seed) # 随机种子 train_mask = ~val_mask # 训练集掩码为验证集掩码的取反 train_mask = downsample_mask( # 对训练集掩码进行下采样 mask=train_mask, # 掩码 max_n=max_train_episodes, # 最大训练集数量 seed=seed) # 随机种子
- 最后,创建一个 SequenceSampler 对象 self.sampler,用于从重放缓冲区中采样数据
SequenceSampler 对象的初始化参数包括重放缓冲区 replay_buffer、时间跨度 sequence_length、填充参数 pad_before 和 pad_after 以及训练集掩码 episode_mask
构造函数还设置了类的其他属性,包括 train_mask、horizon、pad_before 和 pad_afterself.sampler = SequenceSampler( # 创建序列采样器 replay_buffer=self.replay_buffer, # 重放缓冲区 sequence_length=horizon, # 序列长度 pad_before=pad_before, # 前填充 pad_after=pad_after, # 后填充 episode_mask=train_mask) # 训练集掩码
通过这些步骤,构造函数完成了数据集对象的初始化,为后续的数据处理和模型训练提供了基础self.train_mask = train_mask # 设置训练集掩码 self.horizon = horizon # 设置时间跨度 self.pad_before = pad_before # 设置前填充 self.pad_after = pad_after # 设置后填充
接下来,get_validation_dataset 方法用于生成验证数据集
它通过浅拷贝当前对象,并创建一个新的 SequenceSampler 对象,使用验证集掩码来替换训练集掩码
其次,get_normalizer 方法用于生成数据归一化器
它首先从重放缓冲区中提取 `action` 数据,并使用 LinearNormalizer 对其进行拟合。然后,为 point_cloud 和 agent_pos 创建身份归一化器,并返回归一化器对象
而剩下的方法有
- __len__ 方法返回数据集的长度,即采样器的长度
- _sample_to_data 方法将采样的数据转换为所需的格式
包括将状态和点云数据转换为浮点数,并对点云数据进行均匀采样 - __getitem__ 方法根据索引从采样器中获取数据样本,并将其转换为 PyTorch 张量
通过 dict_apply 方法,将数据字典中的所有 NumPy 数组转换为 PyTorch 张量,并返回转换后的数据
1.3.3 dataset/gr1_dex_dataset_image.py:处理图像和深度信息
GR1DexDatasetImage 类继承自 BaseDataset,用于处理包含图像和深度信息的数据集
其构造函数接受多个参数
def __init__(self,
zarr_path, # 数据集路径
horizon=1, # 时间跨度
pad_before=0, # 前填充
pad_after=0, # 后填充
seed=42, # 随机种子
val_ratio=0.0, # 验证集比例
max_train_episodes=None, # 最大训练集数量
task_name=None, # 任务名称
use_img=True, # 是否使用图像
use_depth=False, # 是否使用深度信息
):
在初始化过程中,使用 cprint 打印加载数据集的信息,并设置类的属性
- 该类首先定义了一个包含 `state` 和 `action` 的 buffer_keys 列表
如果 use_img 为真,则将 `img` 添加到 buffer_keys 列表中;如果 use_depth 为真,则将 depth 添加到 buffer_keys 列表中self.task_name = task_name # 设置任务名称 self.use_img = use_img # 设置是否使用图像 self.use_depth = use_depth # 设置是否使用深度信息 buffer_keys = [ # 定义缓冲区键列表 'state', # 状态 'action', # 动作
然后,通过调用 ReplayBuffer.copy_from_path 方法从指定路径加载数据,并生成一个 ReplayBuffer 对象if self.use_img: # 如果使用图像 buffer_keys.append('img') # 添加图像键 if self.use_depth: # 如果使用深度信息 buffer_keys.append('depth') # 添加深度键
self.replay_buffer = ReplayBuffer.copy_from_path( # 从指定路径加载重放缓冲区 zarr_path, keys=buffer_keys)
- 接着,使用 get_val_mask 方法生成验证集掩码 val_mask,并通过取反操作生成训练集掩码 train_mask
为了控制训练集的大小,使用 downsample_mask 方法对训练集掩码进行下采样val_mask = get_val_mask( # 获取验证集掩码 n_episodes=self.replay_buffer.n_episodes, # 重放缓冲区中的集数 val_ratio=val_ratio, # 验证集比例 seed=seed) # 随机种子 train_mask = ~val_mask # 训练集掩码为验证集掩码的取反 train_mask = downsample_mask( # 对训练集掩码进行下采样 mask=train_mask, # 掩码 max_n=max_train_episodes, # 最大训练集数量 seed=seed) # 随机种子
- 最后,创建一个 SequenceSampler 对象 self.sampler,用于从重放缓冲区中采样数据
self.sampler = SequenceSampler( # 创建序列采样器 replay_buffer=self.replay_buffer, # 重放缓冲区 sequence_length=horizon, # 序列长度 pad_before=pad_before, # 前填充 pad_after=pad_after, # 后填充 episode_mask=train_mask) # 训练集掩码 self.train_mask = train_mask # 设置训练集掩码 self.horizon = horizon # 设置时间跨度 self.pad_before = pad_before # 设置前填充 self.pad_after = pad_after # 设置后填充
接下来,get_validation_dataset 方法用于生成验证数据集
它通过浅拷贝当前对象,并创建一个新的 SequenceSampler 对象,使用验证集掩码来替换训练集掩码
且get_normalizer 方法用于生成数据归一化器
- 它首先从重放缓冲区中提取 `action` 数据,并使用 LinearNormalizer 对其进行拟合
如果 use_img 为真,则为 image 创建身份归一化器;如果 use_depth 为真,则为 depth 创建身份归一化器 - 最后,为 agent_pos 创建身份归一化器,并返回归一化器对象
至于剩下的方法和上节的gr1_dex_dataset_3d.py一样
- __len__ 方法返回数据集的长度,即采样器的长度
- _sample_to_data 方法将采样的数据转换为所需的格式,包括将状态数据转换为浮点数,并根据需要处理图像和深度数据
- __getitem__ 方法根据索引从采样器中获取数据样本,并将其转换为 PyTorch 张量。通过 dict_apply 方法,将数据字典中的所有 NumPy 数组转换为 PyTorch 张量,并返回转换后的数据
第二部分 扩散模型与3D点云编码器的实现:diffusion_policy_3d/model
2.1 model/common
2.2 model/diffusion
2.2.1 diffusion/conditional_unet1d.py:分别实现交叉注意力、条件残差块、条件U-Net 网络
2.2.1.1 CrossAttention 类:实现交叉注意力
CrossAttention 类是一个用于实现交叉注意力机制的 PyTorch 模块
它在初始化时接受三个参数:输入维度 in_dim、条件维度 cond_dim 和输出维度 out_dim
def __init__(self, in_dim, cond_dim, out_dim):
- 在 __init__ 方法中,定义了三个线性投影层 query_proj、key_proj 和 value_proj,分别用于将输入 x 和条件 cond 投影到查询、键和值
super().__init__() self.query_proj = nn.Linear(in_dim, out_dim) self.key_proj = nn.Linear(cond_dim, out_dim) self.value_proj = nn.Linear(cond_dim, out_dim)
- 在 forward 方法中
首先将输入 x 和条件 cond 投影到查询、键和值
然后计算注意力权重,并通过软最大化函数进行归一化def forward(self, x, cond): # x: [batch_size, t_act, in_dim] # cond: [batch_size, t_obs, cond_dim] # Project x and cond to query, key, and value query = self.query_proj(x) # [batch_size, horizon, out_dim] key = self.key_proj(cond) # [batch_size, horizon, out_dim] value = self.value_proj(cond) # [batch_size, horizon, out_dim]
最后,应用注意力权重到值上,得到注意力输出# Compute attention attn_weights = torch.matmul(query, key.transpose(-2, -1)) # [batch_size, horizon, horizon] attn_weights = F.softmax(attn_weights, dim=-1)
# Apply attention attn_output = torch.matmul(attn_weights, value) # [batch_size, horizon, out_dim] return attn_output
2.2.1.2 ConditionalResidualBlock1D 类:条件残差块,在一维卷积网络中实现条件处理
ConditionalResidualBlock1D 类是一个条件残差块,用于在一维卷积网络中实现条件处理
它在初始化时接受多个参数,如下所示
def __init__(self, # 定义构造函数
in_channels, # 输入通道数
out_channels, # 输出通道数
cond_dim, # 条件维度
kernel_size=3, # 卷积核大小,默认值为3
n_groups=8, # 组归一化的组数,默认值为8
condition_type='film'): # 条件类型,默认值为'film'
在初始化过程中,定义了两个一维卷积块,并根据条件类型初始化条件编码器 cond_encoder
在构造函数中
- 首先创建了一个包含两个 Conv1dBlock 的 nn.ModuleList,每个 Conv1dBlock 包含一维卷积、组归一化和 Mish 激活函数
self.blocks = nn.ModuleList([ # 定义一个包含两个卷积块的模块列表 Conv1dBlock(in_channels, # 第一个卷积块,输入通道数为 in_channels out_channels, # 输出通道数为 out_channels kernel_size, # 卷积核大小 n_groups=n_groups), # 组归一化的组数 Conv1dBlock(out_channels, # 第二个卷积块,输入通道数为 out_channels out_channels, # 输出通道数为 out_channels kernel_size, # 卷积核大小 n_groups=n_groups), # 组归一化的组数 ])
- 接着,根据条件类型 condition_type 初始化条件编码器 cond_encoder
如果条件类型为 `film`,则创建一个 nn.Sequential,包含 Mish 激活函数、线性层和 Rearrange 操作,用于预测每个通道的缩放和偏移self.condition_type = condition_type # 设置条件类型 cond_channels = out_channels # 条件通道数初始为输出通道数
如果条件类型为 `add`,则创建一个包含 Mish 激活函数、线性层和 Rearrange 操作的 nn.Sequentialif condition_type == 'film': # 如果条件类型为 'film' # 预测每个通道的缩放和偏移 cond_channels = out_channels * 2 # 条件通道数为输出通道数的两倍 self.cond_encoder = nn.Sequential( # 定义条件编码器 nn.Mish(), # Mish 激活函数 nn.Linear(cond_dim, cond_channels), # 线性层 Rearrange('batch t -> batch t 1'), # 重新排列张量维度 )
如果条件类型为 `cross_attention_add` 或 `cross_attention_film`,则使用 CrossAttention类进行交叉注意力计算elif condition_type == 'add': # 如果条件类型为 'add' self.cond_encoder = nn.Sequential( # 定义条件编码器 nn.Mish(), # Mish 激活函数 nn.Linear(cond_dim, out_channels), # 线性层 Rearrange('batch t -> batch t 1'), # 重新排列张量维度 )
如果条件类型为 `mlp_film`,则创建一个包含两个 Mish 激活函数和两个线性层的 nn.Sequentialelif condition_type == 'cross_attention_add': # 如果条件类型为 'cross_attention_add' self.cond_encoder = CrossAttention(in_channels, cond_dim, out_channels) # 定义交叉注意力编码器 elif condition_type == 'cross_attention_film': # 如果条件类型为 'cross_attention_film' cond_channels = out_channels * 2 # 条件通道数为输出通道数的两倍 self.cond_encoder = CrossAttention(in_channels, cond_dim, cond_channels) # 定义交叉注意力编码器
如果条件类型未实现,则抛出 NotImplementedError 异常elif condition_type == 'mlp_film': # 如果条件类型为 'mlp_film' cond_channels = out_channels * 2 # 条件通道数为输出通道数的两倍 self.cond_encoder = nn.Sequential( # 定义条件编码器 nn.Mish(), # Mish 激活函数 nn.Linear(cond_dim, cond_dim), # 线性层 nn.Mish(), # Mish 激活函数 nn.Linear(cond_dim, cond_channels), # 线性层 Rearrange('batch t -> batch t 1'), # 重新排列张量维度 )
else: # 如果条件类型未实现 raise NotImplementedError(f"condition_type {condition_type} not implemented") # 抛出未实现的异常
在上述初始化的基础上,forward 方法中
- 首先通过第一个卷积块处理输入 x
- 如果提供了条件 cond,则根据条件类型对输出进行调整
如果条件类型为 `film`,则通过条件编码器生成缩放和偏移,并应用于输出
如果条件类型为 `add`,则将条件编码器的输出与当前输出相加
如果条件类型为 `cross_attention_add` 或 `cross_attention_film`,则通过交叉注意力计算生成嵌入,并应用于输出
如果条件类型为 `mlp_film`,则通过条件编码器生成缩放和偏移,并应用于输出 - 最后,通过第二个卷积块处理输出,并将其与残差连接相加,返回最终输出
2.2.1.3 ConditionalUnet1D:条件一维 U-Net 网络,在一维数据上实现条件生成任务
ConditionalUnet1D 类是一个条件一维 U-Net 网络,用于在一维数据上实现条件生成任务
它在初始化时接受多个参数,如下所示
def __init__(self, # 定义构造函数
input_dim, # 输入维度
local_cond_dim=None, # 局部条件维度,默认值为 None
global_cond_dim=None, # 全局条件维度,默认值为 None
diffusion_step_embed_dim=256, # 扩散步嵌入维度,默认值为 256
down_dims=[256, 512, 1024], # 下采样维度列表,默认值为 [256, 512, 1024]
kernel_size=3, # 卷积核大小,默认值为 3
n_groups=8, # 组归一化的组数,默认值为 8
condition_type='film', # 条件类型,默认值为 'film'
use_down_condition=True, # 是否使用下采样条件,默认值为 True
use_mid_condition=True, # 是否使用中间条件,默认值为 True
use_up_condition=True, # 是否使用上采样条件,默认值为 True
):
- 在 __init__ 方法中,定义了扩散步编码器、局部条件编码器、中间模块、下采样模块和上采样模块,并初始化最终的卷积层
- 在 forward 方法中,首先对时间步进行编码,然后根据条件类型对局部和全局条件进行处理,最后通过下采样、中间处理和上采样阶段生成最终输出
2.2.2 diffusion/conv1d_components.py:涉及一维卷积、下采样、上采样
该代码定义了几个用于一维卷积操作的 PyTorch 模块,包括 Downsample1d、Upsample1d 和 Conv1dBlock
- Downsample1d 类是一个用于一维下采样的模块。它在初始化时接受一个参数 dim,并定义了一个一维卷积层 self.conv,该卷积层的卷积核大小为 3,步幅为 2,填充为 1
在 forward 方法中,输入 x 通过卷积层进行下采样class Downsample1d(nn.Module): def __init__(self, dim): # 定义构造函数,接受一个参数 dim super().__init__() # 调用父类的构造函数 self.conv = nn.Conv1d(dim, dim, 3, 2, 1) # 定义一个一维卷积层,卷积核大小为 3,步幅为 2,填充为 1
def forward(self, x): # 定义前向传播函数 return self.conv(x) # 返回卷积后的结果
- Upsample1d 类是一个用于一维上采样的模块。它在初始化时同样接受一个参数 dim,并定义了一个一维反卷积层 self.conv,该反卷积层的卷积核大小为 4,步幅为 2,填充为 1
在 forward 方法中,输入 x 通过反卷积层进行上采样class Upsample1d(nn.Module): def __init__(self, dim): # 定义构造函数,接受一个参数 dim super().__init__() # 调用父类的构造函数 self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) # 定义一个一维反卷积层,卷积核大小为 4,步幅为 2,填充为 1
def forward(self, x): # 定义前向传播函数 return self.conv(x) # 返回反卷积后的结果
- Conv1dBlock 类是一个包含一维卷积、组归一化和 Mish 激活函数的模块
它在初始化时接受多个参数,如下所示
在 __init__ 方法中,定义了一个顺序容器 self.block,其中包含一维卷积层、组归一化层和 Mish 激活函数class Conv1dBlock(nn.Module): ''' Conv1d --> GroupNorm --> Mish # 一维卷积 --> 组归一化 --> Mish 激活函数 ''' # 定义构造函数,接受输入通道数、输出通道数、卷积核大小、组数 def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__()
在 forward 方法中,输入 x 通过顺序容器中的各层进行处理# 定义一个顺序容器 self.block = nn.Sequential( # 一维卷积层,填充为卷积核大小的一半 nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), # 重新排列张量维度(已注释) # Rearrange('batch channels horizon -> batch channels 1 horizon'), # 组归一化层 nn.GroupNorm(n_groups, out_channels), # 重新排列张量维度(已注释) # Rearrange('batch channels 1 horizon -> batch channels horizon'), # Mish 激活函数 nn.Mish(), )
def forward(self, x): # 定义前向传播函数 return self.block(x) # 返回顺序容器处理后的结果
- 最后,定义了一个 test 函数,用于测试 Conv1dBlock 模块
该函数创建了一个 Conv1dBlock 实例 cb,并生成一个形状为 `(1, 256, 16)` 的全零张量 x
然后,将 x 传递给 cb 进行处理,并将输出存储在变量 o 中# 定义测试函数 def test(): # 创建一个 Conv1dBlock 实例 cb = Conv1dBlock(256, 128, kernel_size=3) # 创建一个全零张量,形状为 (1, 256, 16) x = torch.zeros((1, 256, 16)) # 将张量传入 Conv1dBlock 实例,并获取输出 o = cb(x)
2.2.3 diffusion/ema_model.py:实现模型权重的指数移动平均EMA
该代码定义了一个名为 EMAModel 的类,用于实现模型权重的指数移动平均(EMA)。EMA 是一种常用的技术,通过对模型权重进行平滑处理,可以提高模型的稳定性和泛化能力。
- 在 EMAModel类的初始化方法 __init__ 中,接受多个参数,如下所示
初始化过程中,将传入的模型设置为评估模式,并禁用其梯度计算。还初始化了一些其他属性,如 EMA 衰减率 decay 和优化步数 optimization_stepclass EMAModel: # 定义 EMAModel 类 """ 模型权重的指数移动平均 """ # 定义构造函数 def __init__( self, model, # 模型 update_after_step=0, # 在多少步之后开始更新 EMA 的步数 update_after_step inv_gamma=1.0, # EMA 预热的逆乘法因子,默认值为 1.0 power=2 / 3, # EMA 预热的指数因子,默认值为 2/3 min_value=0.0, # EMA 的最小衰减率,默认值为 0.0 max_value=0.9999 # EMA 的最大衰减率,默认值为 0.9999 ):
""" @crowsonkb 关于 EMA 预热的笔记: 如果 gamma=1 且 power=1,则实现简单平均。gamma=1,power=2/3 是适合训练一百万步或更多步的模型的好值 (在 31.6K 步时达到衰减因子 0.999,在 1M 步时达到 0.9999), gamma=1,power=3/4 适合训练较少步数的模型(在 10K 步时达到衰减因子 0.999,在 215.4K 步时达到 0.9999)。 参数: inv_gamma (float): EMA 预热的逆乘法因子。默认值: 1。 power (float): EMA 预热的指数因子。默认值: 2/3。 min_value (float): EMA 的最小衰减率。默认值: 0。 """ self.averaged_model = model # 设置平均模型 self.averaged_model.eval() # 将平均模型设置为评估模式 self.averaged_model.requires_grad_(False) # 禁用平均模型的梯度计算 self.update_after_step = update_after_step # 设置在多少步之后开始更新 EMA self.inv_gamma = inv_gamma # 设置 EMA 预热的逆乘法因子 self.power = power # 设置 EMA 预热的指数因子 self.min_value = min_value # 设置 EMA 的最小衰减率 self.max_value = max_value # 设置 EMA 的最大衰减率 self.decay = 0.0 # 初始化衰减率 self.optimization_step = 0 # 初始化优化步数
- get_decay 方法用于计算 EMA 的衰减因子。它根据当前的优化步数计算衰减因子,并确保其在 min_value 和 max_value 之间。如果当前步数小于等于 0,则返回 0.0
- step 方法用于更新 EMA 模型的权重。该方法使用 torch.no_grad() 装饰器,以确保在更新权重时不会计算梯度
首先,计算当前步数的衰减因子
然后,遍历新模型和 EMA 模型的所有模块和参数,并根据参数类型和条件更新 EMA 模型的权重@torch.no_grad() # 使用 torch.no_grad() 装饰器,禁用梯度计算 def step(self, new_model): # 定义更新 EMA 模型的方法 self.decay = self.get_decay(self.optimization_step) # 获取当前步数的衰减因子
如果参数是批归一化层的参数或不需要梯度计算的参数,则直接复制新模型的参数值
否则,使用 EMA 衰减因子对参数进行加权更新all_dataptrs = set() # 初始化数据指针集合 # 遍历新模型和平均模型的所有模块 for module, ema_module in zip(new_model.modules(), self.averaged_model.modules()): # 遍历模块的所有参数 for param, ema_param in zip(module.parameters(recurse=False), ema_module.parameters(recurse=False)): # 仅迭代直接参数 if isinstance(param, dict): # 如果参数是字典 raise RuntimeError('Dict parameter not supported') # 抛出运行时异常 if isinstance(module, _BatchNorm): # 如果模块是批归一化层 # 跳过批归一化层 ema_param.copy_(param.to(dtype=ema_param.dtype).data) # 复制参数数据 # 如果参数不需要梯度计算 elif not param.requires_grad: ema_param.copy_(param.to(dtype=ema_param.dtype).data) # 复制参数数据
最后,增加优化步数else: # 乘以衰减因子 ema_param.mul_(self.decay) # 加上参数数据乘以 (1 - 衰减因子) ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
# 验证遍历模块然后参数与递归遍历参数是否相同 self.optimization_step += 1 # 增加优化步数
通过这种方式,EMAModel 类可以在训练过程中平滑地更新模型权重,从而提高模型的稳定性和性能
2.2.4 diffusion/mask_generator.py
该代码片段定义了几个用于生成掩码的函数和类,这些掩码生成器类通过不同的配置和条件,生成适用于各种深度学习任务的掩码,方便模型处理不同的输入维度和条件
// 待更
2.2.5 diffusion/positional_embedding.py:为输入数据添加位置信息
SinusoidalPosEmb 类是一个用于生成正弦位置嵌入的 PyTorch 模块,用于为输入数据添加位置信息,其对应的公式 如下
在 __init__ 方法中,接受一个参数 dim,表示嵌入的维度。调用 super().__init__() 初始化父类 nn.Module,并将 dim 存储为实例属性
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
forward 方法用于计算位置嵌入
- 首先,获取输入张量 x 的设备信息 device
def forward(self, x): device = x.device
- 然后,计算嵌入维度的一半 half_dim
half_dim = self.dim // 2
- 接下来,计算一个常数 emb,该常数用于缩放位置索引
对应的公式为emb = math.log(10000) / (half_dim - 1)
然后使用 torch.arange 生成一个从 0 到 half_dim 的张量,并将其乘以 `-emb`,然后通过 torch.exp 计算指数
其对应的公式为emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
简化下是
- 接着,将输入张量 x 扩展维度并与生成的指数张量相乘
对应公式为emb = x[:, None] * emb[None, :]
- 最后,通过 torch.cat 将正弦和余弦嵌入拼接在一起,得到最终的嵌入张量
最终对应的公式为emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb
注意,transformer原始论文中对位置编码的公式为
如不太理解,详见此文《一文通透位置编码:从标准位置编码、旋转位置编码RoPE到ALiBi、LLaMA 2 Long(含NTK-aware简介)》
2.3 model/vision
2.4 model/vision_3d
2.4.1 vision_3d/multi_stage_pointnet.py:对点云数据进行编码
该代码定义了一个名为 MultiStagePointNetEncoder 的 PyTorch 模块,用于对点云数据进行编码。该模块包含两个辅助函数 meanpool 和 maxpool,以及一个主要的编码器类 MultiStagePointNetEncoder
- meanpool 函数用于在指定维度上对输入张量 x 进行平均池化操作
def meanpool(x, dim=-1, keepdim=False): out = x.mean(dim=dim, keepdim=keepdim) return out
- maxpool 函数用于在指定维度上对输入张量 x 进行最大池化操作
def maxpool(x, dim=-1, keepdim=False): out = x.max(dim=dim, keepdim=keepdim).values return out
- MultiStagePointNetEncoder 类继承自 nn.Module,用于实现多阶段的点云编码器。其构造函数接受多个参数,如下所示
在初始化过程中,定义了激活函数 LeakyReLU、输入卷积层 conv_in、多个隐藏层 layers 和全局层 global_layers,以及输出卷积层 conv_outclass MultiStagePointNetEncoder(nn.Module): # 定义构造函数,接受隐藏维度、输出通道数、层数和其他参数 def __init__(self, h_dim=128, out_channels=128, num_layers=4, **kwargs): super().__init__()
在 forward 方法中self.h_dim = h_dim # 设置隐藏维度 self.out_channels = out_channels # 设置输出通道数 self.num_layers = num_layers # 设置层数 # 定义 LeakyReLU 激活函数 self.act = nn.LeakyReLU(negative_slope=0.0, inplace=False) # 定义输入卷积层,输入通道数为 3,输出通道数为 h_dim,卷积核大小为 1 self.conv_in = nn.Conv1d(3, h_dim, kernel_size=1) # 定义两个模块列表,分别用于存储局部卷积层和全局卷积层 self.layers, self.global_layers = nn.ModuleList(), nn.ModuleList() # 遍历层数 for i in range(self.num_layers): # 添加局部卷积层,输入和输出通道数均为 h_dim,卷积核大小为 1 self.layers.append(nn.Conv1d(h_dim, h_dim, kernel_size=1)) # 添加全局卷积层,输入通道数为 h_dim * 2,输出通道数为 h_dim,卷积核大小为 1 self.global_layers.append(nn.Conv1d(h_dim * 2, h_dim, kernel_size=1)) # 定义输出卷积层,输入通道数为 h_dim * 层数,输出通道数为 out_channels,卷积核大小为 1 self.conv_out = nn.Conv1d(h_dim * self.num_layers, out_channels, kernel_size=1)
首先将输入张量 x 的维度进行转换
然后通过输入卷积层和激活函数进行初步处理
接着,遍历每一层,对输入进行卷积和激活处理,并计算全局特征,将其与当前特征拼接。将所有层的特征拼接后,通过输出卷积层进行处理
最后在指定维度上进行最大池化,得到全局特征 x_global 并返回
该编码器模块通过多层卷积和全局特征提取,能够有效地对点云数据进行编码,提取出有用的全局特征。
2.4.2 vision_3d/point_process.py:针对点云的打乱/填充/采样操作(含NumPy和PyTorch实现)
该代码提供了一些用于点云处理的 PyTorch 和 NumPy 实现——点云处理在计算机视觉和3D建模中非常重要,特别是在处理和分析3D数据时
- 首先,导入了必要的库 torch 和 numpy
然后,定义了一个 __all__ 列表,指定了该模块中可以被外部导入的函数,包括 shuffle_point_torch、pad_point_torch 和 uniform_sampling_torch - 对点云数据进行随机打乱:shuffle_point_numpy。它接受一个形状为 `(B, N, C)` 的点云张量,其中 B 是批量大小,N 是点的数量,C 是每个点的特征维度。函数通过 np.random.permutation 生成一个随机排列的索引,并返回打乱后的点云
- 对点云数据进行填充:pad_point_numpy
如果点的数量少于指定的 num_points,则用零点进行填充。填充后,调用 shuffle_point_numpy 函数对点云进行随机打乱 - 对点云数据进行均匀采样:uniform_sampling_numpy
如果点的数量少于指定的 num_points,则调用 pad_point_numpy 进行填充。否则,通过 np.random.permutation 生成随机索引,并返回采样后的点云 - 打乱之shuffle_point_torch 函数是 shuffle_point_numpy 的 PyTorch 实现
它使用 torch.randperm 生成随机排列的索引,并返回打乱后的点云 - 填充之pad_point_torch 函数是 pad_point_numpy 的 PyTorch 实现
它首先检查点的数量是否少于指定的 num_points,如果是,则用零点进行填充。填充后,调用 shuffle_point_torch 函数对点云进行随机打乱 - 采样之uniform_sampling_torch 函数是 uniform_sampling_numpy 的 PyTorch 实现
如果点的数量等于指定的 num_points,则直接返回点云。如果点的数量少于指定的 num_points,则调用 pad_point_torch 进行填充。否则,通过 torch.randperm 生成随机索引,并返回采样后的点云
这些函数为点云数据的处理提供了基础操作,包括随机打乱、填充和均匀采样,适用于不同的框架——NumPy 和 PyTorch
2.4.3 vision_3d/pointnet_extractor.py:包含点云编码器iDP3Encoder的实现
该代码片段定义了一个用于创建多层感知机(MLP)的函数 create_mlp,以及两个编码器类 StateEncoder 和 iDP3Encoder,用于处理状态和点云数据
首先,create_mlp 函数用于创建一个多层感知机(MLP),即一系列全连接层,每个全连接层后面跟随一个激活函数
- 函数接受五个参数:如下所示
def create_mlp( input_dim: int, # 输入维度 output_dim: int, # 输出维度 net_arch: List[int], # 神经网络的架构,表示每层的单元数 activation_fn: Type[nn.Module] = nn.ReLU, # 每层之后使用的激活函数,默认值为 nn.ReLU squash_output: bool = False, # 是否使用 Tanh 激活函数压缩输出,默认值为 False ) -> List[nn.Module]: # 返回值为 nn.Module 的列表
- 函数首先根据 net_arch 创建第一层全连接层和激活函数
然后遍历 net_arch 创建中间层if len(net_arch) > 0: modules = [nn.Linear(input_dim, net_arch[0]), activation_fn()] else: modules = []
最后添加输出层和可选的 Tanh 激活函数for idx in range(len(net_arch) - 1): modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) modules.append(activation_fn())
返回值是一个包含所有层的模块列表if output_dim > 0: last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim modules.append(nn.Linear(last_layer_dim, output_dim)) if squash_output: modules.append(nn.Tanh())
return modules
其次,StateEncoder 类继承自 nn.Module,用于对状态数据进行编码。其构造函数接受三个参数:如下所示
class StateEncoder(nn.Module):
def __init__(self,
observation_space: Dict, # 观察空间的字典
state_mlp_size=(64, 64), # 状态 MLP 的大小,默认值为 (64, 64)
state_mlp_activation_fn=nn.ReLU): # 状态 MLP 的激活函数,默认值为 nn.ReLU
super().__init__()
- 在初始化过程中,首先获取状态的形状,并根据 state_mlp_size 创建 MLP
self.state_key = 'full_state' # 设置状态键 self.state_shape = observation_space[self.state_key] # 获取状态的形状 cprint(f"[StateEncoder] state shape: {self.state_shape}", "yellow") # 打印状态形状 if len(state_mlp_size) == 0: # 如果状态 MLP 的大小为空 raise RuntimeError(f"State mlp size is empty") # 抛出运行时异常 elif len(state_mlp_size) == 1: # 如果状态 MLP 的大小为 1 net_arch = [] else: net_arch = state_mlp_size[:-1] # 网络架构为状态 MLP 大小的前 n-1 个元素 output_dim = state_mlp_size[-1] # 输出维度为状态 MLP 大小的最后一个元素 self.state_mlp = nn.Sequential(*create_mlp(self.state_shape[0], output_dim, net_arch, state_mlp_activation_fn)) # 创建状态 MLP cprint(f"[StateEncoder] output dim: {output_dim}", "red") # 打印输出维度 self.output_dim = output_dim # 设置输出维度
- forward 方法接受一个包含状态数据的字典 observations,并通过 MLP 对状态进行编码,返回编码后的特征
最后,iDP3Encoder 类同样继承自 nn.Module,用于对点云数据和状态数据进行联合编码
其构造函数接受多个参数,包括
def __init__(self,
observation_space: Dict, # 观察空间的字典
state_mlp_size=(64, 64), # 状态 MLP 的大小
state_mlp_activation_fn=nn.ReLU, # 状态 MLP 的激活函数
pointcloud_encoder_cfg=None, # 点云编码器的配置
use_pc_color=False, # 是否使用点云颜色
pointnet_type='dp3_encoder', # 点网类型
point_downsample=True, # 是否对点云进行下采样
):
- 在初始化过程中,设置了状态和点云的键值,并根据配置初始化点云预处理方法和点网编码器
在构造函数中,首先获取点云和状态的形状,并根据配置选择点云预处理方法super().__init__() # 调用父类的构造函数 self.state_key = 'agent_pos' # 状态键 self.point_cloud_key = 'point_cloud' # 点云键 self.n_output_channels = pointcloud_encoder_cfg.out_channels # 输出通道数
如果 pointnet_type 为 "multi_stage_pointnet",则导入并实例化 MultiStagePointNetEncoder 作为点云特征提取器self.point_cloud_shape = observation_space[self.point_cloud_key] # 获取点云的形状 self.state_shape = observation_space[self.state_key] # 获取状态的形状 self.num_points = pointcloud_encoder_cfg.num_points # 点的数量,默认为 4096
否则,抛出 NotImplementedError 异常
接着,根据 state_mlp_size 创建状态 MLP,并计算输出通道数self.downsample = point_downsample # 是否对点云进行下采样 if self.downsample: # 如果需要下采样 self.point_preprocess = point_process.uniform_sampling_torch # 使用均匀采样 else: # 否则 self.point_preprocess = nn.Identity() # 使用 Identity 层 if pointnet_type == "multi_stage_pointnet": # 如果点网类型为 "multi_stage_pointnet" from .multi_stage_pointnet import MultiStagePointNetEncoder # 导入 MultiStagePointNetEncoder self.extractor = MultiStagePointNetEncoder(out_channels=pointcloud_encoder_cfg.out_channels) # 实例化点云特征提取器 else: # 否则 raise NotImplementedError(f"pointnet_type: {pointnet_type}") # 抛出未实现的异常
if len(state_mlp_size) == 0: # 如果状态 MLP 的大小为空 raise RuntimeError(f"State mlp size is empty") # 抛出运行时异常 elif len(state_mlp_size) == 1: # 如果状态 MLP 的大小为 1 net_arch = [] # 网络架构为空 else: # 否则 net_arch = state_mlp_size[:-1] # 网络架构为状态 MLP 大小的前 n-1 个元素 output_dim = state_mlp_size[-1] # 输出维度为状态 MLP 大小的最后一个元素 self.n_output_channels += output_dim # 输出通道数加上输出维度 self.state_mlp = nn.Sequential(*create_mlp(self.state_shape[0], output_dim, net_arch, state_mlp_activation_fn)) # 创建状态 MLP cprint(f"[DP3Encoder] output dim: {self.n_output_channels}", "red") # 打印输出通道数
- forward 方法用于根据输入的观察字典 observations 生成编码特征
首先,获取点云数据并检查其形状是否为三维。如果需要下采样,则对点云数据进行预处理
然后,通过点云特征提取器提取点云特征def forward(self, observations: Dict) -> torch.Tensor: # 定义前向传播函数 points = observations[self.point_cloud_key] # 获取点云数据 assert len(points.shape) == 3, cprint(f"point cloud shape: {points.shape}, length should be 3", "red") # 确保点云数据的形状为三维 if self.downsample: # 如果需要下采样 points = self.point_preprocess(points, self.num_points) # 对点云数据进行预处理
接着,获取状态数据并通过状态 MLP 进行编码pn_feat = self.extractor(points) # 提取点云特征
最后,将点云特征和状态特征拼接在一起,返回最终的编码特征state = observations[self.state_key] # 获取状态数据 state_feat = self.state_mlp(state) # 对状态数据进行编码
final_feat = torch.cat([pn_feat, state_feat], dim=-1) # 拼接点云特征和状态特征 return final_feat # 返回最终的编码特征
output_shape 方法返回编码器的输出通道数。
总的来说,iDP3Encoder 类通过点云特征提取器和状态 MLP,实现了对点云数据和状态数据的联合编码,适用于各种深度学习任务
第三部分 基于图像和点云的扩散策略:diffusion_policy_3d/policy
3.1 policy/base_policy.py:基类策略模型
该代码定义了一个名为 BasePolicy 的基类,用于实现策略模型。该类继承自 ModuleAttrMixin,并包含一些方法和接口,用于处理策略模型的基本功能
- 首先,导入了必要的库和模块,包括 Dict 类型提示、torch 和 torch.nn,以及自定义的 ModuleAttrMixin 和 LinearNormalizer 模块
- BasePolicy 类的构造函数接受一个关键字参数 `shape_meta`,该参数在配置文件中定义(例如 `config/task/*_image.yaml`)。然而,构造函数的具体实现并未在代码中展示。
predict_action 方法是一个抽象方法,用于根据输入的观察字典 obs_dict 预测动作
obs_dict 是一个字典,键为字符串,值为形状为 `(B, To, *)` 的张量。该方法的返回值是一个字典,键为字符串,值为形状为 `(B, Ta, Da)` 的张量。由于这是一个抽象方法,具体实现需要在子类中完成,因此在该方法中抛出了 NotImplementedError 异常
reset 方法用于重置状态,对于有状态的策略模型非常重要。该方法在基类中实现为空方法,具体实现可以在子类中覆盖def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # 定义 predict_action 方法,接受一个包含观察数据的字典,返回一个包含动作数据的字典 """ obs_dict: # 观察数据字典 str: B,To,* # 键为字符串,值为形状为 (B, To, *) 的张量 return: B,Ta,Da # 返回形状为 (B, Ta, Da) 的张量 """ raise NotImplementedError() # 抛出未实现的异常
set_normalizer 方法用于设置归一化器 normalizer,该归一化器是 LinearNormalizer 类型。由于没有标准的训练接口,该方法在基类中同样抛出了 NotImplementedError 异常,具体实现需要在子类中完成
总的来说,BasePolicy 类提供了一个策略模型的基本框架,定义了预测动作、重置状态和设置归一化器的方法接口。具体的策略模型需要继承该基类,并实现这些抽象方法
3.2 policy/diffusion_image_policy.py:基于图像的扩散策略
DiffusionImagePolicy 类继承自 BasePolicy,用于实现基于扩散模型的图像策略
3.2.1 __init__
该类的构造函数接受多个参数,包括且不限于
def __init__(self,
shape_meta: dict,
noise_scheduler: DDPMScheduler, // 噪声调度器
horizon, // 时间跨度
n_action_steps, // 动作步数
n_obs_steps, // 观察步数
num_inference_steps=None, // 推理步数
obs_as_global_cond=True, // 是否将观察作为全局条件
crop_shape=(76, 76), // 裁剪形状
diffusion_step_embed_dim=256, // 扩散步嵌入维度
down_dims=(256,512,1024), // 下采样维度
kernel_size=5, // 卷积核大小
n_groups=8, // 组数
condition_type='film', // 条件类型
use_depth=False, // 是否使用深度信息
use_depth_only=False, // 是否仅使用深度信息
obs_encoder: TimmObsEncoder = None, // 观察编码器
# parameters passed to step
**kwargs):
在初始化过程中,解析形状元数据,设置动作和观察的形状,并根据配置创建模型和相关组件
3.2.2 forward:根据输入的观察字典 obs_dict 生成动作
forward 方法用于根据输入的观察字典 obs_dict 生成动作
- 首先,复制输入的观察字典 obs_dict,以避免对原始数据进行修改
def forward(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: # 定义前向传播函数 obs_dict = obs_dict.copy() # 复制观察字典
- 接着,对输入进行归一化处理,并将图像数据的像素值从 0-255 范围缩放到 0-1 范围
如果图像数据的最后一个维度为 3(表示 RGB 图像),则根据图像数据的维度进行维度转换# 归一化输入 nobs = self.normalizer.normalize(obs_dict) # 归一化观察字典 nobs['image'] /= 255.0 # 将图像归一化到 [0, 1] 范围
如果使用深度信息且不只使用深度信息,则将深度信息与图像数据沿着通道维度拼接if nobs['image'].shape[-1] == 3: # 如果图像的最后一个维度为 3 if len(nobs['image'].shape) == 5: # 如果图像的形状长度为 5 nobs['image'] = nobs['image'].permute(0, 1, 4, 2, 3) # 重新排列图像维度 if len(nobs['image'].shape) == 4: # 如果图像的形状长度为 4 nobs['image'] = nobs['image'].permute(0, 3, 1, 2) # 重新排列图像维度
如果只使用深度信息,则将深度信息作为图像数据if self.use_depth and not self.use_depth_only: # 如果使用深度信息但不只使用深度信息 nobs['image'] = torch.cat([nobs['image'], nobs['depth'].unsqueeze(-3)], dim=-3) # 将深度信息添加到图像中
if self.use_depth and self.use_depth_only: # 如果仅使用深度信息 nobs['image'] = nobs['depth'].unsqueeze(-3) # 将深度信息作为图像
- 接下来,从归一化后的观察字典中获取一个值,并提取其形状信息,包括批量大小 B 和观察步数 To
然后,设置时间跨度 T、动作维度 Da、观察特征维度 Do 和观察步数 Tovalue = next(iter(nobs.values())) # 获取观察字典中的第一个值 B, To = value.shape[:2] # 获取批量大小和观察步数
构建输入数据时,获取设备信息 device 和数据类型 dtypeT = self.horizon # 设置时间跨度 Da = self.action_dim # 设置动作维度 Do = self.obs_feature_dim # 设置观察特征维度 To = self.n_obs_steps # 设置观察步数
处理不同的观察传递方式时,初始化局部条件 local_cond 和全局条件 global_cond# 构建输入 device = self.device # 获取设备 dtype = self.dtype # 获取数据类型
通过全局特征进行条件处理时,使用 dict_apply 函数对观察数据进行处理,并通过观察编码器 self.obs_encoder 提取观察特征# 处理不同的观察传递方式 local_cond = None # 局部条件 global_cond = None # 全局条件
将提取的观察特征重新调整形状为 `(B, Do)`,并赋值给 global_cond# 通过全局特征进行条件处理 # 获取前 n_obs_steps 步的观察数据 this_nobs = dict_apply(nobs, lambda x: x[:,:self.n_obs_steps,...]) # 编码观察数据 nobs_features = self.obs_encoder(this_nobs)
创建一个空的动作数据张量 cond_data 和一个全为 `False` 的掩码张量 cond_mask# 重新调整形状为 B, Do global_cond = nobs_features.reshape(B, -1) # 重新调整观察特征的形状
然后,调用 conditional_sample 方法进行采样(下一节 会解释该方法),传入动作数据、掩码、局部条件和全局条件等参数# 创建空的动作数据 cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype) # 创建空的动作数据张量 cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) # 创建空的动作掩码张量
采样完成后,对预测的动作数据进行反归一化处理# 运行采样 # 调用 conditional_sample 方法进行采样 nsample = self.conditional_sample( cond_data, cond_mask, local_cond=local_cond, global_cond=global_cond, **self.kwargs)
最后,从预测的动作数据中提取所需的动作步数,并返回最终的动作# 反归一化预测 naction_pred = nsample[...,:Da] # 获取采样结果中的动作预测 action_pred = self.normalizer['action'].unnormalize(naction_pred) # 反归一化动作预测
# 获取动作 start = To - 1 # 设置起始步数 end = start + self.n_action_steps # 设置结束步数 action = action_pred[:,start:end] # 获取动作预测结果 # 获取预测结果 return action # 返回动作预测结果
通过这些步骤,forward 方法实现了从输入观察数据生成动作的过程,适用于基于扩散模型的图像策略
3.2.3 conditional_sample:给定条件下的采样
conditional_sample 方法用于在给定条件下生成样本
- 该方法接受多个参数,具体如下所示
def conditional_sample(self, # 定义 conditional_sample 方法 condition_data, condition_mask, # 接受条件数据和条件掩码 local_cond=None, global_cond=None, # 接受局部条件和全局条件,默认值为 None generator=None, # 接受随机数生成器,默认值为 None # 此外,还可以传递其他关键字参数 kwargs 给调度器的 step 方法 **kwargs ):
- 首先,方法获取模型 self.model 和噪声调度器 self.noise_scheduler
model = self.model # 获取模型 scheduler = self.noise_scheduler # 获取噪声调度器
- 然后,使用 torch.randn 函数生成一个与 condition_data 形状相同的随机轨迹张量 trajectory,并指定数据类型、设备和随机数生成器
trajectory = torch.randn( # 生成一个与条件数据形状相同的随机轨迹张量 size=condition_data.shape, # 形状与条件数据相同 dtype=condition_data.dtype, # 数据类型与条件数据相同 device=condition_data.device, # 设备与条件数据相同 generator=generator) # 使用指定的随机数生成器
- 接下来,设置调度器的时间步数 scheduler.set_timesteps(self.num_inference_steps)。在每个时间步 t 中
# 设置时间步数 scheduler.set_timesteps(self.num_inference_steps)
- 首先应用条件数据,将 condition_data 中满足条件掩码 condition_mask 的部分赋值给轨迹张量 trajectory
然后,使用模型预测输出 model_output,传入当前轨迹、时间步 t、局部条件 local_cond 和全局条件 global_cond# 遍历调度器的时间步数 for t in scheduler.timesteps: # 1. 应用条件 # 将条件数据中满足条件掩码的部分赋值给轨迹张量 trajectory[condition_mask] = condition_data[condition_mask]
接着,调用调度器的 step 方法,计算前一个时间步的样本 `x_t-1`,并更新轨迹张量 trajectory# 2. 预测模型输出 model_output = model(trajectory, t, # 使用模型预测输出 local_cond=local_cond, global_cond=global_cond) # 传入当前轨迹、时间步、局部条件和全局条件
# 3. 计算前一个时间步的样本:x_t -> x_t-1 trajectory = scheduler.step( # 调用调度器的 step 方法 model_output, t, trajectory, # 传入模型输出、时间步和当前轨迹 generator=generator, # 使用指定的随机数生成器 # **kwargs ).prev_sample # 获取前一个时间步的样本
- 最后,确保条件数据被强制应用,再次将 condition_data 中满足条件掩码 condition_mask 的部分赋值给轨迹张量 trajectory
# 最后确保条件被强制应用 # 再次将条件数据中满足条件掩码的部分赋值给轨迹张量 trajectory[condition_mask] = condition_data[condition_mask]
- 方法返回最终生成的轨迹张量 trajectory
return trajectory # 返回最终生成的轨迹张量
通过这些步骤,conditional_sample 方法实现了在给定条件下的样本生成过程,适用于基于扩散模型的图像策略
3.2.4 predict_action:根据输入的观察字典obs_dict预测动作
predict_action 方法用于根据输入的观察字典预测动作,该方法与 forward 方法类似
- 首先对输入进行归一化处理,并根据配置处理图像和深度信息
- 然后,构建输入数据,包括局部和全局条件
- 通过调用 conditional_sample 方法进行采样,得到未归一化的动作预测,并将其反归一化,返回最终的动作和动作预测结果
3.2.5 compute_loss:计算给定批次数据的损失
set_normalizer 方法用于设置归一化器 normalizer,通过加载归一化器的状态字典实现
compute_loss 方法用于计算给定批次数据的损失
- 首先,对输入进行归一化处理,并根据配置处理图像和深度信息
- 然后,构建输入数据,包括局部和全局条件。生成掩码,并添加噪声到轨迹中。应用条件数据,预测模型输出,并根据调度器的配置计算目标
- 最后,计算均方误差损失,并返回损失值
总的来说,DiffusionImagePolicy 类通过扩散模型和条件采样,实现了基于图像的策略生成和训练
3.3 policy/diffusion_pointcloud_policy.py:基于点云的扩散策略(与3.2节有相似)
DiffusionPointcloudPolicy 类继承自 BasePolicy,用于实现基于扩散模型的点云策略
3.3.1 __init__
该类的构造函数接受多个参数,包括
def __init__(self,
shape_meta: dict,
noise_scheduler: DDPMScheduler, // 噪声调度器
horizon, // 时间跨度
n_action_steps, // 动作步数
n_obs_steps, // 观察步数
num_inference_steps=None, // 推理步数
obs_as_global_cond=True, // 是否将观察作为全局条件
diffusion_step_embed_dim=256, // 扩散步嵌入维度
down_dims=(256,512,1024), // 下采样维度
kernel_size=5, // 卷积核大小
n_groups=8, // 组数
condition_type="film", // 条件类型
use_down_condition=True, // 是否使用下采样条件
use_mid_condition=True, // 是否使用中间条件
use_up_condition=True, // 是否使用上采样条件
use_pc_color=False, // 是否使用点云颜色
pointnet_type="pointnet", // 点网类型
pointcloud_encoder_cfg=None, // 点云编码器配置
point_downsample=False, // 是否对点云进行下采样
在初始化过程中,解析形状元数据,设置动作和观察的形状,并根据配置创建模型和相关组件。
3.3.2 forward:根据输入的观察字典 obs_dict 生成动作
forward 方法用于根据输入的观察字典 obs_dict 生成动作
- 首先,对输入进行归一化处理,并根据配置处理点云和颜色信息
- 然后,构建输入数据,包括局部和全局条件
- 通过调用 conditional_sample 方法进行采样,得到未归一化的动作预测,并将其反归一化,返回最终的动作
3.3.3 conditional_sample:在给定条件下进行采样
conditional_sample 方法用于在给定条件下进行采样
- 首先,生成一个随机的轨迹张量,并设置调度器的时间步数
- 在每个时间步中,应用条件数据,预测模型输出,并计算前一个时间步的样本
- 最后,确保条件数据被强制应用,返回最终的轨迹
3.3.4 predict_action:根据输入的观察字典 obs_dict 生成动作(与forward类似)
predict_action 方法用于根据输入的观察字典 obs_dict 生成动作,该方法与上面的forward类似
- 首先,对输入的观察字典进行归一化处理
对于点云数据,如果不使用点云颜色,则只保留前三个通道(通常是坐标信息);# 定义 predict_action 方法,接受一个包含观察数据的字典,返回一个包含动作数据的字典 def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ obs_dict: 必须包含 "obs" 键 result: 必须包含 "action" 键 """ # 归一化输入 nobs = self.normalizer.normalize(obs_dict) # 对观察数据进行归一化处理
如果使用点云颜色,则将颜色信息归一化到 0-1 范围if not self.use_pc_color: # 如果不使用点云颜色 nobs['point_cloud'] = nobs['point_cloud'][..., :3] # 只保留前三个通道(通常是坐标信息)
if self.use_pc_color: # 如果使用点云颜色 nobs['point_cloud'][..., 3:] /= 255.0 # 将颜色信息归一化到 0-1 范围
- 接下来,从归一化后的观察字典中获取一个值,并提取其形状信息,包括批量大小 B 和观察步数 To
然后,设置时间跨度 T、动作维度 Da、观察特征维度 Do 和观察步数 Tovalue = next(iter(nobs.values())) # 获取归一化后的观察数据中的一个值 B, To = value.shape[:2] # 提取批量大小和观察步数
构建输入数据时,获取设备信息 device 和数据类型 dtypeT = self.horizon # 设置时间跨度 Da = self.action_dim # 设置动作维度 Do = self.obs_feature_dim # 设置观察特征维度 To = self.n_obs_steps # 设置观察步数
处理不同的观察传递方式时,初始化局部条件 local_cond 和全局条件 global_cond# 构建输入 device = self.device # 获取设备信息 dtype = self.dtype # 获取数据类型
如果将观察作为全局条件 obs_as_global_cond,则通过全局特征进行条件处理。使用 dict_apply 函数对观察数据进行处理,并通过观察编码器 self.obs_encoder 提取观察特征# 处理不同的观察传递方式 local_cond = None # 初始化局部条件 global_cond = None # 初始化全局条件
根据条件类型 condition_type,将提取的观察特征调整形状为 `(B, self.n_obs_steps, -1)` 或 `(B, -1)`,并赋值给 global_condif self.obs_as_global_cond: # 如果将观察作为全局条件 # 通过全局特征进行条件处理 this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) # 对观察数据进行处理 nobs_features = self.obs_encoder(this_nobs) # 提取观察特征
创建一个空的动作数据张量 cond_data 和一个全为 `False` 的掩码张量 cond_maskif "cross_attention" in self.condition_type: # 如果条件类型为 "cross_attention" # 作为序列处理 global_cond = nobs_features.reshape(B, self.n_obs_steps, -1) # 将观察特征调整形状为 (B, self.n_obs_steps, -1) else: # 重新调整形状为 (B, Do) global_cond = nobs_features.reshape(B, -1) # 将观察特征调整形状为 (B, -1)
如果不将观察作为全局条件,则通过填充的方式进行条件处理,即使用 dict_apply 函数对观察数据进行处理,并通过观察编码器提取观察特征# 空的动作数据 # 创建一个空的动作数据张量 cond_data = torch.zeros(size=(B, T, Da), device=device, dtype=dtype) # 创建一个全为 False 的掩码张量 cond_mask = torch.zeros_like(cond_data, dtype=torch.bool)
将提取的观察特征调整形状为 `(B, To, -1)`,并将其赋值给 cond_data 的相应部分,同时更新 cond_maskelse: # 通过填充进行条件处理 this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) # 对观察数据进行处理 nobs_features = self.obs_encoder(this_nobs) # 提取观察特征
# 重新调整形状为 (B, T, Do) # 将观察特征调整形状为 (B, To, -1) nobs_features = nobs_features.reshape(B, To, -1) # 创建一个空的动作数据张量 cond_data = torch.zeros(size=(B, T, Da+Do), device=device, dtype=dtype) # 创建一个全为 False 的掩码张量 cond_mask = torch.zeros_like(cond_data, dtype=torch.bool) # 将观察特征赋值给动作数据张量的相应部分 cond_data[:,:To,Da:] = nobs_features # 更新掩码张量 cond_mask[:,:To,Da:] = True
- 接下来,调用 conditional_sample 方法进行采样,传入动作数据、掩码、局部条件和全局条件等参数
采样完成后,对预测的动作数据进行反归一化处理# 运行采样 # 调用 conditional_sample 方法进行采样 nsample = self.conditional_sample( cond_data, # 动作数据 cond_mask, # 掩码 local_cond=local_cond, # 局部条件 global_cond=global_cond, # 全局条件 **self.kwargs) # 其他关键字参数
# 反归一化预测 naction_pred = nsample[...,:Da] # 获取预测的动作数据 action_pred = self.normalizer['action'].unnormalize(naction_pred) # 对预测的动作数据进行反归一化处理
- 最后,从预测的动作数据中提取所需的动作步数
并返回最终的动作和动作预测结果# 获取动作 start = To - 1 # 设置起始步数 end = start + self.n_action_steps # 设置结束步数 action = action_pred[:,start:end] # 从预测的动作数据中提取所需的动作步数
# 获取预测结果 result = { 'action': action, # 动作 'action_pred': action_pred, # 动作预测 } return result # 返回最终的动作和动作预测结果
通过这些步骤,predict_action 方法实现了从输入观察数据生成动作的过程,适用于基于扩散模型的点云策略
3.3.5 compute_loss:计算给定批次数据的损失
set_normalizer 方法用于设置归一化器 normalizer,通过加载归一化器的状态字典实现
compute_loss 方法用于计算给定批次数据的损失
- 首先,对输入进行归一化处理,并根据配置处理点云和颜色信息
- 然后,构建输入数据,包括局部和全局条件。生成掩码,并添加噪声到轨迹中。应用条件数据,预测模型输出,并根据调度器的配置计算目标
- 最后,计算均方误差损失,并返回损失值和损失字典
总的来说,DiffusionPointcloudPolicy 类通过扩散模型和条件采样,实现了基于点云的策略生成和训练
// 待更
原文地址:https://blog.csdn.net/v_JULY_v/article/details/145183110
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!