RFdiffusion Denoise类解读
Denoise
类为蛋白质结构扩散模型的实现提供了核心功能,通过灵活的噪声调度、潜力场引导和子结构对齐,使得模型可以生成物理合理的结构序列,并在每个时间步迭代更新蛋白质的坐标和结构信息。
源代码:
def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale=1.0):
"""
get_next_frames gets updated frames using IGSO(3) + score_based reverse diffusion.
based on self.so3_type use score based update.
Generate frames at t-1
Rather than generating random rotations (as occurs during forward process), calculate rotation between xt and px0
Args:
xt: noised coordinates of shape [L, 14, 3]
px0: prediction of coordinates at t=0, of shape [L, 14, 3]
t: integer time step
diffuser: Diffuser object for reverse igSO3 sampling
so3_type: The type of SO3 noising being used ('igso3')
diffusion_mask: of shape [L] of type bool, True means not to be
updated (e.g. mask is true for motif residues)
noise_scale: scale factor for the noise added (IGSO3 only)
Returns:
backbone coordinates for step x_t-1 of shape [L, 3, 3]
"""
N_0 = px0[None, :, 0, :]
Ca_0 = px0[None, :, 1, :]
C_0 = px0[None, :, 2, :]
R_0, Ca_0 = rigid_from_3_points(N_0, Ca_0, C_0)
N_t = xt[None, :, 0, :]
Ca_t = xt[None, :, 1, :]
C_t = xt[None, :, 2, :]
R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t)
# this must be to normalize them or something
R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix()
R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix()
L = R_t.shape[0]
all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy()
# Sample next frame for each residue
if so3_type == "igso3":
# don't do calculations on masked positions since they end up as identity matrix
all_rot_transitions[
~diffusion_mask
] = diffuser.so3_diffuser.reverse_sample_vectorized(
R_t[~diffusion_mask],
R_0[~diffusion_mask],
t,
noise_level=noise_scale,
mask=None,
return_perturb=True,
)
else:
assert False, "so3 diffusion type %s not implemented" % so3_type
all_rot_transitions = all_rot_transitions[:, None, :, :]
# Apply the interpolated rotation matrices to the coordinates
next_crds = (
np.einsum(
"lrij,laj->lrai",
all_rot_transitions,
xt[:, :3, :] - Ca_t.squeeze()[:, None, ...].numpy(),
)
+ Ca_t.squeeze()[:, None, None, ...].numpy()
)
# (L,3,3) set of backbone coordinates with slight rotation
return next_crds.squeeze(1)
def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
"""
Given xt, predicted x0 and the timestep t, give mu of x(t-1)
Assumes t is 0 indexed
"""
# sigma is predefined from beta. Often referred to as beta tilde t
t_idx = t - 1
sigma = (
(1 - alphabar_schedule[t_idx - 1]) / (1 - alphabar_schedule[t_idx])
) * beta_schedule[t_idx]
xt_ca = xt[:, 1, :]
px0_ca = px0[:, 1, :]
a = (
(torch.sqrt(alphabar_schedule[t_idx - 1] + eps) * beta_schedule[t_idx])
/ (1 - alphabar_schedule[t_idx])
) * px0_ca
b = (
(
torch.sqrt(1 - beta_schedule[t_idx] + eps)
* (1 - alphabar_schedule[t_idx - 1])
)
/ (1 - alphabar_schedule[t_idx])
) * xt_ca
mu = a + b
return mu, sigma
def get_next_ca(
xt,
px0,
t,
diffusion_mask,
crd_scale,
beta_schedule,
alphabar_schedule,
noise_scale=1.0,
):
"""
Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1)
Parameters:
原文地址:https://blog.csdn.net/qq_27390023/article/details/144300070
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!