AF3 Transition和ConditionedTransitionBlock类解读
AlphaFold3的Transition/ ConditionedTransitionBlock类
1. Transition 类
-
作用:
- 提升模型的表达能力,通过扩展和收缩通道,学习不同层次的特征。
- 作为残差块的一部分,帮助模型捕捉更复杂的序列-结构映射关系。
-
生物学意义:
- 帮助捕捉蛋白质序列中局部和全局特征,为后续模块提供更丰富的特征。
2. ConditionedTransitionBlock 类
-
作用:
- 通过条件张量
s
自适应调整输入特征a
的分布。 - 门控机制控制特征的更新量,避免过拟合,增强模型的条件依赖能力。
- 通过条件张量
-
生物学意义:
- 模拟蛋白质在不同环境或上下文(如特定配体、化学环境)下的行为。
- 允许模型根据条件信息动态调整特征表示,捕捉更细粒度的结构变化。
源代码:
"""Transition blocks in AlphaFold3"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LayerNorm
from src.models.components.primitives import AdaLN
from src.models.components.primitives import Linear, LinearNoBias
class Transition(nn.Module):
"""A transition block for a residual update."""
def __init__(self, input_dim: int, n: int = 4):
"""
Args:
input_dim:
Channels of the input tensor
n:
channel expansion factor for hidden dimensions
"""
super(Transition, self).__init__()
self.layer_norm = LayerNorm(input_dim)
self.linear_1 = LinearNoBias(input_dim, n * input_dim, init='relu')
self.linear_2 = LinearNoBias(input_dim, n * input_dim, init='default')
self.output_linear = LinearNoBias(input_dim * n, input_dim, init='final')
def forward(self, x):
x = self.layer_norm(x)
x = F.silu(self.linear_1(x)) * self.linear_2(x)
return self.output_linear(x)
class ConditionedTransitionBlock(nn.Module):
"""SwiGLU transition block with adaptive layer norm."""
def __init__(self,
input_dim: int,
n: int = 2):
"""
Args:
input_dim:
Channels of the input tensor
n:
channel expansion factor for hidden dimensions
"""
super(ConditionedTransitionBlock, self).__init__()
self.ada_ln = AdaLN(input_dim)
self.hidden_gating_linear = LinearNoBias(input_dim, n * input_dim, init='relu')
self.hidden_linear = LinearNoBias(input_dim, n * input_dim, init='default')
self.output_linear = L
原文地址:https://blog.csdn.net/qq_27390023/article/details/145097581
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!