Day48:决策树详解与案例

alt

上一节我们讲到了朴素贝叶斯算法,这一节我们将讲解决策树。它是一种基于树形结构的分类与回归算法,它通过对数据进行一系列的分裂,构建出一个树形模型,用于预测目标变量的取值。决策树的优点包括易于理解和解释、能够处理离散型和连续型特征、可以处理多分类问题等。

1. 决策树

1.1 基本算法步骤

  1. 特征选择:根据某个准则选择最佳的特征作为当前节点的分裂特征,常见的特征选择准则有信息增益、信息增益比、基尼系数等。
  2. 结点分裂:根据选定的分裂特征和分裂准则,将当前节点的样本集合分裂成子节点。
  3. 递归构建:对每个子节点,重复步骤1和步骤2,直到满足终止条件,例如达到最大深度或样本数量小于阈值。
  4. 剪枝:通过剪枝操作,去除一些节点或子树,以避免过拟合。

1.2 决策树算法

常见的决策树算法有ID3、C4.5、CART等。

  1. ID3(Iterative Dichotomiser 3):ID3算法是一种基于信息增益的决策树算法。它选择每次分裂时能够提供最大信息增益的特征作为分裂准则。信息增益是指在当前节点的基础上,选择某个特征后,目标变量的不确定性减少的程度。ID3算法在处理分类问题时使用离散型特征,并且不对缺失值进行处理。
  2. C4.5:C4.5算法是ID3算法的改进版本,也是基于信息增益的决策树算法。与ID3不同,C4.5算法引入了信息增益比来解决ID3算法对取值较多的特征有所偏好的问题。信息增益比是信息增益与分裂信息之比,其中分裂信息考虑了特征取值的多样性。C4.5算法可以处理离散型和连续型特征,并对缺失值进行处理。
  3. CART(Classification and Regression Trees):CART算法是一种既能处理分类问题又能处理回归问题的决策树算法。它采用基尼系数作为分裂准则,基尼系数衡量了将当前节点的样本集合分裂成不同类别的不纯度。CART算法可以处理离散型和连续型特征,并对缺失值进行处理。

2. 实际案例

在sklearn库中,我们可以使用DecisionTreeClassifierDecisionTreeRegressor来实现决策树分类和回归。

下面我们依旧适用鸢尾花数据,利用决策树进行分类,并且尝试上述三种算法的不同,我们将使用export_graphviz方法生成Graphviz格式的决策树图形描述,然后使用graphviz库将图形描述转换为可视化的图形,并可以选择保存图形或在默认图片查看器中显示图形。

  1. 导入库,加载数据集:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import export_graphviz
import graphviz

# 加载数据集
data = load_iris()
  1. 划分训练集和测试集:
X = data.data
y = data.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
  1. 使用ID3算法,创建决策树,并训练数据:
# ID3算法
model_id3 = DecisionTreeClassifier(criterion='entropy')
model_id3.fit(X_train, y_train)
y_pred_id3 = model_id3.predict(X_test)
accuracy_id3 = accuracy_score(y_test, y_pred_id3)
print("Accuracy (ID3):", accuracy_id3)  #输出:Accuracy (ID3): 1.0
  1. 输出ID3算法的决策树图形:
dot_data = export_graphviz(model_id3, out_file=None,
                           feature_names=data.feature_names,
                           class_names=data.target_names,
                           filled=True, rounded=True,
                           special_characters=True)

# 使用Graphviz库将图形描述转换为可视化的图形
graph = graphviz.Source(dot_data)
graph.render('decision_tree')  # 保存图形为文件
graph.view()  # 在默认图片查看器中显示图形

1alt

  1. 使用C4.5算法,创建决策树,并训练数据:
# C4.5算法
model_c45 

剩余60%内容,订阅专栏后可继续查看/也可单篇购买

大模型-AI小册 文章被收录于专栏

1. AI爱好者,爱搞事的 2. 想要掌握第二门语言的Javaer或者golanger 3. 决定考计算机领域研究生,给实验室搬砖的uu,强烈建议你花时间学完这个,后续搬砖比较猛 4. 任何对编程感兴趣的,且愿意掌握一门技能的人

全部评论
看不懂啊
点赞
送花
回复 分享
发布于 2023-07-14 15:33 上海

相关推荐

1 4 评论
分享
牛客网
牛客企业服务