题解 | 虫洞网络

虫洞网络

https://www.nowcoder.com/practice/29d2fe7e8d314d849f687d035133c463

m, s_start, s_des, e_total = map(int, input().split())

c_list = []
net_list = []
for _ in range(m):
    temp = list(map(int, input().split()))
    c_list.append(temp[0])
    net_list.append(set(temp[2:]))


def main():
    net_connects = [[] for _ in range(m)]
    start_nets = []
    end_nets = []

    for i, net1 in enumerate(net_list):
        if s_start in net1:
            start_nets.append(i)
        if s_des in net1:
            end_nets.append(i)
        for j, net2 in enumerate(net_list):
            if net1 & net2:
                net_connects[i].append(j)
    
    if not end_nets:
        return -1

    cost_by_node = [float("inf")] * m

    from heapq import heappush, heappop

    pq = []
    for node_id in start_nets:
        cost_by_node[node_id] = c_list[node_id]
        heappush(pq, (c_list[node_id], node_id))

    while pq:
        cost, node_id = heappop(pq)

        if cost > cost_by_node[node_id]:
            continue
        
        for node2 in net_connects[node_id]:
            next_cost = cost + c_list[node2]
            if next_cost < cost_by_node[node2]:
                cost_by_node[node2] = next_cost
                heappush(pq, (next_cost, node2))

                

    min_cost = min(cost_by_node[i] for i in end_nets)
    if min_cost > e_total:
        return -1
    return min_cost


print(main())

全部评论

相关推荐

评论
点赞
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务