RFdiffusion get_init_xyz函数解读
get_init_xyz
是一个初始化坐标的函数,主要目标是为输入的原子坐标 xyz_t
提供初始值,同时确保初始化的坐标在数值上具有一定的稳定性和物理意义。
源代码:
# ideal N, CA, C initial coordinates
init_N = torch.tensor([-0.5272, 1.3593, 0.000]).float()
init_CA = torch.zeros_like(init_N)
init_C = torch.tensor([1.5233, 0.000, 0.000]).float()
INIT_CRDS = torch.full((27, 3), np.nan)
INIT_CRDS[:3] = torch.stack((init_N, init_CA, init_C), dim=0) # (3,3)
def get_init_xyz(xyz_t):
# input: xyz_t (B, T, L, 14, 3)
# ouput: xyz (B, T, L, 14, 3)
B, T, L = xyz_t.shape[:3]
init = INIT_CRDS.to(xyz_t.device).reshape(1,1,1,27,3).repeat(B,T,L,1,1)
if torch.isnan(xyz_t).all():
return init
mask = torch.isnan(xyz_t[:,:,:,:3]).any(dim=-1).any(dim=-1) # (B, T, L)
#
center_CA = ((~mask[:,:,:,None]) * torch.nan_to_num(xyz_t[:,:,:,1,:])).sum(dim=2) / ((~mask[:,:,:,None]).sum(dim=2)+1e-4) # (B, T, 3)
xyz_t = xyz_t - center_CA.view(B,T,1,1,3)
#
idx_s = list()
for i_b in range(B):
for i_T in range(T):
if mask[i_b, i_T].all():
continue
exist_in_templ = torch.where(~mask[i_b, i_T])[0] # (L_sub)
seqmap = (torch.arange(L, device=xyz_t.device)[:,None] - exist_in_templ[None,:]).abs() # (L, L_sub)
seqmap = torch.argmin(seqmap, dim=-1) # (L)
idx = torch.gather(exist_in_templ, -1, seqmap) # (L)
offset_CA = torch.gather(xyz_t[i_b, i_T, :, 1, :], 0, idx.reshape(L,1).expand(-1,3))
init[i_b,i_T] += offset_CA.reshape(L,1,3)
#
xyz = torch.where(mask.view(B, T, L, 1, 1), init, xyz_t)
return xyz
代码解读:
1. 理想化主链原子坐标定义
init_N = torch.tensor([-0.5272, 1.3593, 0.000]).float()
init_CA = torch.zeros_like(init_N)
init_C = torch.tensor([1.5233, 0.000, 0.000]).float()
INIT_CRDS = torch.full((27, 3), np.nan)
INIT_CRDS[:3] = torch.stack((init_N, init_CA, init_C), dim=0) # (3,3)
-
定义了理想化的氮原子 (
N
)、α-碳原子 (CA
) 和羰基碳原子 (C
) 的初始坐标:N
:[-0.5272, 1.3593, 0.000]CA
:[0.000, 0.000, 0.000](中心)C
:[1.5233, 0.000, 0.000]
-
作用:
- 这些坐标代表了蛋白质主链原子在一个标准几何构象下的理想化位置(假设理想肽键和角度)。
-
INIT_CRDS
:- 初始化一个
(27, 3)
的张量,表示所有可能原子类型的坐标。 - 前三个元素分别对应主链原子
N
、CA
和C
。
- 初始化一个
2. 初始化张量 init
- 根据输入的批量大小
B
,时间步数T
,和序列长度L
,扩展理想化坐标到输入张量的形状(B, T, L, 27, 3)
。 - 作用:
- 为每个序列中的每个残基创建初始化坐标(使用理想化的几何构象)。
3. 检查输入是否全为 NaN
if torch.isnan(xyz_t).all():
retu
原文地址:https://blog.csdn.net/qq_27390023/article/details/144450706
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!