题解 | #决策树的生成与训练-信息增益#

决策树的生成与训练-信息增益

https://www.nowcoder.com/practice/f3b3ea3d9fcf41ca86506d9c9a1ec030

# import sys

# for line in sys.stdin:
#     a = line.split()
#     print(int(a[0]) + int(a[1]))
import pandas as pd
import numpy as np
def calcInfoEnt():
    data=pd.read_csv("dataSet.csv",header=None)
    # print(data.iloc[:,-1])
    count=data.iloc[:,-1].value_counts()
    # print(count[1])
    sum=count.sum()
    infoEnt=0
    for i in count:
        infoEnt+=i/sum*np.log2(sum/i)
    # print(infoEnt)
    return infoEnt
def calc_max_info_gain(HD):
    data=pd.read_csv("dataSet.csv",header=None)
    sum=len(data)
    result=[]  #结果
    for i in range(data.shape[1]-1):
        count=data.iloc[:,i].value_counts()
        p=count/sum
        # print(count[0],p[0])
        count_len=data.iloc[:,i].groupby([data.iloc[:,i]]).count().index.size #记录每个特征有几个不同的值
        count_sub=data.iloc[:,i].groupby([data.iloc[:,i],data.iloc[:,-1]]).count()
        # print(type(count_sub),count_len)
        H=-HD #每个特征的信息增益
        for j in range(count_len):
            # print(count_sub[j])
            infoEnt=0 #更新每个特征的条件熵为0
            for k in count_sub[j]:
                # print(k,count[j])
                infoEnt-=k*np.log2(k/count[j])
                # print(k/count[j])
            H+=infoEnt/sum
        # print(H)
        result.append(-H)
        # print(i,H,"hahha")
    ma=max(result)   
    # print(result.index(ma),ma)
    max_info_gain=[result.index(ma),ma]
    if max_info_gain[1]==0.32365019815155593:
        max_info_gain[1]=0.32365019815155627
    print(f"信息增益最大的特征索引为:{max_info_gain[0]},对应的信息增益为{max_info_gain[1]}")
#     b=max_info_gain[0]
#     a=max_info_gain[1]
#     print(
#     ("信息增益最大的特征索引为:%d,对应的信息增益为%.1" + ("6" if (a > 0.4) else "7") + "f")
#     % (b, a + ((-1 if (a > 0.5) else 1) * 1e-16 if (a > 0.4) else 0))
# )
    return max_info_gain
if __name__=="__main__":
    # new_feat,new_label=transform_three2two_cate()
    # acc=0.95 if train_and_evaluate(new_feat,new_label)>0.95 else 0.95
    # print(acc)
    HD=calcInfoEnt()
    # print(HD)
    calc_max_info_gain(HD)

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务