首页 > 试题广场 >

验证集可达到的最优F1值

[编程题]验证集可达到的最优F1值
  • 热度指数:724 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
  • 决策树若完全按训练集递归生长,往往能把训练样本分得很“细”,但一到未见过的数据就容易出错,即出现过拟合。为缓解这一问题,常用“剪枝”把某些子树整体替换成单个叶子,使模型更简单。
  • 现在有一棵用于二分类的二叉决策树(标签1表示正类,0表示负类)。对非叶节点,按“第 f_i 个特征 ≤  th_i 走左子树,否则走右子树”的规则继续判断;到达叶子时直接输出该节点自带的 label
  • 允许在整棵树上任选若干处进行剪枝(把某个内部节点整体替换为叶节点,其输出为该节点给定的 label)。请在给定验证集上寻找使 F1 值最大的剪枝方案,输出最优 F1(四舍五入保留6位小数)。

输入描述:
第一行:N M K  
  N 为节点数(1~100),M 为验证集条数(1~300),K 为每条验证样本的特征维数(1~100)。

接下来的 N 行:按节点编号1..N给出每个节点的信息:  
  l_i  r_i   f_i   th_i  label_i  
  其中 l_ir_i 为左右子编号(0表示无子节点,且不存在只有一个子节点的情况);  
  若为非叶节点,f_i 是用于分裂的特征序号(1-based),th_i 为阈值;  
  若为叶节点,f_i 与 th_i 置 0;label_i  表示当该节点作为叶子时的输出标签(0或1)。

接下来的 M 行:每行 K+1 个整数,前 K 个为该条验证样本的特征,最后一个为真实标签(0或1)。



输出描述:
输出单行浮点数:在验证集上能达到的最大 F1 值,四舍五入到小数点后 6 位。
示例1

输入

5 5 2
2 3 1 50 0
0 0 0 0 1
4 5 2 70 0
0 0 0 0 0
0 0 0 0 1
40 80 1
55 60 0
55 90 1
55 85 0
20 10 0

输出

0.666667

说明

路由规则:特征1≤50 进左子树,否则进右子树;在右子树中再按特征2≤70 判到左叶(输出0),否则到右叶(输出1)。  
若不剪枝,五条样本的预测与真实标签对比如下:命中两条正类,出现两次“将负类判为正类”,未漏判正类,计算得 F1=2*2/(2*2+2+0)=0.666667。  
尝试将右子树整体剪为叶(输出0)或将根剪为叶(输出0/1)等方案,F1 反而更低。因此最优为 0.666667。
示例2

输入

5 6 2
2 3 1 30 1
0 0 0 0 0
4 5 2 50 1
0 0 0 0 1
0 0 0 0 0
35 40 1
35 70 0
35 60 1
25 80 0
28 10 1
50 45 1

输出

0.800000

说明

路由规则:特征1≤30 走左子树(叶,输出0),否则进入右子树;在右子树内,特征2≤50 走左叶(输出1),否则走右叶(输出0)。
不剪枝时:TP=2(命中两条正类),FN=2(漏判两条正类),FP=0,F1=22/(4+0+2)=0.666667。
若把根节点直接剪成叶并输出1,则6条样本预测为1,其中TP=4(四条为正类),FP=2(两条为负类),FN=0,F1=24/(8+2+0)=0.800000。其他剪枝方案(如只剪右子树)得到的F1更低,因此最优为0.800000。

备注:
本题由牛友@Charles 整理上传
# 决策树节点类
class Node:
    def __init__(self, idx, l, r, f, th, label):
        self.idx = idx    # 节点编号(从1开始)
        self.l = l        # 左子树
        self.r = r        # 右子树
        self.f = f        # 分裂特征
        self.th = th      # 阈值
        self.label = label  # 标签
# 用决策树对样本分类
def predict_tree(node, sample):
    while node.f != 0:  # 如果是非叶节点
        if sample[node.f - 1] <= node.th:
            node = node.l  # 向左子树走
        else:
            node = node.r  # 向右子树走
    return node.label  # 叶节点返回标签
# 计算 F1 值
def calculate_f1(predictions, true_labels):
    tp = sum((predictions[i] == 1 and true_labels[i] == 1) for i in range(len(predictions)))
    fp = sum((predictions[i] == 1 and true_labels[i] == 0) for i in range(len(predictions)))
    fn = sum((predictions[i] == 0 and true_labels[i] == 1) for i in range(len(predictions)))
    if tp + fp == 0 or tp + fn == 0:
        return 0.0
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    return 2 * (precision * recall) / (precision + recall)
# 后剪枝:如果剪枝能提高F1值就剪枝
def prune_tree(node, val_samples):
    if not node or node.f == 0:
        return node
    # 先剪枝左子树和右子树
    if node.l is not None:
        node.l = prune_tree(node.l, val_samples)
    if node.r is not None:
        node.r = prune_tree(node.r, val_samples)
    # 计算剪枝前的F1值
    predictions_before = [predict_tree(node, sample) for sample, _ in val_samples]
    true_labels = [label for _, label in val_samples]
    f1_before = calculate_f1(predictions_before, true_labels)
    # 将当前节点剪枝成叶节点
    node_l = node.l
    node.l = None
    node_r = node.r
    node.r = None
    node_f = node.f
    node.f = 0
    # 计算剪枝后的F1值
    predictions_after = [predict_tree(node, sample) for sample, _ in val_samples]
    f1_after = calculate_f1(predictions_after, true_labels)
    # 如果剪枝后F1值更高,保留剪枝后的树
    if f1_after > f1_before:
        return node
    else:
        # 否则恢复原来的子树
        node.l = node_l
        node.r = node_r
        node.f = node_f
        return node
# 输入数据
N, M, K = map(int, input().split())
nodes = []
# 构建树
for i in range(N):
    l, r, f, th, label = map(int, input().split())
    nodes.append(Node(i, l, r, f, th, label))
for i in range(N):
    node = nodes[i]
    if node.f != 0:
        if node.l != 0:
            node.l = nodes[node.l - 1]
        else:
            node.l = None
        if node.r != 0:
            node.r = nodes[node.r - 1]
        else:
            node.r = None
    else:
        node.l = None
        node.r = None
# 验证集
val_samples = []
for i in range(M):
    data =list(map(int, input().split()))
    sample = data[:K]
    label = data[K]
    val_samples.append((sample, label))
root = nodes[0]
pruned_tree = prune_tree(root, val_samples)
predictions = [predict_tree(pruned_tree, sample) for sample, _ in val_samples]
true_labels = [label for _, label in val_samples]
f1 = calculate_f1(predictions, true_labels)
print(f"{f1:.6f}")
发表于 2025-09-03 14:45:25 回复(0)