自学内容网 自学内容网

CatBoost模型Python代码——用CatBoost模型实现机器学习

一、CatBoost模型简介

1.1适用范围

CatBoost(Categorical Boosting)是一种基于梯度提升的机器学习算法,特别适用于处理具有类别特征的数据集。它可以用于分类、回归和排序任务,并且在处理具有大量类别特征的数据时表现优异。典型应用包括但不限于:

  • 电子商务中的推荐系统
  • 客户行为分析
  • 财务风险评估
  • 医疗数据分析
1.2原理

CatBoost使用梯度提升决策树(GBDT)作为其核心算法。其主要特点包括:

  1. 处理类别特征:CatBoost原生支持类别特征,并在内部使用目标编码(target encoding)来处理它们,从而减少了类别变量处理的复杂性。
  2. 顺序增强(Ordered Boosting):在构建每棵树时,CatBoost通过引入一种新的顺序提升方法来避免传统梯度提升中的预测偏差问题。
  3. 随机分片:为了进一步减少过拟合,CatBoost在每次树构建时随机分割数据集。
1.3优点
  • 高效处理类别特征:无需复杂的预处理步骤。
  • 减少过拟合:通过顺序增强和随机分片技术。
  • 易于使用:内置了许多默认的优化参数,适合初学者和快速原型开发。
  • 高性能:在许多实际应用中表现优于其他GBDT算法(如XGBoost和LightGBM)。
1.4缺点
  • 模型训练时间较长:尽管有许多优化,训练时间可能比其他简单模型更长。
  • 内存占用较高:在处理大规模数据时,内存需求较大。

二、实现CatBoost模型的Python代码

下面是一个使用CatBoost进行分类任务的完整Python代码示例,包含详细注释。

2.1导入必要的包和测试数据
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns

# 加载Titanic数据集
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
data = pd.read_csv(url)

# 查看数据集的列名
print("Columns in the dataset:", data.columns)
2.2简单的数据预处理
# 简单的数据预处理
# 填充缺失值
# data['Age'].fillna(data['Age'].median(), inplace=True)
# data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)

# 将Sex和Embarked转换为类别型特征
data['Sex'] = data['Sex'].astype('category')
# data['Pclass'] = data['Pclass'].astype('Pclass')

# 选择特征和目标
features = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']
target = 'Survived'

X = data[features]
y = data[target]
2.3构建CatBoost模型
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建CatBoost数据池
categorical_features = ['Sex', 'Pclass']
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)

# 初始化并训练CatBoost分类器
model = CatBoostClassifier(
    iterations=1000,
    learning_rate=0.1,
    depth=6,
    loss_function='Logloss',  # 二分类任务使用'Logloss'
    verbose=100  # 每100次迭代打印一次信息
)

# 训练模型
model.fit(train_pool)

# 在测试集上进行预测
y_pred = model.predict(test_pool)
y_pred_proba = model.predict_proba(test_pool)[:, 1]
2.4模型评估
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred))

模型评估输出结果如下 :

0:learn: 0.6538633total: 159msremaining: 2m 39s
100:learn: 0.2814504total: 891msremaining: 7.93s
200:learn: 0.2007734total: 1.68sremaining: 6.68s
300:learn: 0.1536222total: 2.45sremaining: 5.69s
400:learn: 0.1220845total: 3.19sremaining: 4.77s
500:learn: 0.0961718total: 3.95sremaining: 3.93s
600:learn: 0.0810769total: 4.7sremaining: 3.12s
700:learn: 0.0694396total: 5.45sremaining: 2.33s
800:learn: 0.0598153total: 6.2sremaining: 1.54s
900:learn: 0.0527771total: 6.93sremaining: 761ms
999:learn: 0.0474017total: 7.67sremaining: 0us
Accuracy: 0.8033707865168539
              precision    recall  f1-score   support

           0       0.84      0.85      0.84       111
           1       0.74      0.73      0.74        67

    accuracy                           0.80       178
   macro avg       0.79      0.79      0.79       178
weighted avg       0.80      0.80      0.80       178

Feature: Pclass, Importance: 16.480181005946406
Feature: Sex, Importance: 24.322199798316337
Feature: Age, Importance: 27.28642174968946
Feature: Siblings/Spouses Aboard, Importance: 5.125530737270014
Feature: Parents/Children Aboard, Importance: 3.006729091175773
Feature: Fare, Importance: 23.77893761760206
2.5可视化特征重要性(可选)
# 可视化特征重要性(可选)
plt.figure(figsize=(10, 6))
plt.barh(X.columns, feature_importances)
plt.xlabel('Feature Importance')
plt.title('CatBoost Feature Importances')
plt.show()

特征重要性输出结果如下:

 2.6绘制混淆矩阵
# 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

绘制混淆矩阵输出结果如下:

2.7绘制ROC曲线
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()

绘制ROC曲线输出结果如下:


原文地址:https://blog.csdn.net/qq_41698317/article/details/140531984

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!