自学内容网 自学内容网

【Python绘图】两种绘制混淆矩阵的方式 (ConfusionMatrixDisplay(), imshow()) 以及两种好看的colorbar

在机器学习领域,混淆矩阵是一个评估分类模型性能的重要工具。它不仅展示了模型预测的准确性,还揭示了模型在不同类别上的表现。本文介绍两种在Python中绘制混淆矩阵的方法:ConfusionMatrixDisplay()imshow(),以及两种好看的colorbar:coolwarm_rGnBu 以增强可视化效果。



ConfusionMatrixDisplay()

ConfusionMatrixDisplay() 是一个来自 scikit-learn 库的类,用于可视化混淆矩阵。

sklearn.metrics.ConfusionMatrixDisplay 的官方社区描述:

基本用法:

ConfusionMatrixDisplay 可以通过以下方式创建:

from sklearn.metrics import ConfusionMatrixDisplay

# 假设 cm 是一个混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()

参数和方法:

  • confusion_matrix: 参数,一个形状为 (n_classes, n_classes) 的 ndarray,表示混淆矩阵。
  • display_labels: 参数,一个形状为 (n_classes,) 的 ndarray,默认为 None。用于设置绘图时的标签。如果为 None,则显示标签从 0 到 n_classes - 1。
  • plot(): 方法,绘制混淆矩阵的可视化。

示例:

在这里插入图片描述

在这里插入图片描述

示例代码:

from sklearn.metrics import ConfusionMatrixDisplay
import os
import matplotlib.pyplot as plt

import numpy as np
import numpy.random as npr
npr.seed(0)

# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)

# Generate random data 0~1
n = 10
data = npr.rand(n, n) * 0.8
for i in range(n):
    data[i, i] = 1.0

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 8))

cm = ConfusionMatrixDisplay(data, display_labels=np.arange(n))
cm.plot(ax=ax, cmap="GnBu", include_values=False, xticks_rotation=90)  # GnBu, coolwarm_r

ax.set_xlabel('Trials', fontsize=20)
ax.set_ylabel('Trials', fontsize=20)

plt.title(f'Confusion matrix', fontsize=30)
plt.tight_layout()

plt.savefig(f'{save_path}/confu_mat_1-2.png', dpi=300)
plt.show()


imshow()

imshow() 是一个来自 Matplotlib 库的函数,用于在图形用户界面(GUI)中显示图像。这个函数可以处理多种类型的图像数据,包括灰度图和彩色图,是 Matplotlib 中用于图像显示的基础函数之一。

matplotlib.pyplot.imshow 的官方描述:https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html

基本用法:

import matplotlib.pyplot as plt
import numpy as np

# 创建一个随机数组作为图像数据
image_data = np.random.rand(10, 10)

# 使用 imshow() 显示图像
plt.imshow(image_data)
plt.colorbar()  # 显示颜色条
plt.show()

参数:

imshow() 函数接受多个参数,以下是一些常用的参数:

  • X: 图像数据,可以是 2D 数组(灰度图)或 3D 数组(彩色图)。
  • cmap: 颜色映射表,用于定义颜色。例如,cmap=‘gray’ 表示灰度图,cmap=‘viridis’ 是一种常用的彩色映射。
  • norm: 归一化对象,用于调整数据值到 [0, 1] 范围。
  • aspect: 图像的纵横比,可以是 ‘auto’、‘equal’ 或一个数值。
  • interpolation: 插值方法,用于定义图像的缩放方式,如 ‘nearest’、‘bilinear’、‘bicubic’ 等。
  • alpha: 图像的透明度。

imshow() 返回一个 AxesImage 对象,这个对象包含了图像的显示信息,可以用来进一步定制图像的显示效果。

示例:

在这里插入图片描述

在这里插入图片描述

  • ConfusionMatrixDisplay()内置函数定义了所绘制的混淆矩阵必须为方针,而imshow()可以绘制行列数不等的矩形:

在这里插入图片描述

在这里插入图片描述

示例代码:

from mpl_toolkits.axes_grid1 import make_axes_locatable

import os
import matplotlib.pyplot as plt

import numpy as np
import numpy.random as npr
npr.seed(0)

# Save path
save_path = './plot'
os.makedirs(save_path, exist_ok=True)

# Generate random data 0~1
m = 6
n = 10
data = npr.rand(m, n) * 0.8
if m == n:
    for i in range(n):
        data[i, i] = 1.0

fig, ax = plt.subplots(figsize=(n, m))
cm = ax.imshow(data, cmap='coolwarm_r', interpolation="nearest", vmin=0.0, vmax=1.0)  # coolwarm_r, GnBu

# # 绘制一条对角线
# ax.plot([-0.5, n + 0.5], [-0.5, n + 0.5], color='black', alpha=0.2)

ax.set_xticks(np.arange(n))
ax.set_yticks(np.arange(m))

ax.set_xticklabels(np.arange(n), fontsize=15, rotation=90)
ax.set_yticklabels(np.arange(m), fontsize=15)

plt.xlabel('N', fontsize=20)
plt.ylabel('M', fontsize=20)

plt.title(f'Confusion matrix', fontsize=30)

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="4%", pad=0.2)
cb = fig.colorbar(cm, cax=cax)
cb.ax.tick_params(labelsize=15)

plt.tight_layout()

plt.savefig(f'{save_path}/confu_mat_3-1.png', dpi=300)
plt.show()


两种 colorbar

  • coolwarm_r
    在这里插入图片描述

  • GnBu
    在这里插入图片描述

更多 colorbar:https://astromsshin.github.io/science/code/matplotlib_cm/index.html
在这里插入图片描述


创作不易,麻烦点点赞和关注咯!


原文地址:https://blog.csdn.net/qq_43811536/article/details/143746238

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!