卷积神经网络随记
1.问题描述:一般而言,几个小滤波器卷积层的组合比一个大滤波器卷积层要好,比如层层堆叠了3个3x3的卷积层,中间含有非线性激活层,在这种排列下面,第一个卷积层中每个神经元对输入数据的感受野是3x3,第二层卷积层对第一层卷积层的感受野也是3x3,这样对于输入数据的感受野就是5x5,同样,第三层卷积层上对第二层卷积层的感受野是3x3,这样第三层卷积层对于第一层输入数据的感受野就是7x7。
问题:如何理解“对于输入数据的感受野就是5x5”“第三层卷积层对于第一层输入数据的感受野就是7x7”用计算公式说明
解答:
在卷积神经网络中,感受野是指卷积神经网络每一层输出的特征图(feature map)上的像素点在输入图像上映射的区域大小。换句话说,它定义了输出特征图中的一个元素与输入图像之间的空间范围。
对于堆叠的卷积层,每一层的感受野可以通过以下公式计算:
其中,是第 n 层的感受野大小, 是第 n 层的核(滤波器)大小。
以3x3的卷积层为例,如果连续堆叠3个这样的卷积层,我们可以这样计算每层的感受野:
1. 第一层卷积层使用3x3的滤波器,所以其感受野为 。
2. 第二层卷积层对第一层的输出应用另一个3x3的滤波器,因此感受野增加到 。
3. 第三层卷积层再次应用一个3x3的滤波器,感受野进一步增加到 。
因此,第三层卷积层对于第一层输入数据的感受野是7x7,这意味着第三层的每一个元素都对应于原始输入图像中的一个7x7的区域。
这种层层堆叠小滤波器的策略可以增加网络的深度和非线性,同时保持参数数量相对较少,这通常可以提高网络的性能,因为它能够捕捉到更复杂的模式和特征。
2.问题描述:Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/cnn/net.py", line 57, in <module> conv_model.add_module(layer[0], layer[1]) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 390, in add_module raise KeyError("module name can't contain \".\", got: {}".format(name)) KeyError: 'module name can\'t contain ".", got: layer1.conv1'
解决方法:这个错误是因为在添加模块时,模块名称不能包含"."。你可以尝试将层的名称中的"."替换为其他字符,例如下划线"_"。
conv_model = SimpleCNN()
for layer in model.named_modules():
if isinstance(layer[1], nn.Conv2d):
# 将层名称中的"."替换为"_"
new_name = layer[0].replace('.', '_')
conv_model.add_module(new_name, layer[1])
3.问题:怎么提取卷积模块,使用print打印出来
解决方法:要提取卷积模块并使用print打印出来,可以遍历模型的层,检查每一层是否为卷积层(nn.Conv2d),如果是,则打印出该层的详细信息。
import torch
import torch.nn as nn
# 假设你已经定义了一个名为model的卷积神经网络模型
model = ...
# 遍历模型的所有层
for name, module in model.named_modules():
# 检查当前层是否为卷积层
if isinstance(module, nn.Conv2d):
# 打印卷积层的名称和参数
print(f"Layer: {name}")
print(f"Kernel size: {module.kernel_size}")
print(f"Stride: {module.stride}")
print(f"Padding: {module.padding}")
print(f"Dilation: {module.dilation}")
print(f"Groups: {module.groups}")
print("-----")
这段代码将遍历模型的所有层,如果发现某一层是卷积层(nn.Conv2d),就会打印出该层的名称、核大小、步长、填充、扩张和分组等信息。
4.问题:/home/xc/PycharmProjects/Liuyafei/cnn/net.py:68: UserWarning: nn.init.normal is now deprecated in favor of nn.init.normal_. init.normal(m.weight.data) /home/xc/PycharmProjects/Liuyafei/cnn/net.py:69: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_. init.xavier_normal(m.weight.data) /home/xc/PycharmProjects/Liuyafei/cnn/net.py:70: UserWarning: nn.init.kaiming_normal is now deprecated in favor of nn.init.kaiming_normal_. init.kaiming_normal(m.weight.data)
解决方法:要修改这些警告,需要将nn.init.normal
、nn.init.xavier_normal
和nn.init.kaiming_normal
替换为它们的下划线版本,即nn.init.normal_
、nn.init.xavier_normal_
和nn.init.kaiming_normal_
。
5.问题:Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/MNIST/cnn.py", line 68, in <module> model = CNN(28 * 28, 300, 100, 10) TypeError: __init__() takes 1 positional argument but 5 were given
解决方法:CNN
类定义中并没有接受这些参数的构造函数。实际上,已经在__init__
方法中初始化了网络的各个层,因此不需要在创建CNN
对象时传递任何参数。
model = CNN()
6.问题:
Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/MNIST/cnn.py", line 85, in <module> outputs = model(inputs.view(-1, 28 * 28)) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/PycharmProjects/Liuyafei/MNIST/cnn.py", line 42, in forward x = self.layer1(x) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 457, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 784]
解决方法:
这个错误是因为卷积层期望输入的数据维度为3D(未批量化)或4D(批量化),但实际输入的数据维度为[64, 784]。要解决这个问题,需要将输入数据调整为正确的维度。可以通过在输入数据上添加一个额外的维度来实现这一点。
- 首先,确保输入数据的维度为[batch_size, channels, height, width]。在这个例子中,channels应该为1,因为MNIST数据集是灰度图像。
- 修改代码,将输入数据调整为正确的维度。
# 假设 inputs 是一个形状为 [batch_size, 784] 的张量
inputs = inputs.view(-1, 1, 28, 28) # 将输入数据调整为 [batch_size, 1, 28, 28]
outputs = model(inputs)
7.问题:
Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/MNIST/cnn.py", line 118, in <module> out = model(img) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/PycharmProjects/Liuyafei/MNIST/cnn.py", line 42, in forward x = self.layer1(x) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 457, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [64, 784]
解决方法:
在测试阶段,需要将输入数据调整为适合卷积层的形状。具体来说,卷积层期望的输入形状是 (batch_size, channels, height, width)
,而当前的输入数据形状是 (batch_size, 784)
。
要解决这个问题,需要在测试阶段的输入数据上添加一个通道维度。可以使用 unsqueeze()
函数来实现这一点。
# eval
eval_loss = 0
eval_acc = 0
for data in test_loader:
img, label = data
img = img.view(img.size(0), 1, 28, 28) # 添加通道维度
if torch.cuda.is_available():
img = Variable(img).cuda()
label = Variable(label).cuda()
else:
img = Variable(img)
label = Variable(label)
out = model(img)
loss = criterion(out, label)
eval_loss += loss.item() * label.size(0)
_, pred = torch.max(out, 1)
num_correct = (pred == label).sum()
eval_acc += num_correct.item()
print('Test Loss: {:.6f}, ACC: {:.6f}'.format(eval_loss / (len(test_dataset)), eval_acc / (len(test_dataset))))
8.问题:
Traceback (most recent call last): File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1354, in do_open h.request(req.get_method(), req.selector, req.data, headers, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1256, in request self._send_request(method, url, body, headers, encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1302, in _send_request self.endheaders(body, encode_chunked=encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1251, in endheaders self._send_output(message_body, encode_chunked=encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1011, in _send_output self.send(msg) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 951, in send self.connect() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1425, in connect self.sock = self._context.wrap_socket(self.sock, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 500, in wrap_socket return self.sslsocket_class._create( File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1040, in _create self.do_handshake() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1309, in do_handshake self._sslobj.do_handshake()
或者这个:
Traceback (most recent call last):
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1354, in do_open
h.request(req.get_method(), req.selector, req.data, headers,
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1256, in request
self._send_request(method, url, body, headers, encode_chunked)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1302, in _send_request
self.endheaders(body, encode_chunked=encode_chunked)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1251, in endheaders
self._send_output(message_body, encode_chunked=encode_chunked)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1011, in _send_output
self.send(msg)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 951, in send
self.connect()
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1425, in connect
self.sock = self._context.wrap_socket(self.sock,
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 500, in wrap_socket
return self.sslsocket_class._create(
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1040, in _create
self.do_handshake()
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1309, in do_handshake
self._sslobj.do_handshake()
ssl.SSLCertVerificati torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 65, in __init__
self.download()
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 141, in download
download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 446, in download_and_extract_archive
download_url(url, download_root, filename, md5)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 146, in download_url
url = _get_redirect_url(url, max_hops=max_redirect_hops)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 94, in _get_redirect_url
with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response:
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 222, in urlopen
return opener.open(url, data, timeout)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 525, in open
response = self._open(req, data)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 542, in _open
result = self._call_chain(self.handle_open, protocol, protocol +
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 502, in _call_chain
result = func(*args)
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1397, in https_open
return self.do_open(http.client.HTTPSConnection, req,
File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1357, in do_open
raise URLError(err)
urllib.error.URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1131)>
进程已结束,退出代码 1
解决办法:
这个错误是由于SSL证书验证失败导致的。可以尝试在代码中禁用SSL证书验证,但请注意这样做可能会导致安全风险。如果仍然想要尝试禁用SSL证书验证,可以在创建urllib.request.urlopen
对象之前添加以下代码:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
将这段代码添加到你的脚本的开头,然后再次运行脚本。
9.问题:
Downloading http://192.168.0.2:80/ac_portal/proxy.html?template=disclaimer&tabs=pwd&vlanid=0&_ID_=0&switch_url=&url=https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz&controller_type=&mac=04-7c-16-5e-85-7c to ./data/cifar-10-python.tar.gz 100%|█████████████████████████████████| 2250/2250 [00:00<00:00, 10922666.67it/s] Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/cifar/resnet.py", line 25, in <module> trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 65, in __init__ self.download() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 141, in download download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 446, in download_and_extract_archive download_url(url, download_root, filename, md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 167, in download_url raise RuntimeError("File not found or corrupted.") RuntimeError: File not found or corrupted.
这个错误表明在下载过程中出现了问题,导致文件未找到或损坏。可以尝试重新运行代码以重新下载文件。如果问题仍然存在,可能是网络连接问题或者文件链接失效。
10.问题:
Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/cifar/resnet.py", line 25, in <module> trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 65, in __init__ self.download() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 141, in download download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 446, in download_and_extract_archive download_url(url, download_root, filename, md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 146, in download_url url = _get_redirect_url(url, max_hops=max_redirect_hops) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 94, in _get_redirect_url with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 222, in urlopen return opener.open(url, data, timeout) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 525, in open response = self._open(req, data) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 542, in _open result = self._call_chain(self.handle_open, protocol, protocol + File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 502, in _call_chain result = func(*args) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1397, in https_open return self.do_open(http.client.HTTPSConnection, req, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1358, in do_open r = h.getresponse() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1348, in getresponse response.begin() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 316, in begin version, status, reason = self._read_status() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 285, in _read_status raise RemoteDisconnected("Remote end closed connection without" http.client.RemoteDisconnected: Remote end closed connection without response
解决方法:
这个错误是由于在下载CIFAR-10数据集时,远程服务器关闭了连接而没有响应。这可能是由于网络问题或服务器问题导致的。你可以尝试以下方法解决这个问题:
- 检查你的网络连接是否正常。
- 确保你的防火墙或代理设置不会阻止下载。
- 稍后再尝试下载,可能是服务器暂时不可用。
- 如果问题仍然存在,可以考虑手动下载CIFAR-10数据集,然后将其解压缩到指定的目录(在这个例子中是
./data
)。
11.问题:
Traceback (most recent call last): File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1354, in do_open h.request(req.get_method(), req.selector, req.data, headers, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1256, in request self._send_request(method, url, body, headers, encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1302, in _send_request self.endheaders(body, encode_chunked=encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1251, in endheaders self._send_output(message_body, encode_chunked=encode_chunked) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1011, in _send_output self.send(msg) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 951, in send self.connect() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/http/client.py", line 1425, in connect self.sock = self._context.wrap_socket(self.sock, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 500, in wrap_socket return self.sslsocket_class._create( File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1040, in _create self.do_handshake() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/ssl.py", line 1309, in do_handshake self._sslobj.do_handshake() Connecti torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 65, in __init__ self.download() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 141, in download download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 446, in download_and_extract_archive download_url(url, download_root, filename, md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 146, in download_url url = _get_redirect_url(url, max_hops=max_redirect_hops) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 94, in _get_redirect_url with urllib.request.urlopen(urllib.request.Request(url, headers=headers)) as response: File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 222, in urlopen return opener.open(url, data, timeout) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 525, in open response = self._open(req, data) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 542, in _open result = self._call_chain(self.handle_open, protocol, protocol + File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 502, in _call_chain result = func(*args) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1397, in https_open return self.do_open(http.client.HTTPSConnection, req, File "/home/xc/anaconda3/envs/share_env/lib/python3.8/urllib/request.py", line 1357, in do_open raise URLError(err) urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>
解决方法:
这个错误是由于在尝试下载CIFAR-10数据集时,远程服务器关闭了连接。这可能是由于网络问题或服务器问题导致的。可以尝试以下方法解决这个问题:
- 检查你的网络连接是否正常。
- 确保防火墙或代理设置不会阻止下载。
- 稍后再尝试下载,可能是服务器暂时不可用。
- 如果问题仍然存在,可以考虑手动下载CIFAR-10数据集,然后将其解压缩到指定的目录(在这个例子中是
./data
)。
12.问题:
通过这个网址http://192.168.0.2:80/ac_portal/proxy.html?template=disclaimer&tabs=pwd&vlanid=0&_ID_=0&switch_url=&url=https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz&controller_type=&mac=04-7c-16-5e-85-7c下载了名为“cifar-10-python.tar.gz”这个的文件,将下载的文件“cifar-10-python.tar.gz”放到当前运行文件data下面后,运行当前代码,显示以下错误:
Downloading http://192.168.0.2:80/ac_portal/proxy.html?template=disclaimer&tabs=pwd&vlanid=0&_ID_=0&switch_url=&url=https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz&controller_type=&mac=04-7c-16-5e-85-7c to ./data/cifar-10-python.tar.gz 100%|██████████████████████████████████| 2250/2250 [00:00<00:00, 9779465.28it/s] Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/cifar/resnet.py", line 25, in <module> trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 65, in __init__ self.download() File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/cifar.py", line 141, in download download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 446, in download_and_extract_archive download_url(url, download_root, filename, md5) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torchvision/datasets/utils.py", line 167, in download_url raise RuntimeError("File not found or corrupted.") RuntimeError: File not found or corrupted.
解决方法:
从错误信息来看,下载的文件可能已经损坏或者文件路径不正确。请尝试以下步骤:
1. 确保文件 "cifar-10-python.tar.gz" 已经被正确下载到 "./data" 目录下。
2. 如果文件不存在,请重新运行代码以下载文件。
3. 如果文件已存在但仍然出现错误,请检查文件是否完整且未损坏。可以尝试重新下载文件。
4. 确保在创建 `CIFAR10` 数据集时使用正确的文件路径。在这个例子中,文件应该位于 "./data/cifar-10-python.tar.gz"。
如果问题仍然存在,可以尝试手动下载文件并将其移动到正确的位置。以下是手动下载和解压文件的步骤:
1. 访问 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 并下载文件。
2. 将下载的文件移动到 "./data" 目录下。
3. 确保文件名正确(应为 "cifar-10-python.tar.gz")。
4. 解压文件:在终端中运行 `tar -xzvf ./data/cifar-10-python.tar.gz -C ./data`。
5. 重新运行代码。
13.问题:
Traceback (most recent call last): File "/home/xc/PycharmProjects/Liuyafei/cifar/resnet.py", line 67, in <module> outputs = net(inputs) File "/home/xc/anaconda3/envs/share_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/xc/PycharmProjects/Liuyafei/cifar/resnet.py", line 46, in forward x = self.pool(F.relu(self.conv1(x))) AttributeError: 'TypeVar' object has no attribute 'relu'
解决方法:
这个错误是因为在代码中使用了`F.relu`,但是没有正确导入`torch.nn.functional`模块。要解决这个问题,需要在代码开头添加以下导入语句:
import torch.nn.functional as F
这样,`F.relu`就可以正常工作了。
原文地址:https://blog.csdn.net/miodi/article/details/140814447
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!