RoseTTAFold QueryEncoding类解读
QueryEncoding
类用于在输入张量 x
上添加一种查询序列的特殊编码。这里的查询编码将第一个序列标记为查询序列,并将其与其他序列区分开。以下是代码中的细节和每一步的作用。
源码:
class QueryEncoding(nn.Module):
def __init__(self, d_model):
super(QueryEncoding, self).__init__()
self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
def forward(self, x):
B, N, L, K = x.shape
idx = torch.ones((B, N, L), device=x.device).long()
idx[:,0,:] = 0 # first sequence is the query
x = x + self.pe(idx)
return x
代码解读:
class QueryEncoding(nn.Module):
def __init__(self, d_model):
super(QueryEncoding, self).__init__()
self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
def forward(self, x):
B, N, L, K = x.shape
idx = torch.ones((B, N, L), device=x.device).long()
i
原文地址:https://blog.csdn.net/qq_27390023/article/details/143750334
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!