自学内容网 自学内容网

【python】sklearn基础教程及示例

【python】sklearn基础教程及示例


Scikit-learn(简称sklearn)是一个非常流行的Python机器学习库,提供了许多常用的机器学习算法和工具。以下是一个基础教程的概述:


 1. 安装scikit-learn


首先,确保你已经安装了Python和pip,然后使用以下命令安装scikit-learn:

pip install -U scikit-learn

2. 导入库

在你的Python脚本或Jupyter Notebook中,首先导入scikit-learn库:

import sklearn

3. 加载数据

你可以加载各种数据集,包括样本数据集和真实世界数据集。例如,加载经典的鸢尾花数据集:

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data  # 特征矩阵
y = iris.target  # 目标向量

4. 数据预处理

在应用机器学习算法之前,通常需要进行一些数据预处理,例如特征缩放、特征选择、数据清洗等。以下是一些常用的数据预处理方法:

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

5. 数据拆分

将数据集拆分为训练集和测试集:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

6. 建立模型

使用各种机器学习算法来建立模型,例如逻辑回归:

from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
model.fit(X_train, y_train)

7. 模型评估

在训练模型之后,评估模型的性能,例如使用准确度评估:

from sklearn.metrics import accuracy_score
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

8. 交叉验证

使用交叉验证来评估模型的稳定性和泛化能力:

from sklearn.model_selection import cross_validate
result = cross_validate(model, X, y, cv=5)
print(result['test_score'])

sklearn示例

1.简单例子:鸢尾花分类

这是一个经典的机器学习任务,用于分类鸢尾花的种类。

load_iris 是一个经典的机器学习数据集,通常用于分类和聚类任务。这个数据集包含了三种不同种类的鸢尾花(Iris Setosa、Iris Versicolour 和 Iris Virginica)的信息,每种鸢尾花有四个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。

具体来说,load_iris 数据集包含以下内容:

  • 150个样本:每种鸢尾花各50个样本。
  • 4个特征:花萼长度、花萼宽度、花瓣长度和花瓣宽度。
  • 目标标签:每个样本的目标类别标签,分别为0(Setosa)、1(Versicolour)和2(Virginica)。
  • StandardScaler 是 scikit-learn 库中的一个类,用于对数据进行标准化处理。标准化的目的是将数据的特征缩放到相同的尺度,通常是均值为0,标准差为1。这对于许多机器学习算法来说是非常重要的,特别是那些基于距离的算法(如K-近邻、支持向量机等)和需要计算协方差矩阵的算法(如PCA、线性回归等)。

# 导入必要的库
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 建立和训练模型
model = LogisticRegression()
model.fit(X_train, y_train)

# 预测和评估
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")

2.复杂例子:手写数字识别

这个例子使用手写数字数据集,并应用支持向量机(SVM)进行分类。

load_digits 是 scikit-learn 提供的一个经典数据集,用于手写数字识别任务。这个数据集包含了 0 到 9 共 10 个数字的手写图像,每个图像是一个 8x8 的灰度图像。

  • 数据集内容 样本数量:1797 个手写数字图像。
  • 特征维度:每个图像有 64 个特征(8x8 像素)。
  • 特征值:每个特征值是一个整数,范围从 0 到 16,表示像素的灰度值。
  • 目标标签:每个样本对应一个目标标签,表示数字 0 到 9。

# 导入必要的库
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import classification_report

# 加载数据集
digits = load_digits()
X = digits.data
y = digits.target

# 数据预处理
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

# 使用网格搜索进行超参数调优
param_grid = {'C': [0.1, 1, 10, 100], 'gamma': [1, 0.1, 0.01, 0.001]}
grid = GridSearchCV(SVC(), param_grid, refit=True, verbose=2)
grid.fit(X_train, y_train)

# 最佳参数和模型评估
print(f"Best Parameters: {grid.best_params_}")
y_pred = grid.predict(X_test)
print(classification_report(y_test, y_pred))

在这个复杂的例子中,我们使用了网格搜索(GridSearchCV)来找到支持向量机(SVM)的最佳超参数,并使用分类报告(classification_report)来评估模型的性能。

  • param_grid:这是一个字典,定义了要搜索的参数范围。在这个例子中,我们要调整两个参数:
    • C:正则化参数,控制模型的复杂度。较小的 C 值会使模型更简单,但可能欠拟合;较大的 C 值会使模型更复杂,但可能过拟合。
    • gamma:核函数系数,控制单个训练样本的影响范围。较大的 gamma 值会使模型更复杂,但可能过拟合;较小的 gamma 值会使模型更简单,但可能欠拟合。
  • GridSearchCV:这是 scikit-learn 提供的一个工具,用于通过交叉验证来搜索最佳参数组合。
    • SVC():支持向量机分类器。
    • param_grid:要搜索的参数网格。
    • refit=True:在找到最佳参数组合后,使用整个训练集重新训练模型。
    • verbose=2:设置详细程度,输出更多的搜索过程信息。


原文地址:https://blog.csdn.net/weixin_44502754/article/details/140683201

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!