首页 > 试题广场 >

验证集可达到的最优F1值

[编程题]验证集可达到的最优F1值
  • 热度指数:721 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
  • 决策树若完全按训练集递归生长,往往能把训练样本分得很“细”,但一到未见过的数据就容易出错,即出现过拟合。为缓解这一问题,常用“剪枝”把某些子树整体替换成单个叶子,使模型更简单。
  • 现在有一棵用于二分类的二叉决策树(标签1表示正类,0表示负类)。对非叶节点,按“第 f_i 个特征 ≤  th_i 走左子树,否则走右子树”的规则继续判断;到达叶子时直接输出该节点自带的 label
  • 允许在整棵树上任选若干处进行剪枝(把某个内部节点整体替换为叶节点,其输出为该节点给定的 label)。请在给定验证集上寻找使 F1 值最大的剪枝方案,输出最优 F1(四舍五入保留6位小数)。

输入描述:
第一行:N M K  
  N 为节点数(1~100),M 为验证集条数(1~300),K 为每条验证样本的特征维数(1~100)。

接下来的 N 行:按节点编号1..N给出每个节点的信息:  
  l_i  r_i   f_i   th_i  label_i  
  其中 l_ir_i 为左右子编号(0表示无子节点,且不存在只有一个子节点的情况);  
  若为非叶节点,f_i 是用于分裂的特征序号(1-based),th_i 为阈值;  
  若为叶节点,f_i 与 th_i 置 0;label_i  表示当该节点作为叶子时的输出标签(0或1)。

接下来的 M 行:每行 K+1 个整数,前 K 个为该条验证样本的特征,最后一个为真实标签(0或1)。



输出描述:
输出单行浮点数:在验证集上能达到的最大 F1 值,四舍五入到小数点后 6 位。
示例1

输入

5 5 2
2 3 1 50 0
0 0 0 0 1
4 5 2 70 0
0 0 0 0 0
0 0 0 0 1
40 80 1
55 60 0
55 90 1
55 85 0
20 10 0

输出

0.666667

说明

路由规则:特征1≤50 进左子树,否则进右子树;在右子树中再按特征2≤70 判到左叶(输出0),否则到右叶(输出1)。  
若不剪枝,五条样本的预测与真实标签对比如下:命中两条正类,出现两次“将负类判为正类”,未漏判正类,计算得 F1=2*2/(2*2+2+0)=0.666667。  
尝试将右子树整体剪为叶(输出0)或将根剪为叶(输出0/1)等方案,F1 反而更低。因此最优为 0.666667。
示例2

输入

5 6 2
2 3 1 30 1
0 0 0 0 0
4 5 2 50 1
0 0 0 0 1
0 0 0 0 0
35 40 1
35 70 0
35 60 1
25 80 0
28 10 1
50 45 1

输出

0.800000

说明

路由规则:特征1≤30 走左子树(叶,输出0),否则进入右子树;在右子树内,特征2≤50 走左叶(输出1),否则走右叶(输出0)。
不剪枝时:TP=2(命中两条正类),FN=2(漏判两条正类),FP=0,F1=22/(4+0+2)=0.666667。
若把根节点直接剪成叶并输出1,则6条样本预测为1,其中TP=4(四条为正类),FP=2(两条为负类),FN=0,F1=24/(8+2+0)=0.800000。其他剪枝方案(如只剪右子树)得到的F1更低,因此最优为0.800000。

备注:
本题由牛友@Charles 整理上传
先实现二叉树的数据结构,然后递归实现决策树的推理过程。遍历验证集,每个节点记录当前以它为根的子树的真实标签列表和预测标签列表。然后遍历决策树,计算当前各个节点的局部F1值,当然根节点的F1就是在完整验证集上的全局F1。

接下来是剪枝过程,其实不需要排列组合所有的情况,我们前序遍历,如果根节点合并了,就没有必要去遍历子树了;如果根节点没有合并,两个子树是否剪枝是在它们各自分配到的样本子集中进行的独立过程。我们只需要将预测标签列表暂时全部更新为该节点的label,重新计算F1值,如果比历史F1值要高,那就合并子树。

最后我们再跑一遍完全一样的遍历和分类过程,计算根节点的F1值。代码如下:
def calc_f1(truth, preds):
    TP = FP = FN = 0
    for y, pred in zip(truth, preds):
        TP += pred == y == 1
        FP += pred == 1 and y == 0
        FN += pred == 0 and y == 1
    recall = TP / (TP + FN) if TP + FN else 0
    precision = TP / (TP + FP) if TP + FP else 0
    f1 = 2 * recall * precision / (recall + precision) if recall + precision else 0
    return f1


