【深度学习】提取模块的backbone是相同的,但是需要多个头来做分类,应该如何做
x= self.seq(x)
假设经过这一步,输出了需要分类的特征。
head1 = nn.Linear(linear_size, n_classes)
head2 = nn.Linear(linear_size, n_classes)
.........
然后,就可以做分类了:
x = self.seq(x)
head1_out = head1(x)
head2_out = head2(x)
......
问题:
-
那可不可以同时复用head1的而不用new这么多head?
若是你的head层的权重是不被调整更新的那可以,(但应该把它过滤掉,设置为不随训练更新,参考transformer的位置编码) -
head多了,那就一个个手动改代码么?
可以做成pipeline
# head_num 为头的数量,这个示例是 linear_size 和 n_classes是相同的,你做个列表,就可以自动提取不相同了。
nn.ModuleList([nn.Linear(linear_size3, n_classes) for _ in range(head_num)])
原文地址:https://blog.csdn.net/weixin_40293999/article/details/143758627
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!