基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)
目录
一、背景
图卷积神经网络(Graph Convolutional Networks, GCNs)在高光谱图像分类中是一种有效的方法,特别适用于处理具有复杂空间关系的数据。高光谱图像通常包含数百个甚至数千个连续的频谱波段,每个波段对应一个光谱特征,这使得传统的卷积神经网络在处理高光谱图像时面临困难,因为它们无法有效地捕获像素之间的空间关系。
GCNs通过利用图结构来解决这一问题,将像素(或者像素附近的区域)视为图中的节点,并利用这些节点之间的关系进行特征学习和分类。以下是GCNs在高光谱图像分类中的一些关键点和优势:
-
图结构建模:将高光谱图像中的像素视为图中的节点,像素之间的空间关系(例如邻近关系)作为图的边,这样就能够在整个图上利用节点的局部和全局信息。
-
卷积操作:GCN引入了图卷积操作,允许在图结构上进行类似于传统卷积神经网络中的卷积操作。这种操作可以捕获节点及其邻居的特征,并利用这些信息来提取更有意义的特征表示。
-
特征学习:通过多层的图卷积操作,GCNs能够逐步学习出更加抽象和高级的特征表示,这对于高光谱数据的复杂特征提取尤为重要。
-
分类器:最后一层通常是一个分类器,用于将学习到的特征映射到类别标签空间,从而进行分类。
-
适应性:GCNs在处理高光谱图像时具有很强的适应性和灵活性,能够处理不同大小和分辨率的图像,以及不同数量和配置的频谱波段。
总体来说,图卷积神经网络通过充分利用高光谱图像中像素之间的空间关系,有效地提升了分类性能,并在遥感图像分析和其他高维数据的处理中展现出了广阔的应用前景。
二、基于卷积神经网络的代码实现
下面我们以IP数据集为例子进行展开讲解。
1、安装依赖库
matplotlib==3.3.4
networkx==2.1
numpy==1.19.5
pandas==1.1.5
scikit_learn==1.5.1
scipy==1.5.4
seaborn==0.11.2
spectral==0.22.4
torch==1.7.1+cu110
torch_geometric==2.0.2
tqdm==4.62.3
2、建立图卷积神经网络
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch
import torch.nn as nn
class GCN(torch.nn.Module):
def __init__(self, num_node_features, num_classes):
super(GCN, self).__init__()
self.conv1 = GCNConv(num_node_features, 32)
self.conv1_bn_relu = nn.Sequential(
nn.BatchNorm1d(32),
nn.ReLU()
)
self.conv2 = GCNConv(32, 64)
self.conv2_bn_relu = nn.Sequential(
nn.BatchNorm1d(64),
nn.ReLU()
)
self.cls = nn.Sequential(
nn.Linear(64, num_classes),
)
def forward(self, edge, data):
x, edge_index = data, edge
x = self.conv1_bn_relu(self.conv1(x, edge_index))
x = self.conv2_bn_relu(self.conv2(x, edge_index))
return self.cls(x)
3、建立数据的边
首先进行PCA数据降维:
X_pca = applyPCA(X, numComponents=pca_components)
然后将无标签数据进行剔除:
X_pca = X_pca.reshape(-1,pca_components)
y = y.ravel()
mask = y == 0
# 剔除无标签的数据
data = X_pca[~mask]
label = y[~mask]
划分训练验证集(训练70%):
X_train, X_test, y_train, y_test = splitTrainTestSet(range(len(data)),label,trainRatio=0.7)
最后建立所有样本的边(这里取最近邻的样本为3):
Edge_build(data,k=3)
4、训练模型
加载数据和模型:
X_train_index,X_test_index = utils.create_train_test('./data/'+patch_+'/train_index.txt',
'./data/'+patch_+'/test_index.txt')
data,label = utils.create_features('./data/'+patch_+'/data.txt',
'./data/'+patch_+'/label.txt')
edge = pd.read_csv('./data/'+patch_+'/edge.txt', sep=" ", header=None).values.T
# 建立模型
model = GCN(30, 16)
训练模型:
class Trainer():
def __init__(self, data,y,edge,X_train_index,X_test_index, model, optimizer, loss_function, epochs):
self.y = y
self.edge = torch.from_numpy(edge).type(torch.LongTensor).to(device)
self.X_train_index = X_train_index
self.X_test_index = X_test_index
self.data = torch.from_numpy(data).type(torch.FloatTensor).to(device)
self.model = model.to(device)
self.optimizer = optimizer
self.loss_function = loss_function
self.epochs = epochs
self.y_train = torch.from_numpy(y[X_train_index]).type(torch.LongTensor).to(device)
self.y_test = torch.from_numpy(y[X_test_index]).type(torch.LongTensor).to(device)
self.preds = None
def train(self):
pass
def test(self):
self.model.eval()
pass
trainer = Trainer(
data=data,
y=label,
edge=edge,
X_train_index=X_train_index,
X_test_index=X_test_index,
model=model,
optimizer=optim.Adam(model.parameters(), lr=0.001),
loss_function=nn.CrossEntropyLoss(),
epochs=1000
)
trainer.train()
trainer.test()
5、可视化
if __name__ == '__main__':
patch_ = "IP"
graph, A = utils.create_Graphs_with_attributes_adjadjency_matrix('./data/' + patch_ + '/edge.txt',
'./data/' + patch_ + '/data.txt')
data, label = utils.create_features('./data/' + patch_ + '/data.txt',
'./data/' + patch_ + '/label.txt')
edge = pd.read_csv('./data/' + patch_ + '/edge.txt', sep=" ", header=None).values.T
model = GCN(30, 16)
model.eval()
net_params = torch.load("./weight/model.pkl")
model.load_state_dict(net_params) # 加载模型可学习参数
trainer = Trainer(
data=data,
y=label,
edge=edge,
model=model,
)
pred = trainer.pre() + 1
y_ = sio.loadmat('./data/Indian_pines_gt.mat')['indian_pines_gt']
a, b = y_.shape
print('Label shape: ', y_.shape)
y = y_.ravel()
mask = y == 0
outputs = np.zeros_like(y)
outputs[~mask] = pred
outputs = outputs.reshape((a, b))
import spectral
import matplotlib.pyplot as plt
predict_image = spectral.imshow(classes=outputs.astype(int), figsize=(5, 5))
plt.savefig('./results/pre.png', dpi=300)
plt.pause(1)
三、项目代码
本项目的代码通过以下链接下载:基于图卷积神经网络(GCN)的高光谱图像分类详细教程(含python代码)
原文地址:https://blog.csdn.net/qq_45100200/article/details/140752649
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!