题解 | 验证集可达到的最优F1值

n, m, k = map(int, input().split())

nodes = [list(map(int, input().split())) for _ in range(n)]
samples = [list(map(int, input().split())) for _ in range(m)]


def pareto(states):
    states.sort(key=lambda x: (-x[0], x[1]))
    res = []
    min_fp = 300
    for state in states:
        if state[1] < min_fp:
            res.append(state)
            min_fp = state[1]
    return res


def cal_f1(tp, fp, yp):
    denom = tp + fp + yp
    if denom == 0:
        return 0.
    return 2 * tp / denom


def main():

    sample_ids_by_node = [[] for _ in range(n)]

    for i, sample in enumerate(samples):
        node_id = 0
        while 1:
            node = nodes[node_id]
            sample_ids_by_node[node_id].append(i)
            if node[0] == 0:
                break
            else:
                if sample[node[2]-1] <= node[3]:
                    node_id = node[0] - 1
                else:
                    node_id = node[1] - 1

    def solve(node_id):
        l, r, _, _, label = nodes[node_id]
        s_ids = sample_ids_by_node[node_id]

        if label:
            pos = sum(samples[i][-1] for i in s_ids)
            neg = len(s_ids) - pos
            states = [(pos, neg)]
        else:
            states = [(0, 0)]

        if l:
            right_states = solve(r - 1)
            for tp1, fp1 in solve(l - 1):
                for tp2, fp2 in right_states:
                    states.append((tp1 + tp2, fp1 + fp2))
        # print(node_id, states)
        return pareto(states)
        
    yp = sum(s[-1] for s in samples)
    return max(cal_f1(tp, fp, yp) for tp, fp in solve(0))

print(f"{main():.6f}")
            
这难度较难以上了吧orz
全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务