自学内容网 自学内容网

c++ libtorch tensor 注意浅拷贝

错误代码示例

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
    multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}

torch::Tensor all_Kx = multi_dim_identity;
torch::Tensor all_Ky = multi_dim_identity;

for (int i = 0; i < 2; ++i) {
   torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
   torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
   for (int j = 0; j < dim_x * dim_y; ++j) {
       all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
       all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}

}

结果 all_Kx和all_Ky一样,在每个第1维度上都是一样的随机b,因为all_Kx和all_Ky都是multi_dim_identity的浅拷贝,all_Kx先赋值,其实是赋值给了multi_dim_identity,然后all_Ky再赋值,其实是赋值给了multi_dim_identity,导致all_Kx也跟着变,所以和all_Ky一样

正确代码如下

torch::Tensor multi_dim_identity = torch::zeros({ 2, 2, 2, 2 }, torch::kComplexDouble);
for (int i = 0; i < 2; ++i) {
    multi_dim_identity.index_put_({ torch::indexing::Slice(), torch::indexing::Slice(), i, i }, 1);
}

torch::Tensor all_Kx = multi_dim_identity.clone();
torch::Tensor all_Ky = multi_dim_identity.clone();

for (int i = 0; i < 2; ++i) {
   torch::Tensor a = torch::zeros({ 2, 2, 2 }, torch::kComplexDouble);
   torch::Tensor b = torch::rand({ 2, 2, 2 }, torch::kComplexDouble);
   for (int j = 0; j < dim_x * dim_y; ++j) {
       all_Kx.index_put_({ i, torch::indexing::Slice(), j, j }, a.index({torch::indexing::Slice(), j, j}));
       all_Ky.index_put_({ i, torch::indexing::Slice(), j, j }, b.index({torch::indexing::Slice(), j, j}));
}

}

用clone方法可以深拷贝,这样all_Kx和all_Ky就不一样


原文地址:https://blog.csdn.net/reyyy/article/details/142886170

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!