自学内容网 自学内容网

【零基础保姆级教程】MMDetection3安装与训练自己的数据集

最近在跑对比试验,由于MMDetection框架的算法较齐全,遂决定写一篇教程留做参考。若你对流程有问题与疑问欢迎评论区指出

本文运行环境如下供参考:

python版本3.9
MMDetection版本3.3
一、虚拟环境的搭建

参考该博客搭建基本环境:

【保姆级最简洁教程】零基础如何快速搭建YOLOv5/v7?_yolo安装-CSDN博客

可能会遇到的报错需要更新的包,输入如下代码解决:

pip install mmcv-full==2.0.0rc4 -f https://download.openmmlab.com/mmcv/dist/cu116/torch1.12/index.html
pip install mmdet==3.0.0

对mmcv还有疑问的,可参考该博客解决:

【零基础保姆级教程】mmcv安装教程-CSDN博客

二、跑通项目流程

首先到官网拉取项目

GitHub - open-mmlab/mmdetection: OpenMMLab Detection Toolbox and Benchmark

拉取后,进行以下步骤:

1.准备自己的数据集(VOC/COCO格式)

coco文件如图所示。

其中,test2017、train2017和val2017均存放图片。

annotations中文件结构如图所示。

2.在拉取项目的目录下新建文件夹data

把你的COCO或VOC格式数据集拉取进来,如图所示

3.根据自己的数据集类别进行修改

打开mmdetection-main\mmdet\evaluation\functional\class_names.py文件

将def coco_classes()和def voc_classes()的return中的内容改为自己数据集类别,博主此处示例只有一个类别,称为“polyp”,若你有多个类别名称不同请自行修改。如图所示:

打开mmdetection-main\mmdet\datasets\coco.py的class CocoDataset(BaseDetDataset):的   METAINFO = {}中的'classes':类别改为自己数据集的类别
打开mmdetection-main\mmdet\datasets\voc.py的
class VOCDataset(XMLDataset):的   METAINFO = {}中的'classes':类别改为自己数据集的类别,如图所示

4.在自己的根目录下运行重构项目
python setup.py install build
5.以SSD算法为例修改参数开始训练(若你要跑其它算法也均按此步骤)

打开mmdetection-main\configs\ssd\,configs目录可以挑选你需要跑的算法

打开ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py

对里面的一些参数做修改,作用是符合自己的训练目的,博主修改了以下位置:

num_classes = 类别数,默认是80,博主数据集只有一个类别,故改为1(不用管背景)
max_epochs = 世代数
input_size = 输入图像尺寸,默认是300,根据你的要求修改
batch_size = 批次量,按住ctrl+f搜索后依次修改,考虑你的计算机资源量后修改
num_workers=线程数,考虑你的计算机资源量后修改

val_interval=评估间隔次数,默认是5,即五次epoch一次评估,这里博主选择的是1,频繁的评估可能造成训练总时间的延长,建议根据自己的需要做调整。

🔺无论你跑的是哪个算法,此处_base_中的参数都需要调整,尤其是博主上述给出的参数。若不修改,例如,batchsize在许多默认算法中是192,这对于许多资源有限的计算机可能就会导致异常。

注:若是跑VOC数据集的同学,此处也需要修改

改为

包括下方的ann_file和data_prefix参数,该项目默认数据集都是COCO数据集的,所以选择COCO格式数据集的可以少改一些。

除此之外,往上两张图中的mmdetection-main\configs\_base_\schedules\schedules_2x.py中的参数也需要调整。

同学们在刚刚开始训练时,需要观察每个step输出的loss值高低,若有一个或多个突然走高的数值是十分正常的,但若是走高乃至直接nan,可以尝试调整optimizer=,将其的学习率调低(如默认的lr×0.1乃至更低)在博主的训练在常有效。

在博主的数据集上,除了ssd\ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py无需繁琐的调参外,centernet\centernet_r18_8xb16-crop512-140e_coco.py、retinanet\retinanet_effb3_fpn_8xb4-crop896-1x_coco.py也不需要调参,按照上述参数调整后即可开始训练。

同时,ssd\ssd300_coco.py、faster_rcnn\faster-rcnn_r18_fpn_8xb8-amp-lsj-200e_coco.py等的训练就很麻烦,经常需要调整学习率。

6.控制台输入命令开始训练
python tools/train.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py --work-dir work_dirs

即可在你的目录下生成work_dirs文件夹,待训练开始,打开后如图所示即代表训练成功

此处博主遭遇报错

ValueError: need at least one array to concatenate

该问题网上有许多解决方法均不适用,博主用如下方法解决:

在自己的本地环境下尤其注意类别需要改在:
./anaconda3/envs/(你的环境名称)/Lib/site-packages/mmdet/dataset/coco.py中的
'classes':('你的类别名称',)即可解决

7.正常运行结果如下所示

如图所示即运行成功。

三、输出指标Precision、Recall、F1

运行后会在work_dirs下生成你的设置的epoch等数量的权重文件。接下来做如下工作:

输入命令:

python tools/test.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py work_dirs/epoch_xx.pth --out=result.pkl

