自学内容网 自学内容网

目标检测、语义分割离线增强脚本

主要利用Augmentor库进行离线数据增强,修改文件地址即可。

一、Augmentor 简介:
Augmentor 是一个Python 包,旨在帮助机器学习任务的图像数据人工生成和数据增强。它主要是一
种数据增强工具,但也将包含基本的图像预处理功能。Augmentor 包含许多用于标准图像处理功能的
类,例如Rotate 旋转类、Crop 裁剪类等等。包含的操作有:旋转rotate、裁剪crop、透视perspective
skewing、shearing、弹性形变Elastic Distortions、亮度、对比度、颜色等等;更多的操作及其参数
设定。

二、目标检测数据增强

import xml.dom.minidom
import cv2
from albumentations import(
BboxParams, RandomGamma,Compose,Blur,CenterCrop,HueSaturationValue,
MotionBlur,Cutout)

import os
import glob
def read_xml(path):
    exp_xml = []
    dom = xml.dom.minidom.parse(path)
    root = dom.documentElement
    img_name = root.getElementsByTagName("filename")[0]
    exp_xml.append(img_name.childNodes[0].data+".jpg")
    #print("fileneme:%s"%img_name.childNodes[0].data)

    label = root.getElementsByTagName("name")[0]
    exp_xml.append(label.childNodes[0].data)

    bonbox_xmin = root.getElementsByTagName("xmin")[0]
    exp_xml.append(bonbox_xmin.childNodes[0].data)
    bonbox_ymin = root.getElementsByTagName("ymin")[0]
    exp_xml.append(bonbox_ymin.childNodes[0].data)
    bonbox_xmax = root.getElementsByTagName("xmax")[0]
    exp_xml.append(bonbox_xmax.childNodes[0].data)
    bonbox_ymax = root.getElementsByTagName("ymax")[0]
    exp_xml.append(bonbox_ymax.childNodes[0].data)
    return exp_xml


def modify_xml(path,bbox,new_img_name,aug_file,num,n):
    new_dom = xml.dom.minidom.parse(path)
    new_root = new_dom.documentElement
    new_img_xml_name = new_root.getElementsByTagName("filename")[0]
    new_img_xml_name.childNodes[0].data = new_img_name
    new_bonbox_xmin = new_root.getElementsByTagName("xmin")[0]
    new_bonbox_xmin.childNodes[0].data = bbox[0]
    new_bonbox_ymin = new_root.getElementsByTagName("ymin")[0]
    new_bonbox_ymin.childNodes[0].data = bbox[1]
    new_bonbox_xmax = new_root.getElementsByTagName("xmax")[0]
    new_bonbox_xmax.childNodes[0].data = bbox[2]
    new_bonbox_ymax = new_root.getElementsByTagName("ymax")[0]
    new_bonbox_ymax.childNodes[0].data = bbox[3]
    with open(os.path.join(aug_file,aug_file+"\\aug_img{}_{}.xml".format(n,num)), 'w') as fh:
        new_dom.writexml(fh)

def visualize_bbox(img, bbox, class_id, class_idx_to_name):
    bbox = list(bbox)
    x_min, y_min, x_max, y_max = bbox
    x_min = int(x_min)
    y_min = int(y_min)
    x_max = int(x_max)
    y_max = int(y_max)
    image = cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255,0,0), 2)
    class_name = class_idx_to_name[class_id]

    ((text_width, text_height), _) = cv2.getTextSize(class_name,cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
    cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), (255,0,0), -1)
    cv2.putText(image, class_name, (x_min, y_min - int(0.3 * text_height)),cv2.FONT_HERSHEY_SIMPLEX, 0.35,(255,255,255), lineType=cv2.LINE_AA)
    return image


def get_aug(aug, min_area=0., min_visibility=0.):
    return Compose(aug, bbox_params=BboxParams(format='pascal_voc',min_area=min_area,
        min_visibility=min_visibility,label_fields=["category_id"]))

def augment():
    aug = Compose([
    Blur(blur_limit = 7,p = 0.3),
    RandomGamma(gamma_limit=(80,120),p=0.5),
    CenterCrop(height=400, width=400, p=0.2),
    HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
    MotionBlur(blur_limit=7, p=0.5),
    Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0,always_apply=False, p=0.2)])
    return aug

