自学内容网 自学内容网

绘制人体3D关键点

一背景

最近学习了3D人体骨骼关键点检测算法。需要修改可视化3D,在此记录可视化3D骨骼点绘画思路以及代码实现。

二可视化画需求

希望在一张图显示,标签的3D结果,模型预测的3D结果,预测和标签一起的结果,以及对应的图像,并保存视频。

三代码实现

1 读取标签数据

import os, sys, copy, cv2
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import imageio
import io
import matplotlib.animation as animation
import matplotlib as mpl

def string_to_float(data):
    
    return list(map(lambda x:float(x), data))


def read_label(label_txt_file):
    
    if os.path.exists(label_txt_file) == False:
        print('find not file : ', label_txt_file)
        sys.exit(1)
    
    label_dict = {}
    class_id = [0, 1, 2, 3, 4, 5]
    
    with open(label_txt_file, 'r') as f:
        lines = f.readlines()
        
        #line = lines[0:1] + lines[3:9]
        line = lines
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))
        line = np.array(line)
        
        point2d = line[1:, 2:4]
        point2d = np.array(point2d)
        
        line  = np.array(line[..., 4:])

        find_0_id = line[..., -1] == 0.0

        not_find_0_id = ~find_0_id
        line_temp = line[not_find_0_id]
        temp      = line_temp[..., 3] / 1000

        line_temp[..., 2]   = temp
        line[not_find_0_id] = line_temp
        
        data = line[..., 0:3] - line[0, 0:3]
        
    return data, point2d

标签中有2D坐标和3D坐标。

2 获取模型预测数据

def pred_label(pred_txt_file):
    
    if os.path.exists(pred_txt_file) == False:
        print('find not file : ', pred_txt_file)
        sys.exit(1)
    
    all_point = []
    with open(pred_txt_file, 'r') as f:
        all_lines = f.readlines()
        line = all_lines[0:1] + all_lines[1:17]
        
        line = list(map(lambda x:x.strip(), line))
        line = list(map(lambda x:x.split(' '), line))
        line = list(map(lambda x:string_to_float(x), line))

        all_point.append(line)
    
    all_point = np.array(all_point)
    all_points = []
    all_p  = all_point[0, 1:, -3:]
    base_p = all_point[0, 0, -3:]
    
    all_points.append(base_p)
    for i in range(8):
        all_points.append(all_p[i*2 + 1].tolist())
        all_points.append(all_p[i*2].tolist())
    
    all_points = np.array(all_points)
    all_points -= base_p
    
    return all_points

注意,注意,注意 读取标签和模型预测时候,我都减去了根节点的坐标的。

3 绘制3D骨骼图

