【21天学习AI底层概念】day13 (kaggle新手入门教程)Exercise: Underfitting and Overfitting
网址:https://www.kaggle.com/code/meirou674/exercise-underfitting-and-overfitting/edit
代码
1.回顾
# Code you have previously used to load data
import pandas as pd
from sklearn.metrics import mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
# Path of the file to read
iowa_file_path = '../input/home-data-for-ml-course/train.csv'
home_data = pd.read_csv(iowa_file_path)
# Create target object and call it y
y = home_data.SalePrice
# Create X
features = ['LotArea', 'YearBuilt', '1stFlrSF', '2ndFlrSF', 'FullBath', 'BedroomAbvGr', 'TotRmsAbvGrd']
X = home_data[features]
# Split into validation and training data
train_X, val_X, train_y, val_y = train_test_split(X, y, random_state=1)
# Specify Model
iowa_model = DecisionTreeRegressor(random_state=1)
# Fit Model
iowa_model.fit(train_X, train_y)
# Make validation predictions and calculate mean absolute error
val_predictions = iowa_model.predict(val_X)
val_mae = mean_absolute_error(val_predictions, val_y)
print("Validation MAE: {:,.0f}".format(val_mae))
# Set up code checking
from learntools.core import binder
binder.bind(globals())
from learntools.machine_learning.ex5 import *
print("\nSetup complete")
输出:
Validation MAE: 29,653
没做优化前,模型的MAE
2.Step 1: Compare Different Tree Sizes
以下封装一个计算MAE的方法,入参是最大节点数、训练集、验证集,输出MAE
def get_mae(max_leaf_nodes, train_X, val_X, train_y, val_y):
model = DecisionTreeRegressor(max_leaf_nodes=max_leaf_nodes, random_state=0)
model.fit(train_X, train_y)
preds_val = model.predict(val_X)
mae = mean_absolute_error(val_y, preds_val)
return(mae)
以下是精髓,找到最佳的“最大节点数”,代码过于简洁,看不懂,解释放在后面
candidate_max_leaf_nodes = [5, 25, 50, 100, 250, 500]
# Write loop to find the ideal tree size from candidate_max_leaf_nodes
scores = {leaf_nodes:get_mae(leaf_nodes,train_X, val_X, train_y, val_y) for leaf_nodes in candidate_max_leaf_nodes}
# Store the best value of max_leaf_nodes (it will be either 5, 25, 50, 100, 250 or 500)
best_tree_size = min(scores,key=scores.get)
print(best_tree_size)
好的!我们来一步步拆解这段代码并解释每一部分的含义。
原代码:
scores = {leaf_size: get_mae(leaf_size, train_X, val_X, train_y, val_y) for leaf_size in candidate_max_leaf_nodes}
best_tree_size = min(scores, key=scores.get)
第一行代码解释
scores = {leaf_size: get_mae(leaf_size, train_X, val_X, train_y, val_y) for leaf_size in candidate_max_leaf_nodes}
1. 这是一个字典推导式:
- 字典推导式的格式是:
{key: value for item in iterable}
。 - 它会生成一个字典,
key
和value
是动态计算出来的。
2. 代码作用:
- 遍历列表
candidate_max_leaf_nodes
中的每个值,把每个值命名为leaf_size
。 - 调用函数
get_mae()
,传入当前的leaf_size
和其他参数(train_X, val_X, train_y, val_y
),计算出一个误差值(MAE
,平均绝对误差)。 - 把
leaf_size
作为字典的键(key
),对应的误差值作为值(value
),组成一个新的字典scores
。
3. 假设数据:
candidate_max_leaf_nodes
= [5, 10, 20]- 每次
get_mae()
的结果可能是:get_mae(5, ...)
= 0.25get_mae(10, ...)
= 0.20get_mae(20, ...)
= 0.30
- 最终生成的
scores
:scores = {5: 0.25, 10: 0.20, 20: 0.30}
第二行代码解释
best_tree_size = min(scores, key=scores.get)
1. min()
函数作用:
- 找出
scores
字典中某个键,使得对应的值最小。 key=scores.get
表示比较的是字典中每个键对应的值,而不是直接比较键本身。
2. 分解步骤:
- 遍历
scores
中的键(如5, 10, 20
)。 - 对于每个键,计算对应的值(如
0.25, 0.20, 0.30
)。 - 找出值最小的键,作为
best_tree_size
。
3. 继续用假设数据:
scores = {5: 0.25, 10: 0.20, 20: 0.30}
- 最小值是
0.20
,对应的键是10
。 - 所以
best_tree_size = 10
。
总结这段代码的作用
- 遍历不同的叶子节点数量(
leaf_size
),计算每个数量下模型的误差(MAE
)。 - 从中找出误差最小的叶子节点数量,赋值给
best_tree_size
。
补充:字典推导式和 min()
的简单例子
字典推导式:
numbers = [1, 2, 3, 4]
squares = {x: x**2 for x in numbers}
print(squares) # 输出:{1: 1, 2: 4, 3: 9, 4: 16}
min()
用法:
values = {5: 10, 3: 7, 8: 2}
result = min(values, key=values.get)
print(result) # 输出:8,因为值 2 是最小的,键是 8
3.Step 2: Fit Model Using All Data
第一步算出最佳的“最大节点数”,第二步就可以建一个新的决策树模型,它的max_leaf_nodes=best_tree_size,用整个数据集去训练模型了
# Fill in argument to make optimal size and uncomment
final_model = DecisionTreeRegressor(max_leaf_nodes=best_tree_size,random_state=1)
# fit the final model and uncomment the next two lines
final_model.fit(X, y)
最终,我们基于现有的数据集,训练出了一个最可靠的模型,如果我们得到一批新的X,那么就可以预测y了!(我觉得效果应该也不好,但肯定比优化前好了,数学真伟大)
原文地址:https://blog.csdn.net/keira674/article/details/145161042
免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!