目标检测、语义分割离线增强脚本
主要利用Augmentor库进行离线数据增强,修改文件地址即可。
一、Augmentor 简介:
Augmentor 是一个Python 包,旨在帮助机器学习任务的图像数据人工生成和数据增强。它主要是一
种数据增强工具,但也将包含基本的图像预处理功能。Augmentor 包含许多用于标准图像处理功能的
类,例如Rotate 旋转类、Crop 裁剪类等等。包含的操作有:旋转rotate、裁剪crop、透视perspective
skewing、shearing、弹性形变Elastic Distortions、亮度、对比度、颜色等等;更多的操作及其参数
设定。
二、目标检测数据增强
import xml.dom.minidom
import cv2
from albumentations import(
BboxParams, RandomGamma,Compose,Blur,CenterCrop,HueSaturationValue,
MotionBlur,Cutout)
import os
import glob
def read_xml(path):
exp_xml = []
dom = xml.dom.minidom.parse(path)
root = dom.documentElement
img_name = root.getElementsByTagName("filename")[0]
exp_xml.append(img_name.childNodes[0].data+".jpg")
#print("fileneme:%s"%img_name.childNodes[0].data)
label = root.getElementsByTagName("name")[0]
exp_xml.append(label.childNodes[0].data)
bonbox_xmin = root.getElementsByTagName("xmin")[0]
exp_xml.append(bonbox_xmin.childNodes[0].data)
bonbox_ymin = root.getElementsByTagName("ymin")[0]
exp_xml.append(bonbox_ymin.childNodes[0].data)
bonbox_xmax = root.getElementsByTagName("xmax")[0]
exp_xml.append(bonbox_xmax.childNodes[0].data)
bonbox_ymax = root.getElementsByTagName("ymax")[0]
exp_xml.append(bonbox_ymax.childNodes[0].data)
return exp_xml
def modify_xml(path,bbox,new_img_name,aug_file,num,n):
new_dom = xml.dom.minidom.parse(path)
new_root = new_dom.documentElement
new_img_xml_name = new_root.getElementsByTagName("filename")[0]
new_img_xml_name.childNodes[0].data = new_img_name
new_bonbox_xmin = new_root.getElementsByTagName("xmin")[0]
new_bonbox_xmin.childNodes[0].data = bbox[0]
new_bonbox_ymin = new_root.getElementsByTagName("ymin")[0]
new_bonbox_ymin.childNodes[0].data = bbox[1]
new_bonbox_xmax = new_root.getElementsByTagName("xmax")[0]
new_bonbox_xmax.childNodes[0].data = bbox[2]
new_bonbox_ymax = new_root.getElementsByTagName("ymax")[0]
new_bonbox_ymax.childNodes[0].data = bbox[3]
with open(os.path.join(aug_file,aug_file+"\\aug_img{}_{}.xml".format(n,num)), 'w') as fh:
new_dom.writexml(fh)
def visualize_bbox(img, bbox, class_id, class_idx_to_name):
bbox = list(bbox)
x_min, y_min, x_max, y_max = bbox
x_min = int(x_min)
y_min = int(y_min)
x_max = int(x_max)
y_max = int(y_max)
image = cv2.rectangle(img, (x_min, y_min), (x_max, y_max), (255,0,0), 2)
class_name = class_idx_to_name[class_id]
((text_width, text_height), _) = cv2.getTextSize(class_name,cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
cv2.rectangle(image, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), (255,0,0), -1)
cv2.putText(image, class_name, (x_min, y_min - int(0.3 * text_height)),cv2.FONT_HERSHEY_SIMPLEX, 0.35,(255,255,255), lineType=cv2.LINE_AA)
return image
def get_aug(aug, min_area=0., min_visibility=0.):
return Compose(aug, bbox_params=BboxParams(format='pascal_voc',min_area=min_area,
min_visibility=min_visibility,label_fields=["category_id"]))
def augment():
aug = Compose([
Blur(blur_limit = 7,p = 0.3),
RandomGamma(gamma_limit=(80,120),p=0.5),
CenterCrop(height=400, width=400, p=0.2),
HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.3),
MotionBlur(blur_limit=7, p=0.5),
Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0,always_apply=False, p=0.2)])
return aug
def get_data(xml_date_path,img_path):
xml_date = read_xml(xml_date_path)
img = cv2.imread(img_path + '\\' + xml_date[0])
bbox = [[int(xml_date[2]),int(xml_date[3]),int(xml_date[4]),int(xml_date[5])]]
return bbox,img
def keep_aug_img(annotations):
aug_img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
bbox = list(bbox)
x_min, y_min, x_max, y_max = bbox
x_min = int(x_min)
y_min = int(y_min)
x_max = int(x_max)
y_max = int(y_max)
aug_bbox = [x_min,y_min,x_max,y_max]
return aug_img,aug_bbox
def visualize(annotations, category_id_to_name):
img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
return img
def main():
xml_img_path = r"H:\darknet_y\data\car" # 存放xml 和img 数据地址
aug_file = r"H:\darknet_y\data\car\aug_file" # 增强xml 和img 数据存放地址
shample = 6 # 需要增强的次数
for n in range(shample):
num= 0
print(" 第%d次"%n)
for xml_name in glob.glob(xml_img_path + "/*.xml"): # 循环
print(" 第%d张图片"%num)
bbox,img = get_data(xml_name,xml_img_path) # 获得img 以及xml 中bbox 坐标。
annotations = {'image': img, 'bboxes': bbox, 'category_id': [1]}
aug = augment()
augmented = aug(**annotations)
#category_id_to_name = {1:"juanyuanzi"}
#img,bbox = visualize(augmented, category_id_to_name)
#cv2.imshow("x",img)
#cv2.waitKey(0)
""" 可视化"""
aug_img,aug_bbox = keep_aug_img(augmented) # 增强后的图像及相对应的坐标
cv2.imwrite(aug_file+"\\aug_img{}_{}.jpg".format(n,num),aug_img)
#img 保存
new_xml_path = os.path.split(aug_file+"\\aug_img{}_{}.jpg".format(n,num))[1] # 获取增强xml 地址
new_xml_name = new_xml_path.split(".")[0] # 获取xml 名字
modify_xml(xml_name,aug_bbox,new_xml_name,aug_file,num,n) # 对xml 文件进行修改
num += 1
if __name__ == "__main__":
main()
三、语义分割数据增强
import glob
import cv2
from albumentations import (
PadIfNeeded,
HorizontalFlip,
VerticalFlip,
CenterCrop,
Crop,
Compose,
Transpose,
RandomRotate90,
ElasticTransform,
GridDistortion,
OpticalDistortion,
RandomSizedCrop,
OneOf,
CLAHE,
RandomBrightnessContrast,
RandomGamma
)
def data_num(train_path,mask_path):
train_img = glob.glob(train_path)
masks = glob.glob(mask_path)
return train_img,masks
def mask_aug():
aug = Compose([
RandomRotate90(p=1),
CenterCrop(p=1, height=300, width=300)])
return aug
def main():
train_path=r"E:\augmentor\mask_augmentor\original/*.jpg" # 输入img 地址
mask_path=r"\augmentor\mask_augmentor\segmentation/*.png" # 输入mask 地址
augtrain_path=r"E:\augmentor\mask_albumentations\augtrain_path" # 输入增强 img 存放地址
augmask_path=r"E:\augmentor\mask_albumentations\augmask_path" # 输入增强mask 存放地址
num = 2 # 输入增强图像增强的张数。
aug = mask_aug()
train_img,masks = data_num(train_path,mask_path)
for data in range(len(train_img)):
for i in range(num):
image = cv2.imread(train_img[data])
mask = cv2.imread(masks[data])
augmented = aug(image=image, mask=mask)
aug_image = augmented['image']
aug_mask = augmented['mask']
cv2.imwrite(augtrain_path+"\\aug_img{}_{}.jpg".format(data,i),aug_image)
cv2.imwrite(augmask_path+"\\aug_mask{}_{}.png".format(data,i),aug_mask)
print(data)
#cv2.imshow("x",aug_image)
#cv2.imshow("y",aug_mask)
#cv2.waitKey(0)
if __name__ == "__main__":
main()
原文地址:https://blog.csdn.net/qq_41920323/article/details/140751609
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!