自学内容网 自学内容网

RFdiffusion Denoise类解读

Denoise 类为蛋白质结构扩散模型的实现提供了核心功能,通过灵活的噪声调度、潜力场引导和子结构对齐,使得模型可以生成物理合理的结构序列,并在每个时间步迭代更新蛋白质的坐标和结构信息。

源代码:


def get_next_frames(xt, px0, t, diffuser, so3_type, diffusion_mask, noise_scale=1.0):
    """
    get_next_frames gets updated frames using IGSO(3) + score_based reverse diffusion.


    based on self.so3_type use score based update.

    Generate frames at t-1
    Rather than generating random rotations (as occurs during forward process), calculate rotation between xt and px0

    Args:
        xt: noised coordinates of shape [L, 14, 3]
        px0: prediction of coordinates at t=0, of shape [L, 14, 3]
        t: integer time step
        diffuser: Diffuser object for reverse igSO3 sampling
        so3_type: The type of SO3 noising being used ('igso3')
        diffusion_mask: of shape [L] of type bool, True means not to be
            updated (e.g. mask is true for motif residues)
        noise_scale: scale factor for the noise added (IGSO3 only)

    Returns:
        backbone coordinates for step x_t-1 of shape [L, 3, 3]
    """
    N_0 = px0[None, :, 0, :]
    Ca_0 = px0[None, :, 1, :]
    C_0 = px0[None, :, 2, :]

    R_0, Ca_0 = rigid_from_3_points(N_0, Ca_0, C_0)

    N_t = xt[None, :, 0, :]
    Ca_t = xt[None, :, 1, :]
    C_t = xt[None, :, 2, :]

    R_t, Ca_t = rigid_from_3_points(N_t, Ca_t, C_t)

    # this must be to normalize them or something
    R_0 = scipy_R.from_matrix(R_0.squeeze().numpy()).as_matrix()
    R_t = scipy_R.from_matrix(R_t.squeeze().numpy()).as_matrix()

    L = R_t.shape[0]
    all_rot_transitions = np.broadcast_to(np.identity(3), (L, 3, 3)).copy()
    # Sample next frame for each residue
    if so3_type == "igso3":
        # don't do calculations on masked positions since they end up as identity matrix
        all_rot_transitions[
            ~diffusion_mask
        ] = diffuser.so3_diffuser.reverse_sample_vectorized(
            R_t[~diffusion_mask],
            R_0[~diffusion_mask],
            t,
            noise_level=noise_scale,
            mask=None,
            return_perturb=True,
        )
    else:
        assert False, "so3 diffusion type %s not implemented" % so3_type

    all_rot_transitions = all_rot_transitions[:, None, :, :]

    # Apply the interpolated rotation matrices to the coordinates
    next_crds = (
        np.einsum(
            "lrij,laj->lrai",
            all_rot_transitions,
            xt[:, :3, :] - Ca_t.squeeze()[:, None, ...].numpy(),
        )
        + Ca_t.squeeze()[:, None, None, ...].numpy()
    )

    # (L,3,3) set of backbone coordinates with slight rotation
    return next_crds.squeeze(1)


def get_mu_xt_x0(xt, px0, t, beta_schedule, alphabar_schedule, eps=1e-6):
    """
    Given xt, predicted x0 and the timestep t, give mu of x(t-1)
    Assumes t is 0 indexed
    """
    # sigma is predefined from beta. Often referred to as beta tilde t
    t_idx = t - 1
    sigma = (
        (1 - alphabar_schedule[t_idx - 1]) / (1 - alphabar_schedule[t_idx])
    ) * beta_schedule[t_idx]

    xt_ca = xt[:, 1, :]
    px0_ca = px0[:, 1, :]

    a = (
        (torch.sqrt(alphabar_schedule[t_idx - 1] + eps) * beta_schedule[t_idx])
        / (1 - alphabar_schedule[t_idx])
    ) * px0_ca
    b = (
        (
            torch.sqrt(1 - beta_schedule[t_idx] + eps)
            * (1 - alphabar_schedule[t_idx - 1])
        )
        / (1 - alphabar_schedule[t_idx])
    ) * xt_ca

    mu = a + b

    return mu, sigma


def get_next_ca(
    xt,
    px0,
    t,
    diffusion_mask,
    crd_scale,
    beta_schedule,
    alphabar_schedule,
    noise_scale=1.0,
):
    """
    Given full atom x0 prediction (xyz coordinates), diffuse to x(t-1)

    Parameters:

        

原文地址:https://blog.csdn.net/qq_27390023/article/details/144300070

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!