决策树
概念: 一种树形结构的分类模型,通过不断提问(特征判断)将数据划分到不同子节点,最终在叶子节点给出预测结果。
三要素
-
特征选择:选择最有用的特征先提问
如何选择?
- 信息增益:哪种特征能最大程度减少混乱程度(减少熵,增加信息增益最大的先提问)——IG = 原始熵 - 分裂后的平均熵
- 基尼系数:反映了样本的纯度(0 表示样本都是一类,完全纯净;1 表示样本分布均匀,混乱度高)。随机抽两个东西,它们不是同类的概率有多高(降低基尼系数最大的先提问)
-
节点分类:根据答案划分为不同的分支
-
停止条件:什么时候停止提问
参数
max_depth:树最多能问多少个问题min_samples:控制树的深度,至少有多少样本才继续分裂criterion:选择评分标准
代码开发步骤
- 数据准备
- 将数据划分训练集和测试集
- 将训练集按决策树训练,提供训练深度和随机种子
- 预测测试集
- 将测试集结果和真实结果做准确性等检验评估
- 可视化
鸢尾花分类决策树实战
# -*- coding: utf-8 -*-
''' 鸢尾花分类决策树实战 '''
# 1. 导入库
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
# === 1. 全局配置 ===
# 解决中文显示和Matplotlib警告
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文字体
plt.rcParams['axes.unicode_minus'] = False # 修复负号显示
import matplotlib
matplotlib.use('TkAgg') # 非交互式后端,避免警告
# 2. 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data # 特征矩阵 (150个样本 x 4个特征)
y = iris.target # 目标标签 (0,1,2对应三个品种)
# 查看数据基本信息
print("特征名称:", iris.feature_names) # 四个特征:萼片长宽 + 花瓣长宽
print("类别名称:", iris.target_names) # 三个品种
# 3. 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.3, # 30%数据作为测试集
random_state=42 # 随机种子保证结果可复现
)
# 4. 创建决策树模型
clf = DecisionTreeClassifier(
max_depth=3, # 控制树的最大深度防止过拟合
random_state=42
)
# 5. 训练模型
clf.fit(X_train, y_train) # 传入训练数据学习规律
# 6. 预测测试集
y_pred = clf.predict(X_test)
# 7. 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2%}")
# 8. 可视化决策树
plt.figure(figsize=(15,10))
plot_tree(
clf,
feature_names=iris.feature_names, # 显示特征名称
class_names=iris.target_names, # 显示类别名称
filled=True, # 用颜色填充表示不同类别
rounded=True # 节点圆角样式
)
plt.show()超市顾客会员卡购买预测
# -*- coding: utf-8 -*-
''' 超市顾客会员卡购买预测 '''
# 1. 导入工具包
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# === 1. 全局配置 ===
# 解决中文显示和Matplotlib警告
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文字体
plt.rcParams['axes.unicode_minus'] = False # 修复负号显示
import matplotlib
matplotlib.use('TkAgg') # 非交互式后端,避免警告
# 2. 创建模拟数据集(实际项目从CSV读取)
data = {
'年龄': [25, 45, 32, 60, 18, 55, 28, 40],
'年收入(万)': [18, 76, 25, 120, 5, 95, 22, 80],
'是否学生': [1, 0, 1, 0, 1, 0, 1, 0],
'购买会员卡': [0, 1, 0, 1, 0, 1, 0, 1]
}
df = pd.DataFrame(data)
print("原始数据:\n", df)
# 3. 准备数据
X = df.drop('购买会员卡', axis=1) # 特征矩阵:删除目标列
y = df['购买会员卡'] # 目标标签:只保留目标列
# 4. 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.25, # 25%作为测试集(2条数据)
random_state=7 # 固定随机种子,确保每次分割结果相同
)
# 5. 创建决策树模型(关键参数详解)
model = DecisionTreeClassifier(
max_depth=2, # 最多问2层问题(防止过度复杂)
criterion='gini', # 使用基尼系数作为分裂标准
min_samples_split=3 # 至少3个样本才允许继续分裂
)
# 6. 训练模型(让树学习规律)
model.fit(X_train, y_train) # 传入特征数据和对应标签
# 7. 预测测试集
y_pred = model.predict(X_test) # 用训练好的树对未知数据做预测
print("\n测试集真实结果:", y_test.values)
print("模型预测结果 :", y_pred)
# 8. 评估模型
conf_matrix = confusion_matrix(y_test, y_pred)
print("\n混淆矩阵(真实vs预测):\n", conf_matrix)
# 9. 可视化决策树
plt.figure(figsize=(12,8))
plot_tree(
model,
feature_names=X.columns, # 显示特征名称(年龄、收入...)
class_names=['不购买', '购买'], # 显示类别名称
filled=True, # 用颜色填充区分类别
rounded=True, # 圆角矩形框
proportion=True # 显示样本比例
)
plt.show()