自学内容网 自学内容网

DeepACO:用于组合优化的神经增强蚂蚁系统解决TSP问题的代码阅读

总体概括

在这里插入图片描述
DeepACO与普通ACO不同的是将问题输入实例输入到一个训练的网络中,将网络训练成为一个类似于专家知识的模块,可以生成相应的启发式矩阵网络,从而省去相应的专家知识。
其中在训练网络的代码中:
在这里插入图片描述
是进行监督式训练通过train_epochs函数进行一次训练然后通过validation函数对模块进行校验。接下来分别从这两个函数进行细致的讲述。

1.训练函数train_epoch()

train_epoch函数其实特别的简单就是调用utils文件中的gen_pyg_data函数输入:
instance:这是输入到 gen_pyg_data 函数的第一个参数,它是一个形状为(n_node, 2) 的张量,表示每个节点的特征。这些特征通常用于表示节点在二维空间中的位置。(instance 是一个由 PyTorch 生成的随机张量。具体来说,它是一个形状为 (n_node, 2) 的张量,其中 n_node 是一个变量,表示节点的数量,而 2 表示每个节点有两个特征)
k_sparse:这是输入到 gen_pyg_data 函数的第二个参数,它是一个整数,表示稀疏度。稀疏度决定了生成的图中的边数。具体来说,k_sparse 表示每个节点连接到其最近的 k_sparse 个邻居节点的边数。
在这里插入图片描述

gen_pyg_data 函数的具体实现如下:
首先,计算节点数 n_nodes。
然后,调用 gen_distance_matrix(tsp_coordinates) 函数生成距离矩阵 distances。
接着,使用 torch.topk 函数获取距离矩阵中前 k_sparse 个最小值和对应的索引。
然后,生成边索引 edge_index,其中第一个维度是重复的节点索引,第二个维度是前 k_sparse 个最小值的索引。
生成边属性 edge_attr,将前 k_sparse 个最小值展平。
最后,生成图数据 pyg_data,包含节点坐标、边索引和边属性。
x 是节点特征,即 tsp_coordinates,表示每个城市的坐标。
edge_index 是边的索引,即 edge_index,表示城市之间的连接关系。
edge_attr 是边的属性,即 edge_attr,表示城市之间的距离。
返回图数据和距离矩阵
在这里插入图片描述
train_instance主要是负责训练实例函数,通过蚁群优化算法和强化学习来优化路径成本。训练完成后,模型参数将根据优化结果进行更新。
在这里插入图片描述

训练实例的作用是使用给定的神经网络(net)和优化器(optimizer)对数据进行训练。具体来说,train_instance 函数会执行以下步骤:
数据准备:从 data 中提取训练数据,包括位置和距离矩阵。
初始化:设置一些训练参数,如训练轮数(epochs)和批处理大小(batch_size)。
训练循环:对于每一轮训练,执行以下操作:
将数据输入到神经网络中,得到预测结果。
计算预测结果和真实值之间的损失(loss)。
使用优化器更新神经网络的权重,以最小化损失。
评估:在每一轮训练后,评估模型在验证集上的表现,并记录损失值。
通过这个过程,神经网络会逐渐学习如何根据输入的位置和距离矩阵预测最优路径。训练完成后,模型可以用于预测新的实例,从而解决项目调度问题。

2.验证函数validation()

验证函数使用验证数据集通过infer_instance函数在给定模型和图数据的情况下,返回基线、采样成本最小值和ACO算法最小成本。然后validation函数返回avg_bl:平均的预测结果。avg_sample_best:平均的样本最佳结果。avg_aco_best:平均的蚁群优化最佳结果。
在这里插入图片描述

在这里插入图片描述


原文地址:https://blog.csdn.net/weixin_52326703/article/details/142795784

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!