atttention1111
import math import pdb import numpy as np import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as F def import_class(name): components = name.split('.') mod = __import__(components[0]) for comp in components[1:]: mod = getattr(mod, comp) return mod def conv_branch_init(conv, branches): weight = conv.weight n = weight.size(0) k1 = weight.size(1) k2 = weight.size(2) nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches))) nn.init.constant_(conv.bias, 0) def conv_init(conv): if conv.weight is not None: nn.init.kaiming_normal_(conv.weight, mode='fan_out') if conv.bias is not None: nn.init.constant_(conv.bias, 0) def bn_init(bn, scale): nn.init.constant_(bn.weight, scale) nn.init.constant_(bn.bias, 0) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: if hasattr(m, 'weight'): nn.init.kaiming_normal_(m.weight, mode='fan_out') if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): nn.init.constant_(m.bias, 0) elif classname.find('BatchNorm') != -1: if hasattr(m, 'weight') and m.weight is not None: m.weight.data.normal_(1.0, 0.02) if hasattr(m, 'bias') and m.bias is not None: m.bias.data.fill_(0) class TemporalConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=5, stride=1, dilation=1): super(TemporalConv, self).__init__() pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2 self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), stride=(stride, 1), dilation=(dilation, 1)) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): x = self.conv(x) x = self.bn(x) return x import torch import torch.nn as nn import torch.nn.functional as F import math class SelfAttention(nn.Module): def __init__(self, in_channels, out_channels, num_heads=4): super(SelfAttention, self).__init__() assert out_channels % num_heads == 0, "out_channels 必须是 num_heads 的倍数" self.num_heads = num_heads self.head_dim = out_channels // num_heads self.conv_q = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv_k = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv_v = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.fc_out = nn.Conv2d(out_channels, out_channels, kernel_size=1) def forward(self, x): N, C, T, V = x.shape q = self.conv_q(x) k = self.conv_k(x) v = self.conv_v(x) q = q.view(N, self.num_heads, self.head_dim, T, V) k = k.view(N, self.num_heads, self.head_dim, T, V) v = v.view(N, self.num_heads, self.head_dim, T, V) q = q.permute(0, 1, 3, 4, 2).contiguous() k = k.permute(0, 1, 3, 4, 2).contiguous() attention_scores = torch.matmul(q, k.permute(0, 1, 4, 3, 2)) attention_scores = attention_scores / math.sqrt(self.head_dim) attention_weights = F.softmax(attention_scores, dim=-1) v = v.permute(0, 1, 3, 4, 2).contiguous() out = torch.matmul(attention_weights, v) out = out.view(N, self.num_heads * self.head_dim, T, V) out = self.fc_out(out) return out class MultiScale_TemporalConv(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilations=[1, 2, 3, 4], residual=True, residual_kernel_size=1): super().__init__() assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches' # Multiple branches of temporal convolution self.num_branches = len(dilations) + 2 branch_channels = out_channels // self.num_branches if type(kernel_size) == list: assert len(kernel_size) == len(dilations) else: kernel_size = [kernel_size] * len(dilations) # Temporal Convolution branches self.branches = nn.ModuleList([ nn.Sequential( nn.Conv2d( in_channels, branch_channels, kernel_size=1, padding=0), nn.BatchNorm2d(branch_channels), nn.ReLU(inplace=True), TemporalConv( branch_channels, branch_channels, kernel_size=ks, stride=stride, dilation=dilation), ) for ks, dilation in zip(kernel_size, dilations) ]) # Additional Max & 1x1 branch self.branches.append(nn.Sequential( nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0), nn.BatchNorm2d(branch_channels), nn.ReLU(inplace=True), nn.MaxPool2d(kernel_size=(3, 1), stride=(stride, 1), padding=(1, 0)), nn.BatchNorm2d(branch_channels) # 为什么还要加bn )) self.branches.append(nn.Sequential( nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride, 1)), nn.BatchNorm2d(branch_channels) )) # Residual connection if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride) # self.selfattention = SelfAttention(out_channels, out_channels) # initialize self.apply(weights_init) def forward(self, x): # Input dim: (N,C,T,V) res = self.residual(x) branch_outs = [] for tempconv in self.branches: out = tempconv(x) branch_outs.append(out) # 这里的是所有的结果concat,dim=1 out = torch.cat(branch_outs, dim=1) # 这里尝试在多尺度时间卷积上加入自注意力机制效果 # out = self.selfattention(out) + out out += res return out class CTRGC(nn.Module): def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1): super(CTRGC, self).__init__() self.in_channels = in_channels self.out_channels = out_channels if in_channels == 3 or in_channels == 9: self.rel_channels = 8 self.mid_channels = 16 else: self.rel_channels = in_channels // rel_reduction self.mid_channels = in_channels // mid_reduction self.conv1 = nn.Conv2d(6, self.rel_channels, kernel_size=1) self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1) self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1) self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1) self.tanh = nn.Tanh() for m in self.modules(): if isinstance(m, nn.Conv2d): conv_init(m) elif isinstance(m, nn.BatchNorm2d): bn_init(m, 1) def forward(self, x, A=None, alpha=1): x1, x2, x3 = self.conv1(x), self.conv2(x), self.conv3(x) graph = self.tanh(x1.mean(-2).unsqueeze(-1) - x2.mean(-2).unsqueeze(-2)) graph = self.conv4(graph) graph_c = graph * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0) # N,C,V,V y = torch.einsum('ncuv,nctv->nctu', graph_c, x3) return y, graph class unit_tcn(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): super(unit_tcn, self).__init__() pad = int((kernel_size - 1) / 2) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), stride=(stride, 1)) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) conv_init(self.conv) bn_init(self.bn, 1) def forward(self, x): x = self.bn(self.conv(x)) return x class unit_gcn(nn.Module): def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True): super(unit_gcn, self).__init__() inter_channels = out_channels // coff_embedding self.inter_c = inter_channels self.out_c = out_channels self.in_c = in_channels self.adaptive = adaptive self.num_subset = A.shape[0] self.convs = nn.ModuleList() for i in range(self.num_subset): self.convs.append(CTRGC(in_channels, out_channels)) if residual: if in_channels != out_channels: self.down = nn.Sequential( nn.Conv2d(in_channels, out_channels, 1), nn.BatchNorm2d(out_channels) ) else: self.down = lambda x: x else: self.down = lambda x: 0 if self.adaptive: self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32))) else: self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False) self.alpha = nn.Parameter(torch.zeros(1)) self.bn = nn.BatchNorm2d(out_channels) self.soft = nn.Softmax(-2) self.relu = nn.ReLU(inplace=True) for m in self.modules(): if isinstance(m, nn.Conv2d): conv_init(m) elif isinstance(m, nn.BatchNorm2d): bn_init(m, 1) bn_init(self.bn, 1e-6) def forward(self, x): y = None graph_list = [] if self.adaptive: A = self.PA else: A = self.A.cuda(x.get_device()) for i in range(self.num_subset): z, graph = self.convs[i](x, A[i], self.alpha) graph_list.append(graph) y = z + y if y is not None else z y = self.bn(y) y += self.down(x) y = self.relu(y) return y, torch.stack(graph_list, 1) class TCN_GCN_unit(nn.Module): def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5, dilations=[1, 2]): super(TCN_GCN_unit, self).__init__() self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive) # self.tcn1 = TemporalConv(out_channels, out_channels, stride=stride) self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations, residual=True) self.relu = nn.ReLU(inplace=True) if not residual: self.residual = lambda x: 0 elif (in_channels == out_channels) and (stride == 1): self.residual = lambda x: x else: self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride) def forward(self, x): z, graph = self.gcn1(x) y = self.relu(self.tcn1(z) + self.residual(x)) return y, graph class Model(nn.Module): def __init__(self, num_class=155, num_point=17, num_person=2, graph=None, graph_args=dict(), in_channels=3, drop_out=0, adaptive=True): super(Model, self).__init__() if graph is None: raise ValueError() else: Graph = import_class(graph) self.graph = Graph(**graph_args) A = self.graph.A # 3,25,25 self.num_class = num_class self.num_point = num_point self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) base_channel = 64 # self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive) # self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive) # self.l3 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive) # self.l4 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive) # self.l5 = TCN_GCN_unit(base_channel, base_channel * 2, A, stride=2, adaptive=adaptive) # self.l6 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive) # self.l7 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive) # self.l8 = TCN_GCN_unit(base_channel * 2, base_channel * 4, A, stride=2, adaptive=adaptive) # self.l9 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive) # self.l10 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive) # self.fc = nn.Linear(base_channel * 4, num_class) self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive) self.l2 = TCN_GCN_unit(base_channel, base_channel, A, adaptive=adaptive) self.l3 = TCN_GCN_unit(base_channel, base_channel * 2, A, stride=2, adaptive=adaptive) self.l4 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive) self.l5 = TCN_GCN_unit(base_channel * 2, base_channel * 2, A, adaptive=adaptive) self.l6 = TCN_GCN_unit(base_channel * 2, base_channel * 4, A, stride=2, adaptive=adaptive) self.l7 = TCN_GCN_unit(base_channel * 4, base_channel * 4, A, adaptive=adaptive) self.fc = nn.Linear(base_channel * 4, num_class) nn.init.normal_(self.fc.weight, 0, math.sqrt(2. / num_class)) bn_init(self.data_bn, 1) if drop_out: self.drop_out = nn.Dropout(drop_out) else: self.drop_out = lambda x: x def partDivison(self, graph): # _, num_joints, _ = graph.size() _, k, u, v = graph.size() # n k u v head = [0, 1, 2, 3, 4, 5, 6] # nose, eyes, and ears left_arm = [5, 7, 9] # arms connections right_arm = [6, 8, 10] # arms connections # arm = [5, 6, 7, 8, 9, 10] torso = [5, 6, 11, 12] # torso connections left_leg = [11, 13, 15] # legs connections right_leg = [12, 14, 16] graph_list = [] part_list = [[head, left_arm, right_arm, torso, left_leg, right_leg]] for part in part_list: part_grah = graph[:, :, :, part].mean(dim=-1, keepdim=True) graph_list.append(part_grah) return torch.cat(graph_list, -1) def forward(self, x): if torch.isnan(x).any() or torch.isinf(x).any(): print("Input data contains NaN or Inf.") if len(x.shape) == 3: N, T, VC = x.shape x = x.view(N, T, self.num_point, -1).permute(0, 3, 1, 2).contiguous().unsqueeze(-1) N, C, T, V, M = x.size() x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) x = self.data_bn(x) x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) # x, _ = self.l1(x) # x, _ = self.l2(x) # x, _ = self.l3(x) # x, _ = self.l4(x) # x, _ = self.l5(x) # x, _ = self.l6(x) # x, _ = self.l7(x) # x, _ = self.l8(x) # x, _ = self.l9(x) # x, graph = self.l10(x) x, _ = self.l1(x) x, _ = self.l2(x) x, _ = self.l3(x) x, _ = self.l4(x) x, _ = self.l5(x) x, _ = self.l6(x) x, graph = self.l7(x) # N*M,C,T,V c_new = x.size(1) x = x.view(N, M, c_new, -1) x = x.mean(3).mean(1) x = self.drop_out(x) graph2 = graph.view(N, M, -1, c_new, V, V) # graph4 = torch.einsum('n m k c u v, n m k c v l -> n m k c u l', graph2, graph2) graph2 = graph2.view(N, M, -1, c_new, V, V).mean(1).mean(2).view(N, -1) # graph4 = graph4.view(N, M, -1, c_new, V, V).mean(1).mean(2).view(N, -1) # graph = torch.cat([graph2, graph4], -1) return self.fc(x), graph2
原文地址:https://blog.csdn.net/yyfhq/article/details/143697036
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!