此处命令的xx需改为实际的数字,即可对你训练出的权重进行测试,测试结果会输出COCO指标与一个result.pkl文件,这个文件可用于生成precision/recall/f1。

打开文件tools/analysis_tools/confusion_matrix.py

在文件后加入代码


    TP = np.diag(confusion_matrix)
    FP = np.sum(confusion_matrix, axis=0) - TP
    FN = np.sum(confusion_matrix, axis=1) - TP

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    average_precision = np.mean(precision)
    average_recall = np.mean(recall)
    f1 = 2 * (precision * recall) / (precision + recall)

    print('AP:', average_precision)
    print('AR:', average_recall)
    print('F1:', f1)
    print('Precision', precision[0])
    print('Recall', recall[0])


    output_file_path = os.path.join(save_dir, 'PRF1.txt')
    with open(output_file_path, 'a') as output_file:
        output_file.write(f'{prediction_path}    {precision[0]:.5f}   {recall[0]:.5f}   {f1:.5f}\n')

输入命令

python tools/analysis_tools/confusion_matrix.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py result.pkl results/ --score-thr 0.5 

即可生成对应一个epoch权重的指标。

若要生成所有权重的指标,请见链接:

【零基础保姆级教程】MMDetection3训练输出Precision/Recall/F1-Score指标-CSDN博客

四、输出指标GFLOPS、参数量

输入命令

python tools/analysis_tools/get_flops.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py

/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py换为你的配置文件即可输出GFLOPS与参数量。

五、输出指标FPS

打开tools/analysis_tools/benchmark.py,覆盖如下3.1.0的mmdetection该代码至原文

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from mmengine import MMLogger
from mmengine.config import Config, DictAction
from mmengine.dist import init_dist
from mmengine.registry import init_default_scope
from mmengine.utils import mkdir_or_exist
from mmdet.utils.benchmark import (DataLoaderBenchmark, DatasetBenchmark,InferenceBenchmark)
 
 
def parse_args():
    parser = argparse.ArgumentParser(description='MMDet benchmark')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('--checkpoint', help='checkpoint file')
    parser.add_argument('--task',choices=['inference', 'dataloader', 'dataset'],default='inference',help='Which task do you want to go to benchmark')
    parser.add_argument('--repeat-num',type=int,default=1,help='number of repeat times of measurement for averaging the results')
    parser.add_argument('--max-iter', type=int, default=2000, help='num of max iter')
    parser.add_argument('--log-interval', type=int, default=50, help='interval of logging')
    parser.add_argument('--num-warmup', type=int, default=5, help='Number of warmup')
    parser.add_argument('--fuse-conv-bn',action='store_true',help='Whether to fuse conv and bn, this will slightly increase the inference speed')
    parser.add_argument('--dataset-type',choices=['train', 'val', 'test'],default='val',help='Benchmark dataset type. only supports train, val and test')
    parser.add_argument('--work-dir',help='the directory to save the file containing benchmark metrics')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
 
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args
 
 
def inference_benchmark(args, cfg, distributed, logger):
    benchmark = InferenceBenchmark(
        cfg,
        args.checkpoint,
        distributed,
        args.fuse_conv_bn,
        args.max_iter,
        args.log_interval,
        args.num_warmup,
        logger=logger)
    return benchmark
 
 
def dataloader_benchmark(args, cfg, distributed, logger):
    benchmark = DataLoaderBenchmark(
        cfg,
        distributed,
        args.dataset_type,
        args.max_iter,
        args.log_interval,
        args.num_warmup,
        logger=logger)
    return benchmark
 
 
def dataset_benchmark(args, cfg, distributed, logger):
    benchmark = DatasetBenchmark(
        cfg,
        args.dataset_type,
        args.max_iter,
        args.log_interval,
        args.num_warmup,
        logger=logger)
    return benchmark
 
 
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
 
    init_default_scope(cfg.get('default_scope', 'mmdet'))
 
    distributed = False
    if args.launcher != 'none':
        init_dist(args.launcher, **cfg.get('env_cfg', {}).get('dist_cfg', {}))
        distributed = True
 
    log_file = None
    if args.work_dir:
        log_file = os.path.join(args.work_dir, 'benchmark.log')
        mkdir_or_exist(args.work_dir)
 
    logger = MMLogger.get_instance(
        'mmdet', log_file=log_file, log_level='INFO')
 
    benchmark = eval(f'{args.task}_benchmark')(args, cfg, distributed, logger)
    benchmark.run(args.repeat_num)
 
 
if __name__ == '__main__':
    main()

输入命令即可输出FPS:

python tools/analysis_tools/benchmark.py configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py --checkpoint work_dirs/ssd-mobilev2/epoch_150.pth

其中,此处configs/ssd/ssdlite_mobilenetv2-scratch_8xb24-600e_coco.py代表你的配置文件,--checkpoint部分代表你的权重路径。

附、运行过程中常见的报错

请见链接:

【持续更新中】MMDetection3训练自己的数据集常见报错解决-CSDN博客

更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!


原文地址:https://blog.csdn.net/2401_84870184/article/details/142624840

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!