自学内容网 自学内容网

AF3 Transition和ConditionedTransitionBlock类解读

AlphaFold3的Transition/ ConditionedTransitionBlock类
1. Transition 类
  • 作用

    • 提升模型的表达能力,通过扩展和收缩通道,学习不同层次的特征。
    • 作为残差块的一部分,帮助模型捕捉更复杂的序列-结构映射关系。
  • 生物学意义

    • 帮助捕捉蛋白质序列中局部和全局特征,为后续模块提供更丰富的特征。
2. ConditionedTransitionBlock 类
  • 作用

    • 通过条件张量 s 自适应调整输入特征 a 的分布。
    • 门控机制控制特征的更新量,避免过拟合,增强模型的条件依赖能力。
  • 生物学意义

    • 模拟蛋白质在不同环境或上下文(如特定配体、化学环境)下的行为。
    • 允许模型根据条件信息动态调整特征表示,捕捉更细粒度的结构变化。

源代码:

"""Transition blocks in AlphaFold3"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from src.models.components.primitives import AdaLN
from src.models.components.primitives import Linear, LinearNoBias


class Transition(nn.Module):
    """A transition block for a residual update."""
    def __init__(self, input_dim: int, n: int = 4):
        """
        Args:
            input_dim:
                Channels of the input tensor
            n:
                channel expansion factor for hidden dimensions
        """
        super(Transition, self).__init__()
        self.layer_norm = LayerNorm(input_dim)
        self.linear_1 = LinearNoBias(input_dim, n * input_dim, init='relu')
        self.linear_2 = LinearNoBias(input_dim, n * input_dim, init='default')
        self.output_linear = LinearNoBias(input_dim * n, input_dim, init='final')

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.silu(self.linear_1(x)) * self.linear_2(x)
        return self.output_linear(x)


class ConditionedTransitionBlock(nn.Module):
    """SwiGLU transition block with adaptive layer norm."""
    def __init__(self,
                 input_dim: int,
                 n: int = 2):
        """
        Args:
            input_dim:
                Channels of the input tensor
            n:
                channel expansion factor for hidden dimensions
        """
        super(ConditionedTransitionBlock, self).__init__()
        self.ada_ln = AdaLN(input_dim)
        self.hidden_gating_linear = LinearNoBias(input_dim, n * input_dim, init='relu')
        self.hidden_linear = LinearNoBias(input_dim, n * input_dim, init='default')
        self.output_linear = L

原文地址:https://blog.csdn.net/qq_27390023/article/details/145097581

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!