class Node:

    def __init__(self, no, lc, rc, thres, f, label):
        self.no = no
        self._lc, self._rc = lc, rc
        self.thres = thres
        self.f, self.label = f, label
        self.is_leaf = thres == 0

        self.f1 = 0
        self.truth = []
        self.preds = []

    @property
    def lc(self):
        return nodes[self._lc]

    @property
    def rc(self):
        return nodes[self._rc]

    def classify(self, x, y):
        ans = self.label
        if not self.is_leaf:
            ans = (self.lc if x[self.f] <= self.thres else self.rc).classify(x, y)
        self.truth.append(y)
        self.preds.append(ans)
        return ans

    def calc_local_f1(self):
        self.f1 = calc_f1(self.truth, self.preds)
        if not self.is_leaf:
            self.lc.calc_local_f1()
            self.rc.calc_local_f1()

    def prun(self):
        if self.is_leaf:
            return
        new_f1 = calc_f1(self.truth, [self.label] * len(self.truth))
        if new_f1 > self.f1:
            self.is_leaf = True
        else:
            self.lc.prun()
            self.rc.prun()

    def reset_buffer(self):
        self.truth, self.preds = [], []
        if not self.is_leaf:
            self.lc.reset_buffer()
            self.rc.reset_buffer()

    def classify_and_update_f1(self, valid_set):
        self.reset_buffer()
        for *x, y in valid_set:
            pred = root.classify(x, y)
        self.calc_local_f1()


nodes = []
N, M, K = map(int, input().split())
for i in range(N):
    l_no, r_no, f, thres, label = map(int, input().split())
    nodes.append(Node(i, l_no - 1, r_no - 1, thres, f - 1, label))
root = nodes[0]
valid_set = [list(map(int, input().split())) for _ in range(M)]

root.classify_and_update_f1(valid_set)
root.prun()
root.classify_and_update_f1(valid_set)
print(f"{root.f1:>.6f}")


发表于 2025-09-27 21:31:59 回复(0)
Python3百行代码实现二叉树决策树剪枝
class Treenode:
    def __init__(self, id=None, f=None, threshhold=None, label=None, left=None, right=None):
        self.id = id
        self.f = f
        self.threshhold = threshhold
        self.label = label
        self.left = left
        self.right = right
        self.cut_table = True

class Tree:
    def __init__(self, root):
        self.root = root
        self.F1 = 0

    # 决策树预测结果
    def judge(self, x):
        node = self.root
        while node.left and node.right:
            if x[node.f] <= node.threshhold:
                node = node.left
            else:
                node = node.right
        return node.label

    def update_F1_score(self):
        TP, FP = 0, 0
        for i in range(m):
            res = self.judge(dataset[i][0])
            if res == dataset[i][1] == 1:
                TP += 1
            elif res == 1 and dataset[i][1] == 0:
                FP += 1
        precision = TP / (TP + FP) if TP else 0
        recall = TP / True_count
        F1 = 2 * precision * recall / (precision + recall) if TP else 0
        if F1 > self.F1:
            self.F1 = F1
            return True
        else:
            return False

    def cut_tree(self):
        """后续遍历剪枝"""
        def traverse(node):
            if node is not None:
                if node.left is None and node.right is None: # 为叶节点
                    node.cut_table = False
               
                traverse(node.left)
                traverse(node.right)
                if node.cut_table:
                    temp = [node.left, node.right]
                    node.left = None
                    node.right = None

                    if not self.update_F1_score():
                        node.left = temp[0]
                        node.right = temp[1]
               

        traverse(self.root)

# 输入
n, m, k = map(int, input().split())

# 构建树
Tree_ls = [Treenode(id=i) for i in range(n)]
for i in range(n):
    ls = input().split()
    # 左节点
    idx_left = int(ls[0])-1
    if idx_left > 0:
        Tree_ls[i].left = Tree_ls[idx_left] # 编号从1开始为根节点
    # 右节点
    idx_right = int(ls[1])-1
    if idx_right > 0:
        Tree_ls[i].right = Tree_ls[idx_right]
   
    Tree_ls[i].f = int(ls[2])-1
    Tree_ls[i].threshhold = float(ls[3])
    Tree_ls[i].label = int(ls[4])
Tree = Tree(Tree_ls[0])

# 输入数据集
dataset = []
for _ in range(m):
    ls = list(map(float, input().split()))
    dataset.append([ls[:-1], int(ls[-1])])

# 正例的个数
True_count = 0
for e in dataset:
    True_count += e[1]

Tree.update_F1_score()
if n > 1:
    Tree.cut_tree()
print(f'{round(Tree.F1, 6):.6f}')


"""
datasets:
(1) res: 0.800000
5 6 2
2 3 1 30 1
0 0 0 0 0
4 5 2 50 1
0 0 0 0 1
0 0 0 0 0
35 40 1
35 70 0
35 60 1
25 80 0
28 10 1
50 45 1

(2) res: 0.666667
3 4 2
2 3 2 50 1
0 0 0 0 0
0 0 0 0 1
10 40 0
20 60 1
5 55 0
5 10 1

(3) res: 0.800000
1 3 1
0 0 0 0 1
10 1
20 1
30 0

(4) res:  0.666667
5 5 2
2 3 1 50 0
0 0 0 0 1
4 5 2 70 0
0 0 0 0 0
0 0 0 0 1
40 80 1
55 60 0
55 90 1
55 85 0
20 10 0

"""
发表于 2025-09-10 20:11:27 回复(0)