def get_data(xml_date_path,img_path):
    xml_date = read_xml(xml_date_path)
    img = cv2.imread(img_path + '\\' + xml_date[0])
    bbox = [[int(xml_date[2]),int(xml_date[3]),int(xml_date[4]),int(xml_date[5])]]
    return bbox,img

def keep_aug_img(annotations):
    aug_img = annotations['image'].copy()
    for idx, bbox in enumerate(annotations['bboxes']):
        bbox = list(bbox)
        x_min, y_min, x_max, y_max = bbox
        x_min = int(x_min)
        y_min = int(y_min)
        x_max = int(x_max)
        y_max = int(y_max)
        aug_bbox = [x_min,y_min,x_max,y_max]
    return aug_img,aug_bbox


def visualize(annotations, category_id_to_name):
    img = annotations['image'].copy()
    for idx, bbox in enumerate(annotations['bboxes']):
        img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
    return img

def main():
    xml_img_path = r"H:\darknet_y\data\car" # 存放xml 和img 数据地址
    aug_file = r"H:\darknet_y\data\car\aug_file" # 增强xml 和img 数据存放地址
    shample = 6 # 需要增强的次数
    for n in range(shample):
        num= 0
        print(" 第%d次"%n)
        for xml_name in glob.glob(xml_img_path + "/*.xml"): # 循环
            print(" 第%d张图片"%num)
            bbox,img = get_data(xml_name,xml_img_path) # 获得img 以及xml 中bbox 坐标。
            annotations = {'image': img, 'bboxes': bbox, 'category_id': [1]}
            aug = augment()
            augmented = aug(**annotations)
            #category_id_to_name = {1:"juanyuanzi"}
            #img,bbox = visualize(augmented, category_id_to_name)
            #cv2.imshow("x",img)
            #cv2.waitKey(0)
            """ 可视化"""
            aug_img,aug_bbox = keep_aug_img(augmented) # 增强后的图像及相对应的坐标
            cv2.imwrite(aug_file+"\\aug_img{}_{}.jpg".format(n,num),aug_img)
            #img 保存
            new_xml_path = os.path.split(aug_file+"\\aug_img{}_{}.jpg".format(n,num))[1] # 获取增强xml 地址
            new_xml_name = new_xml_path.split(".")[0] # 获取xml 名字
            modify_xml(xml_name,aug_bbox,new_xml_name,aug_file,num,n) # 对xml 文件进行修改
            num += 1

if __name__ == "__main__":
    main()

三、语义分割数据增强

import glob
import cv2
from albumentations import (
    PadIfNeeded,
    HorizontalFlip,
    VerticalFlip,
    CenterCrop,
    Crop,
    Compose,
    Transpose,
    RandomRotate90,
    ElasticTransform,
    GridDistortion,
    OpticalDistortion,
    RandomSizedCrop,
    OneOf,
    CLAHE,
    RandomBrightnessContrast,
    RandomGamma
)
def data_num(train_path,mask_path):
    train_img = glob.glob(train_path)
    masks = glob.glob(mask_path)
    return train_img,masks

def mask_aug():
    aug = Compose([
    RandomRotate90(p=1),
    CenterCrop(p=1, height=300, width=300)])
    return aug

def main():
    train_path=r"E:\augmentor\mask_augmentor\original/*.jpg" # 输入img 地址
    mask_path=r"\augmentor\mask_augmentor\segmentation/*.png" # 输入mask 地址
    augtrain_path=r"E:\augmentor\mask_albumentations\augtrain_path" # 输入增强 img 存放地址
    augmask_path=r"E:\augmentor\mask_albumentations\augmask_path" # 输入增强mask 存放地址
    num = 2 # 输入增强图像增强的张数。
    aug = mask_aug()
    train_img,masks = data_num(train_path,mask_path)
    for data in range(len(train_img)):
        for i in range(num):
            image = cv2.imread(train_img[data])
            mask = cv2.imread(masks[data])
            augmented = aug(image=image, mask=mask)
            aug_image = augmented['image']
            aug_mask = augmented['mask']
    cv2.imwrite(augtrain_path+"\\aug_img{}_{}.jpg".format(data,i),aug_image)
    cv2.imwrite(augmask_path+"\\aug_mask{}_{}.png".format(data,i),aug_mask)
    print(data)
    #cv2.imshow("x",aug_image)
    #cv2.imshow("y",aug_mask)
    #cv2.waitKey(0)
if __name__ == "__main__":
    main()

原文地址:https://blog.csdn.net/qq_41920323/article/details/140751609

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!