#画骨骼点代码
def draw3Dpose(label_pose_3d, pred_pose_3d, ax1, ax2, ax3, label_total_ids, pred_total_ids, lcolor="r", rcolor="g", add_labels=False):  # blue, orange
"""
label_pose_3d : 标签3D坐标
pred_pose_3d : 模型预测的3D坐标
ax1, ax2, ax3 子图
label_total_ids : 标签关键点个点连接关系
pred_total_ids : 模型预测的关键点连接关系
""
    colors_keys = [
             '#FF0000',  # 红色
             '#00FF00',  # 绿色
             '#0000FF',  # 蓝色
             '#FFFF00',  # 黄色
             '#FF00FF',  # 洋红
             '#00FFFF',  # 青色
             '#FFA500',  # 橙色
             '#800080',  # 紫色
             '#008000',  # 深绿
             '#000080',  # 深蓝
             '#808000',  # 橄榄绿
             '#800000',  # 栗色
             '#008080',  # 青色
             '#808080',  # 灰色
             '#A52A2A',  # 棕色
             '#D2691E',  # 巧克力色
             '#00FFFF'
         ]
    
    
    for k in range(len(label_total_ids)):
        l_ids = label_total_ids[k]
        p_ids = pred_total_ids[k]
        lx, ly, lz = [np.array([label_pose_3d[l_ids[0], j], label_pose_3d[l_ids[1], j]]) for j in range(3)]
        px, py, pz = [np.array([pred_pose_3d[p_ids[0], j], pred_pose_3d[p_ids[1], j]]) for j in range(3)]
        if l_ids[2] == 3:
            color = 'b'
            ax1.plot(lx, ly, lz, lw=2, c=color)
            ax2.plot(lx, ly, lz, lw=2, c=color)
            ax3.plot(lx, ly, lz, lw=2, c=color)
        
        elif p_ids[2] == 3:
            color = 'b'
            ax1.plot(px, py, pz, lw=2, c=color)
            ax2.plot(px, py, pz, lw=2, c=color)
            ax3.plot(px, py, pz, lw=2, c=color)            
            
        else:
            ax1.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax2.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)            
            ax3.plot(lx, ly, lz, lw=2, c=lcolor if l_ids[2] else rcolor)
            ax3.plot(px, py, pz, lw=2, c=lcolor if p_ids[2] else rcolor)                   
        
        key_color = colors_keys[k]
        ax1.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax2.scatter(px, py, pz, color=key_color, marker='o', s=5)        
        ax3.scatter(lx, ly, lz, color=key_color, marker='o', s=5)
        ax3.scatter(px, py, pz, color=key_color, marker='o', s=5)  
        
    
    ax1.set_xlim3d([-100, 100])
    ax1.set_zlim3d([70, 200])
    ax1.set_ylim3d([-100, 100])    
    
    ax1.set_xlabel("x")
    ax1.set_ylabel("y")
    ax1.set_zlabel("z")
    
    
    ax2.set_xlim3d([-100, 100])
    ax2.set_zlim3d([70, 200])
    ax2.set_ylim3d([-100, 100])    
    
    ax2.set_xlabel("x")
    ax2.set_ylabel("y")
    ax2.set_zlabel("z")    
    
    
    ax3.set_xlim3d([-100, 100])
    ax3.set_zlim3d([70, 200])
    ax3.set_ylim3d([-100, 100])    
    
    ax3.set_xlabel("x")
    ax3.set_ylabel("y")
    ax3.set_zlabel("z")

#把fig转换成图片,用于保存视频.
def get_img_from_fig(fig, dpi=500):
    buf = io.BytesIO()
    #fig.savefig(buf, format='png', dpi=dpi, bbox_inches='tight', pad_inches=0)
    fig.savefig(buf, format='png', dpi=dpi, pad_inches=0.2)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGBA)
    
    return img

def draw_label_test(label_file, label_ids, pred_file, pred_ids, img_file):
    
    names = os.listdir(label_file)
    names = sorted(names)
    
    plt.rcParams['axes.unicode_minus'] = False
    
    #解决中文乱码问题
    font_path = '/home/xx/Downloads/chinese.simhei.ttf'
    mpl.font_manager.fontManager.addfont(font_path)  
    mpl.rc('font', family='SimHei')
    
    fig = plt.figure(figsize=(12, 8))
    ax1 = fig.add_subplot(141, projection='3d')
    ax2 = fig.add_subplot(142, projection='3d')
    ax3 = fig.add_subplot(143, projection='3d')
    ax4 = fig.add_subplot(144)
    
    ax1.view_init(elev=29, azim=-60)
    ax2.view_init(elev=16, azim=-75)
    ax3.view_init(elev=18, azim=-73)
         
    
    output_video = './3d_pose_animation.mp4'
    fps=10
    videwrite = imageio.get_writer(uri=output_video, fps=fps)
    
    plt.ion() 
    for name in names:
        na = name[:-3] + 'txt'
        ax1.cla()
        ax2.cla()
        ax3.cla()
        ax4.cla()
        
        ax1.title.set_text("标签结果结果")
        ax2.title.set_text("模型算法结果")
        ax3.title.set_text("标签和算法结果")
        ax4.title.set_text("原始图片")
        
        label_path = os.path.join(label_file, name)
        pred_path = os.path.join(pred_file, na)
        if os.path.exists(label_path) == False:
            continue
        
        l_data_3d, _ = read_label(label_path)
        p_data_3d = pred_label(pred_path)
        
        l_data_3d *= 100
        l_new_data_3d = l_data_3d[..., [0, 2, 1]]
        l_new_data_3d[..., 2] = 200 - l_new_data_3d[..., 2]
        
        p_data_3d *= 100
        p_new_data_3d = p_data_3d[..., [0, 2, 1]]
        p_new_data_3d[..., 2] = 200 - p_new_data_3d[..., 2]        
        
        img_name = name[:-3] + 'png'
        img_path = os.path.join(img_file, img_name)
        img = np.array(Image.open(img_path)) 
        draw3Dpose(l_new_data_3d, p_new_data_3d, ax1, ax2, ax3, label_ids, pred_ids)
        ax4.imshow(img)
        
        #plt.pause(0.01)
        frame_vis = get_img_from_fig(fig)
        videwrite.append_data(frame_vis)
                
    
    videwrite.close()
    plt.tight_layout()
 
    plt.ioff()
    print("save out video")
    plt.show()

if __name__ == "__main__":
label_path = '/home/xx/Desktop/simcc_3d/temp/select_label_txt'
    pred_path  = '/home/xx/Desktop/simcc_3d/temp/out_txt'
    img_file   = '/home/xx/Desktop/simcc_3d/temp/val_img'

label_ids = [[0, 1, 1], [1, 2, 1], [1, 3, 1], [1, 4, 1], [3, 5, 1], 
                 [5, 7, 1], [4, 6, 1], [6, 8, 1],
                 ]
    
    
    pred_ids = [[0, 6, 0], [6, 8, 0], [8, 10, 0], [0, 5, 0], [5, 7, 0], 
                [7, 9, 0], [0, 0, 0], [0, 0, 0],
                 ]
    
    draw_label_test(label_path, label_ids, pred_path, pred_ids, img_file)

四总结

以上代码都是只是演示,只适用于我自己的场景,其他场景需要修改标签数据,关键点连接关系,该代码仅供参考,不可照搬。


原文地址:https://blog.csdn.net/qq_43318374/article/details/145034040

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!