【零基础保姆级教程】MMDetection3安装与训练自己的数据集
最近在跑对比试验,由于MMDetection框架的算法较齐全,遂决定写一篇教程留做参考。若你对流程有问题与疑问欢迎评论区指出
本文运行环境如下供参考:
python版本 | 3.9 |
MMDetection版本 | 3.3 |
一、虚拟环境的搭建
参考该博客搭建基本环境:
【保姆级最简洁教程】零基础如何快速搭建YOLOv5/v7?_yolo安装-CSDN博客
可能会遇到的报错需要更新的包,输入如下代码解决:
pip install mmcv-full==2.0.0rc4 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12/index.html
pip install mmdet==3.0.0
对mmcv还有疑问的,可参考该博客解决:
二、跑通项目流程
首先到官网拉取项目
GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark
拉取后,进行以下步骤:
1.准备自己的数据集(VOC/COCO格式)
coco文件如图所示。
其中,test2017、train2017和val2017均存放图片。
annotations中文件结构如图所示。
2.在拉取项目的目录下新建文件夹data
把你的COCO或VOC格式数据集拉取进来,如图所示
3.根据自己的数据集类别进行修改
打开mmdetection-main\mmdet\evaluation\functional\class_names.py文件
将def coco_classes()和def voc_classes()的return中的内容改为自己数据集类别,博主此处示例只有一个类别,称为“polyp”,若你有多个类别名称不同请自行修改。如图所示:
打开mmdetection-main\mmdet\datasets\coco.py的class CocoDataset(BaseDetDataset):的 METAINFO = {}中的'classes':类别改为自己数据集的类别
打开mmdetection-main\mmdet\datasets\voc.py的
class VOCDataset(XMLDataset):的 METAINFO = {}中的'classes':类别改为自己数据集的类别,如图所示
4.在自己的根目录下运行重构项目
python setup.py install build
5.以SSD算法为例修改参数开始训练(若你要跑其它算法也均按此步骤)
打开mmdetection-main\configs\ssd\,configs目录可以挑选你需要跑的算法
打开ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py
对里面的一些参数做修改,作用是符合自己的训练目的,博主修改了以下位置:
num_classes = 类别数,默认是80,博主数据集只有一个类别,故改为1(不用管背景)
max_epochs = 世代数
input_size = 输入图像尺寸,默认是300,根据你的要求修改
batch_size = 批次量,按住ctrl+f搜索后依次修改,考虑你的计算机资源量后修改
num_workers=线程数,考虑你的计算机资源量后修改
val_interval=评估间隔次数,默认是5,即五次epoch一次评估,这里博主选择的是1,频繁的评估可能造成训练总时间的延长,建议根据自己的需要做调整。
🔺无论你跑的是哪个算法,此处_base_中的参数都需要调整,尤其是博主上述给出的参数。若不修改,例如,batchsize在许多默认算法中是192,这对于许多资源有限的计算机可能就会导致异常。
注:若是跑VOC数据集的同学,此处也需要修改
改为
包括下方的ann_file和data_prefix参数,该项目默认数据集都是COCO数据集的,所以选择COCO格式数据集的可以少改一些。
除此之外,往上两张图中的mmdetection-main\configs\_base_\schedules\schedules_2x.py中的参数也需要调整。
同学们在刚刚开始训练时,需要观察每个step输出的loss值高低,若有一个或多个突然走高的数值是十分正常的,但若是走高乃至直接nan,可以尝试调整optimizer=,将其的学习率调低(如默认的lr×0.1乃至更低)在博主的训练在常有效。
在博主的数据集上,除了ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py无需繁琐的调参外,centernet\centernet_r18_8xb16-crop512-140e_coco.py、retinanet\retinanet_effb3_fpn_8xb4-crop896-1x_coco.py也不需要调参,按照上述参数调整后即可开始训练。
同时,ssd\ssd300_coco.py、faster_rcnn\faster-rcnn_r18_fpn_8xb8-amp-lsj-200e_coco.py等的训练就很麻烦,经常需要调整学习率。
6.控制台输入命令开始训练
python tools/train.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py --work-dir work_dirs
即可在你的目录下生成work_dirs文件夹,待训练开始,打开后如图所示即代表训练成功
此处博主遭遇报错
ValueError: need at least one array to concatenate
该问题网上有许多解决方法均不适用,博主用如下方法解决:
在自己的本地环境下尤其注意类别需要改在:
./anaconda3/envs/(你的环境名称)/Lib/site-packages/mmdet/dataset/coco.py中的
'classes':('你的类别名称',)即可解决
7.正常运行结果如下所示
如图所示即运行成功。
三、输出指标Precision、Recall、F1
运行后会在work_dirs下生成你的设置的epoch等数量的权重文件。接下来做如下工作:
输入命令:
python tools/test.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py work_dirs/epoch_xx.pth --out=result.pkl
此处命令的xx需改为实际的数字,即可对你训练出的权重进行测试,测试结果会输出COCO指标与一个result.pkl文件,这个文件可用于生成precision/recall/f1。
打开文件tools/analysis_tools/confusion_matrix.py
在文件后加入代码
TP = np.diag(confusion_matrix)
FP = np.sum(confusion_matrix, axis=0) - TP
FN = np.sum(confusion_matrix, axis=1) - TP
precision = TP / (TP + FP)
recall = TP / (TP + FN)
average_precision = np.mean(precision)
average_recall = np.mean(recall)
f1 = 2 * (precision * recall) / (precision + recall)
print('AP:', average_precision)
print('AR:', average_recall)
print('F1:', f1)
print('Precision', precision[0])
print('Recall', recall[0])
output_file_path = os.path.join(save_dir, 'PRF1.txt')
with open(output_file_path, 'a') as output_file:
output_file.write(f'{prediction_path} {precision[0]:.5f} {recall[0]:.5f} {f1:.5f}\n')
输入命令
python tools/analysis_tools/confusion_matrix.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py result.pkl results/ --score-thr 0.5
即可生成对应一个epoch权重的指标。
若要生成所有权重的指标,请见链接:
【零基础保姆级教程】MMDetection3训练输出Precision/Recall/F1-Score指标-CSDN博客
四、输出指标GFLOPS、参数量
输入命令
python tools/analysis_tools/get_flops.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py
/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py换为你的配置文件即可输出GFLOPS与参数量。
五、输出指标FPS
打开tools/analysis_tools/benchmark.py,覆盖如下3.1.0的mmdetection该代码至原文
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from mmengine import MMLogger
from mmengine.config import Config, DictAction
from mmengine.dist import init_dist
from mmengine.registry import init_default_scope
from mmengine.utils import mkdir_or_exist
from mmdet.utils.benchmark import (DataLoaderBenchmark, DatasetBenchmark,InferenceBenchmark)
def parse_args():
parser = argparse.ArgumentParser(description='MMDet benchmark')
parser.add_argument('config', help='test config file path')
parser.add_argument('--checkpoint', help='checkpoint file')
parser.add_argument('--task',choices=['inference', 'dataloader', 'dataset'],default='inference',help='Which task do you want to go to benchmark')
parser.add_argument('--repeat-num',type=int,default=1,help='number of repeat times of measurement for averaging the results')
parser.add_argument('--max-iter', type=int, default=2000, help='num of max iter')
parser.add_argument('--log-interval', type=int, default=50, help='interval of logging')
parser.add_argument('--num-warmup', type=int, default=5, help='Number of warmup')
parser.add_argument('--fuse-conv-bn',action='store_true',help='Whether to fuse conv and bn, this will slightly increase the inference speed')
parser.add_argument('--dataset-type',choices=['train', 'val', 'test'],default='val',help='Benchmark dataset type. only supports train, val and test')
parser.add_argument('--work-dir',help='the directory to save the file containing benchmark metrics')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'],
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()
if 'LOCAL_RANK' not in os.environ:
os.environ['LOCAL_RANK'] = str(args.local_rank)
return args
def inference_benchmark(args, cfg, distributed, logger):
benchmark = InferenceBenchmark(
cfg,
args.checkpoint,
distributed,
args.fuse_conv_bn,
args.max_iter,
args.log_interval,
args.num_warmup,
logger=logger)
return benchmark
def dataloader_benchmark(args, cfg, distributed, logger):
benchmark = DataLoaderBenchmark(
cfg,
distributed,
args.dataset_type,
args.max_iter,
args.log_interval,
args.num_warmup,
logger=logger)
return benchmark
def dataset_benchmark(args, cfg, distributed, logger):
benchmark = DatasetBenchmark(
cfg,
args.dataset_type,
args.max_iter,
args.log_interval,
args.num_warmup,
logger=logger)
return benchmark
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
init_default_scope(cfg.get('default_scope', 'mmdet'))
distributed = False
if args.launcher != 'none':
init_dist(args.launcher, **cfg.get('env_cfg', {}).get('dist_cfg', {}))
distributed = True
log_file = None
if args.work_dir:
log_file = os.path.join(args.work_dir, 'benchmark.log')
mkdir_or_exist(args.work_dir)
logger = MMLogger.get_instance(
'mmdet', log_file=log_file, log_level='INFO')
benchmark = eval(f'{args.task}_benchmark')(args, cfg, distributed, logger)
benchmark.run(args.repeat_num)
if __name__ == '__main__':
main()
输入命令即可输出FPS:
python tools/analysis_tools/benchmark.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py --checkpoint work_dirs/ssd-mobilev2/epoch_150.pth
其中,此处configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py代表你的配置文件,--checkpoint部分代表你的权重路径。
附、运行过程中常见的报错
请见链接:
【持续更新中】MMDetection3训练自己的数据集常见报错解决-CSDN博客
更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!
原文地址:https://blog.csdn.net/2401_84870184/article/details/142624840
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!