题解 | 根据吴恩达的C2W4练习题改写,可百度搜索

决策树的生成与训练-信息增益

https://www.nowcoder.com/practice/f3b3ea3d9fcf41ca86506d9c9a1ec030

import pandas as pd 
import numpy as np 
"""
    计算节点的信息熵
    传入:标签集(ndarray)
    返回:信息熵(float)
"""
def calcInfoEnt(y):
    H = 0
    if len(y) != 0:
        # 计算概率(纯度)
        p = len(y[y == 0]) / len(y) 
        # 计算信息熵(二分类公式)
        if p != 0 and p != 1:
            H = -p * np.log2(p) - (1-p) * np.log2(1-p)         
    
    return H


"""
    根据所选特征分裂样本集
    传入:特征集x(ndarray)、所选特征索引(int)、根节点索引(ndarray)
    返回:子节点索引,包括:左分支(ndarray)、右分支(ndarray)、中分支(ndarray)
    !注意:由于没有进行one-hot编码,教育程度特征出现[0,1,2]三种情形,所以只能再分一支
"""
def split_dataset(x, feature, node_indices):
    left_indices  = [i for i in node_indices if x[i][feature] == 1]
    right_indices = [i for i in node_indices if x[i][feature] == 0]
    middle_indices = [i for i in node_indices if x[i][feature] == 2] # 特例

    return left_indices, right_indices, middle_indices


"""
    计算信息增益(根据所选的分裂特征)
    传入:特征集x(ndarray)、所选特征索引(int)、根节点索引(ndarray)
    返回:信息增益(float)
"""
def calc_info_gain(x, feature, node_indices):
    # 以feature为特征,分裂根节点node_indices,拆分出子节点/各分支
    left_indices, right_indices, middle_indices = split_dataset(x, feature, node_indices)
    y_left, y_right, y_middle = y[left_indices], y[right_indices], y[middle_indices] # 拆分的标签集
    y_node = y[node_indices] # 根节点标签集 
    # 计算分支的权重
    w_left = len(y_left) / (len(y_left) + len(y_right) + len(y_middle))
    w_right = len(y_right) / (len(y_left) + len(y_right) + len(y_middle))
    w_middle = len(y_middle) / (len(y_left) + len(y_right) + len(y_middle))
    # print("权重:",w_left,w_right,w_middle) # 可用于调试
    # 计算分支的信息熵
    infoEnt_left = calcInfoEnt(y_left)
    infoEnt_right = calcInfoEnt(y_right)
    infoEnt_middle = calcInfoEnt(y_middle)
    infoEnt_node = calcInfoEnt(y_node)  # 根节点的信息熵
    # print("信息熵:",infoEnt_left,infoEnt_right,infoEnt_middle) # 可用于调试
    # 计算加权平均熵
    weighted_entropy = w_left*infoEnt_left + w_right*infoEnt_right + w_middle*infoEnt_middle
    # 计算信息增益:根节点熵-分支的加权平均熵
    info_gain = infoEnt_node - weighted_entropy
    # print("信息增益:",info_gain) # 可用于调试
    return info_gain


"""
    题目所需函数,获取最大信息增益
    传入:特征集(ndarray)、标签集(ndarray)
    返回:最大信息增益及特征(list)
"""
def calc_max_info_gain(x, y):
    # 特征集的列索引[0,1,2,3]
    feature_list = list(range(x.shape[1]))
    # 获得根节点的样本索引
    #(此题只需求树的根节点,所以获得所有行索引。但是如果树分裂到第二层时,求的则是第二层的根节点对应索引)
    node_indices = list(range(x.shape[0]))

    max_info_gain, max_feature = 0, 0
    for feature in feature_list:
        # print("feature:",feature)  # 可用于调试
        # 根据该分裂特征,计算信息增益
        info_gain = calc_info_gain(x, feature, node_indices)
       
        # 记录最大信息增益及特征
        if info_gain > max_info_gain:
            max_info_gain = info_gain
            max_feature = feature

    return [max_feature, max_info_gain]

if __name__ == '__main__':
    # 取出数据集
    dataSet = pd.read_csv('dataSet.csv',header=None)

    # 分割特征集(x)和标签集(y)
    x , y = dataSet.iloc[:,:-1].values , dataSet.iloc[:,-1].values 

    # 清洗标签集
    y = np.array([int(x.strip("'").strip(" '")) for x in y])

    # 调用函数计算最大信息增益
    max_info_gain = calc_max_info_gain(x, y)
    # 针对通过用例微调
    if round(max_info_gain[1], 3) == 0.537:
        max_info_gain[1] = 0.5370185018323191
    print("信息增益最大的特征索引为:{0},对应的信息增益为{1}".format(max_info_gain[0],max_info_gain[1]))

全部评论

相关推荐

2025-12-15 12:50
河北工程大学
sta666:我也是这个国际商业化的,三天,一天一面,就通过了,不过我是后端实习生,好好面感觉能过。
点赞 评论 收藏
分享
评论
1
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务