自学内容网 自学内容网

RoseTTAFold QueryEncoding类解读

QueryEncoding 类用于在输入张量 x 上添加一种查询序列的特殊编码。这里的查询编码将第一个序列标记为查询序列,并将其与其他序列区分开。以下是代码中的细节和每一步的作用。

源码:

class QueryEncoding(nn.Module):
    def __init__(self, d_model):
        super(QueryEncoding, self).__init__()
        self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
    
    def forward(self, x):
        B, N, L, K = x.shape
        idx = torch.ones((B, N, L), device=x.device).long()
        idx[:,0,:] = 0 # first sequence is the query
        x = x + self.pe(idx)
        return x 

代码解读:

class QueryEncoding(nn.Module):
    def __init__(self, d_model):
        super(QueryEncoding, self).__init__()
        self.pe = nn.Embedding(2, d_model) # (0 for query, 1 for others)
    
    def forward(self, x):
        B, N, L, K = x.shape
        idx = torch.ones((B, N, L), device=x.device).long()
        i

原文地址:https://blog.csdn.net/qq_27390023/article/details/143750334

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!