梯度被原地修改,破坏了计算图
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 18, 32, 32]] is at version 2; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient.
出现错误片段
def forward_slice(self, x_slice, x_channel, color, hp, cc_tsf, ctx_tsf, h_tsf, color_tsf, parameter_aggregation, entropy, entropy_mode):
h = h_tsf(hp)
support = h
if color != None:
clr = color_tsf(color)
support = torch.cat([support, clr], dim=1)
if x_channel != None:
ch = cc_tsf(x_channel)
support = torch.cat([support, ch], dim=1)
x_slice_anchor = torch.zeros_like(x_slice).to(x_slice.device)
ctx_anchor = ctx_tsf(x_slice_anchor)
support_anchor = torch.cat([support, ctx_anchor], dim=1)
parameters = parameter_aggregation(support_anchor)
if entropy_mode == "gmm":
mean_anchor,sigma_anchor,weight_anchor = torch.chunk(parameters, 3, dim=1)
weight_anchor = F.softmax(weight_anchor, dim=1)
else:
mean_anchor,sigma_anchor = torch.chunk(parameters, 2, dim=1)
weight_anchor = None
probs_anchor = entropy.likelihood(x_slice, mean_anchor, sigma_anchor, weight_anchor)
probs = torch.zeros_like(x_slice).to(x_slice.device)
probs[:,:,0::2,0::2] = probs_anchor[:,:,0::2,0::2]
probs[:,:,1::2,1::2] = probs_anchor[:,:,1::2,1::2]
x_slice_anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]
x_slice_anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]
ctx_non_anchor = ctx_tsf(x_slice_anchor)
support_non_anchor = torch.cat([support, ctx_non_anchor], dim=1)
parameters_non_anchor = parameter_aggregation(support_non_anchor)
if entropy_mode == "gmm":
mean_non_anchor,sigma_non_anchor,weight_non_anchor = torch.chunk(parameters_non_anchor, 3, dim=1)
weight_non_anchor = F.softmax(weight_non_anchor, dim=1)
else:
mean_non_anchor,sigma_non_anchor = torch.chunk(parameters_non_anchor, 2, dim=1)
weight_non_anchor = None
probs_non_anchor = entropy.likelihood(x_slice, mean_non_anchor, sigma_non_anchor, weight_non_anchor)
probs[:,:,0::2,1::2] = probs_non_anchor[:,:,0::2,1::2]
probs[:,:,1::2,0::2] = probs_non_anchor[:,:,1::2,0::2]
return probs
错误原因:
x_slice_anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]
x_slice_anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]
这一步对x_slice_anchor进行了修改,但是x_slice_anchor在前面已经用到过,其已经在计算图中,虽然在数值上仍然等于0,但是对其修改会破坏原有的计算图,导致上述错误。
解决办法是新开一个tensor用来存储x_slice的对应位置参数。
所以在修改一个变量的时候,一定要慎重。
解决代码:
def forward_slice(self, x_slice, x_channel, color, hp, cc_tsf, ctx_tsf, h_tsf, color_tsf, parameter_aggregation, entropy, entropy_mode):
h = h_tsf(hp)
support = h
if color != None:
clr = color_tsf(color)
support = torch.cat([support, clr], dim=1)
if x_channel != None:
ch = cc_tsf(x_channel)
support = torch.cat([support, ch], dim=1)
x_slice_anchor = torch.zeros_like(x_slice).to(x_slice.device)
ctx_anchor = ctx_tsf(x_slice_anchor)
support_anchor = torch.cat([support, ctx_anchor], dim=1)
parameters = parameter_aggregation(support_anchor)
if entropy_mode == "gmm":
mean_anchor,sigma_anchor,weight_anchor = torch.chunk(parameters, 3, dim=1)
weight_anchor = F.softmax(weight_anchor, dim=1)
else:
mean_anchor,sigma_anchor = torch.chunk(parameters, 2, dim=1)
weight_anchor = None
probs_anchor = entropy.likelihood(x_slice, mean_anchor, sigma_anchor, weight_anchor)
# 开了一个新的tensor用来存储其中的变量,既能保证原有的计算图不被破坏,又能保证数值传递正确,梯度传递正确
probs = torch.zeros_like(x_slice).to(x_slice.device)
probs[:,:,0::2,0::2] = probs_anchor[:,:,0::2,0::2]
probs[:,:,1::2,1::2] = probs_anchor[:,:,1::2,1::2]
anchor = torch.zeros_like(x_slice).to(x_slice.device)
anchor[:, :, 0::2, 0::2] = x_slice[:,:, 0::2, 0::2]
anchor[:,:, 1::2, 1::2] = x_slice[:,:, 1::2, 1::2]
ctx_non_anchor = ctx_tsf(anchor)
support_non_anchor = torch.cat([support, ctx_non_anchor], dim=1)
parameters_non_anchor = parameter_aggregation(support_non_anchor)
if entropy_mode == "gmm":
mean_non_anchor,sigma_non_anchor,weight_non_anchor = torch.chunk(parameters_non_anchor, 3, dim=1)
weight_non_anchor = F.softmax(weight_non_anchor, dim=1)
else:
mean_non_anchor,sigma_non_anchor = torch.chunk(parameters_non_anchor, 2, dim=1)
weight_non_anchor = None
probs_non_anchor = entropy.likelihood(x_slice, mean_non_anchor, sigma_non_anchor, weight_non_anchor)
probs[:,:,0::2,1::2] = probs_non_anchor[:,:,0::2,1::2]
probs[:,:,1::2,0::2] = probs_non_anchor[:,:,1::2,0::2]
return probs
原文地址:https://blog.csdn.net/weixin_51435884/article/details/140485215
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!