【利用GroundingDINO裁剪分类任务的数据集】及文本提示检测图像任意目标(Grounding DINO) 的使用
背景
- 在处理公开数据集ImageNet-21k的时候发现里面有很多的数据有问题,比如,数据目标有很多背景,且部分类别有其他种类的图片。
- 针对数据目标有很多背景,公开数据集ImageNet-21k的21k种类别进行裁剪。
- 文本提示检测图像任意目标(Grounding DINO),这更模型可以很好的应用在这个场景。
1.Grounding DINO安装
- 从 GitHub 克隆 GroundingDINO 存储库。
git clone https://github.com/IDEA-Research/GroundingDINO.git
- 将当前目录更改为 GroundingDINO 文件夹。
cd GroundingDINO/
- 在当前目录中安装所需的依赖项。
pip install -e .
- 下载预训练模型权重。
mkdir weights
cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
- 下载bert-base-uncased到text_encoder_type(自己创建一个文件夹)
需要下载下面的三个文件,放进text_encoder_type里面就好。
- 修改地址
修改/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
文件中text_encoder_type
的路径。
-
如果您有 CUDA 环境,请确保设置了环境变量 CUDA_HOME 。如果没有可用的 CUDA,它将在仅 CPU 模式下编译。
-
可能遇到的bug
Segmentation fault (core dumped)
是因为timm版本和cuda,pytorch等版本不匹配重新安装可以解决这个bug。
pip uninstall timm
pip install timm
2.裁剪指定目标的脚本
- 如下是测试的demo
import cv2
print("456")
from groundingdino.util.inference import load_model, load_image, predict, annotate
print("123")
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth", "cpu")
IMAGE_PATH = r"images/th.jpg"
TEXT_PROMPT = "dolphins"
BOX_TRESHOLD = 0.35
TEXT_TRESHOLD = 0.25
print("456")
image_source, image = load_image(IMAGE_PATH)
print("789")
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
print("10")
print(boxes)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
cv2.imwrite("annotated_image.jpg", annotated_frame)
- 裁剪指定目标的脚本
该脚本指定目录后,会对该目录下子文件夹的不同目标类别,进行裁剪并将裁剪结果放在与原路径对应的相对路径种。
脚本全部代码:
import os
import time
from groundingdino.util.inference import load_model, load_image, predict
import cv2
import torch
from torchvision.ops import box_convert
def save_cropped_images(image, boxes, image_name, output_folder):
os.makedirs(output_folder, exist_ok=True)
h, w, _ = image.shape
boxes = boxes * torch.tensor([w, h, w, h])
xyxy_boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy()
for i, box in enumerate(xyxy_boxes):
x_min, y_min, x_max, y_max = map(int, box)
cropped_image = image[y_min:y_max, x_min:x_max]
# Ensure the color channels are in BGR order for OpenCV
cropped_image_bgr = cv2.cvtColor(cropped_image, cv2.COLOR_RGB2BGR)
cv2.imwrite(f"{output_folder}/{image_name}_cropped_{i}.jpg", cropped_image_bgr)
def process_image(image_path, model, output_folder, box_threshold=0.35, text_threshold=0.25):
image_source, image = load_image(image_path)
try:
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=box_threshold,
text_threshold=text_threshold
)
except RuntimeError as e:
print(f"RuntimeError: {e}")
# Get the image name without extension
image_name = os.path.splitext(os.path.basename(image_path))[0]
# Save cropped images with image name included
save_cropped_images(image_source, boxes, image_name, output_folder)
def process_images_in_folder(folder_path, model, box_threshold=0.35, text_threshold=0.25):
folder_name = os.path.basename(folder_path.rstrip('/'))
output_folder = os.path.join("/animals_classify/Cropped_Dataset/QuanKe", folder_name)
print(f"{folder_name}, cropping.")
# Start timer for processing this folder
start_time = time.time()
for filename in os.listdir(folder_path):
if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".JPEG"):
image_path = os.path.join(folder_path, filename)
process_image(image_path, model, output_folder, box_threshold, text_threshold)
# End timer for processing this folder
folder_processing_time = time.time() - start_time
process_images_in_folder.total_time += folder_processing_time
print(f"{folder_name}, cropped. Time taken: {folder_processing_time:.2f} seconds")
print(f"Total time taken so far: {process_images_in_folder.total_time:.2f} seconds")
# Initialize the total time taken to 0
process_images_in_folder.total_time = 0.0
# Configuration and model loading
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", "weight/groundingdino_swint_ogc.pth")
TEXT_PROMPT = "canine"
BOX_THRESHOLD = 0.35
TEXT_THRESHOLD = 0.25
FOLDERS_PATH = "/animals_classify/Raw_Dataset/QuanKe"
for FOLDER_Name in os.listdir(FOLDERS_PATH):
FOLDER_PATH = os.path.join(FOLDERS_PATH, FOLDER_Name)
# Process all images in the folder
process_images_in_folder(FOLDER_PATH, model, BOX_THRESHOLD, TEXT_THRESHOLD)
裁剪示例:
原图:
结果:
原文地址:https://blog.csdn.net/ban102055/article/details/140329322
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!