题解 | #术式终端的并行调度#

术式终端的并行调度

https://www.nowcoder.com/practice/de33104bddb3458592fcbaa7c4fa2ffc

题目链接

术式终端的并行调度

题目描述

小红有 台规格相同的服务器,每台服务器的 CPU 算力上限为 ,内存容量上限为 。 现有 个任务,第 个任务需要消耗 算力、 内存,并产生 的价值。 任务在服务器上的分配必须满足:每台服务器上所有任务的算力总和 ,且内存总和 。 请分别计算:当拥有 台服务器()时,所能获得的最大任务总价值。

解题思路

由于任务数量 较小(通常 ),本题可以采用**状态压缩动态规划(Bitmask DP)**来解决。这实际上是一个变种的装箱问题。

  1. 预计算: 对于 种任务组合(由位掩码 表示),预计算出该组合所需的总算力 、总内存 以及总价值

  2. 状态定义: 定义 为完成 中所有任务所需的最少服务器数量

  3. 状态转移

    • 如果一个任务组合 满足 ,则该组合可以放入一台服务器,即
    • 对于无法放入一台服务器的组合,我们通过枚举子集进行转移: 这是一个经典的枚举子集优化,总时间复杂度为
  4. 最终答案: 定义 为使用 台服务器所能获得的最大价值。 遍历所有 ,如果 ,则更新 。 最后通过 确保答案随服务器数量单调递增。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    int n, c_limit, m_limit;
    if (!(cin >> n >> c_limit >> m_limit)) return 0;

    vector<int> c(n), m(n), v(n);
    for (int i = 0; i < n; i++) {
        cin >> c[i] >> m[i] >> v[i];
    }

    int num_masks = 1 << n;
    vector<long long> sum_c(num_masks, 0);
    vector<long long> sum_m(num_masks, 0);
    vector<long long> sum_v(num_masks, 0);

    // 预计算每个 mask 的总消耗和总价值
    for (int i = 0; i < n; i++) {
        for (int mask = 0; mask < (1 << i); mask++) {
            sum_c[mask | (1 << i)] = sum_c[mask] + c[i];
            sum_m[mask | (1 << i)] = sum_m[mask] + m[i];
            sum_v[mask | (1 << i)] = sum_v[mask] + v[i];
        }
    }

    // dp[mask] 表示完成 mask 中任务所需的最少服务器数
    vector<int> dp(num_masks, n + 1);
    dp[0] = 0;
    for (int mask = 1; mask < num_masks; mask++) {
        if (sum_c[mask] <= c_limit && sum_m[mask] <= m_limit) {
            dp[mask] = 1;
        } else {
            // 枚举子集 sub,其中 dp[sub] == 1 表示 sub 能被一台服务器装下
            for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
                if (dp[sub] == 1) {
                    if (dp[mask ^ sub] + 1 < dp[mask]) {
                        dp[mask] = dp[mask ^ sub] + 1;
                    }
                }
            }
        }
    }

    vector<long long> ans(n + 1, 0);
    for (int mask = 0; mask < num_masks; mask++) {
        if (dp[mask] <= n) {
            ans[dp[mask]] = max(ans[dp[mask]], sum_v[mask]);
        }
    }

    for (int k = 1; k <= n; k++) {
        ans[k] = max(ans[k], ans[k - 1]);
        cout << ans[k] << endl;
    }

    return 0;
}
import java.util.*;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int cLimit = sc.nextInt();
        int mLimit = sc.nextInt();

        int[] c = new int[n];
        int[] m = new int[n];
        int[] v = new int[n];
        for (int i = 0; i < n; i++) {
            c[i] = sc.nextInt();
            m[i] = sc.nextInt();
            v[i] = sc.nextInt();
        }

        int numMasks = 1 << n;
        long[] sumC = new long[numMasks];
        long[] sumM = new long[numMasks];
        long[] sumV = new long[numMasks];

        for (int i = 0; i < n; i++) {
            for (int mask = 0; mask < (1 << i); mask++) {
                sumC[mask | (1 << i)] = sumC[mask] + c[i];
                sumM[mask | (1 << i)] = sumM[mask] + m[i];
                sumV[mask | (1 << i)] = sumV[mask] + v[i];
            }
        }

        int[] dp = new int[numMasks];
        Arrays.fill(dp, n + 1);
        dp[0] = 0;

        for (int mask = 1; mask < numMasks; mask++) {
            if (sumC[mask] <= cLimit && sumM[mask] <= mLimit) {
                dp[mask] = 1;
            } else {
                for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
                    if (dp[sub] == 1) {
                        dp[mask] = Math.min(dp[mask], dp[mask ^ sub] + 1);
                    }
                }
            }
        }

        long[] ans = new long[n + 1];
        for (int mask = 0; mask < numMasks; mask++) {
            if (dp[mask] <= n) {
                ans[dp[mask]] = Math.max(ans[dp[mask]], sumV[mask]);
            }
        }

        for (int k = 1; k <= n; k++) {
            ans[k] = Math.max(ans[k], ans[k - 1]);
            System.out.println(ans[k]);
        }
    }
}
def solve():
    n, c_limit, m_limit = map(int, input().split())
    
    c, m, v = [], [], []
    for _ in range(n):
        ti = list(map(int, input().split()))
        c.append(ti[0])
        m.append(ti[1])
        v.append(ti[2])
        
    num_masks = 1 << n
    sum_c = [0] * num_masks
    sum_m = [0] * num_masks
    sum_v = [0] * num_masks
    
    for i in range(n):
        bit = 1 << i
        for mask in range(bit):
            sum_c[mask | bit] = sum_c[mask] + c[i]
            sum_m[mask | bit] = sum_m[mask] + m[i]
            sum_v[mask | bit] = sum_v[mask] + v[i]
            
    dp = [n + 1] * num_masks
    dp[0] = 0
    
    valid = [False] * num_masks
    for mask in range(1, num_masks):
        if sum_c[mask] <= c_limit and sum_m[mask] <= m_limit:
            dp[mask] = 1
            valid[mask] = True
            
    for mask in range(1, num_masks):
        if dp[mask] == 1:
            continue
        # 子集枚举:枚举 mask 的所有非空子集 sub
        sub = mask
        while sub > 0:
            if valid[sub]:
                if dp[mask ^ sub] + 1 < dp[mask]:
                    dp[mask] = dp[mask ^ sub] + 1
            sub = (sub - 1) & mask
            
    ans = [0] * (n + 1)
    for mask in range(num_masks):
        k = dp[mask]
        if k <= n:
            if sum_v[mask] > ans[k]:
                ans[k] = sum_v[mask]
                
    for k in range(1, n + 1):
        if ans[k-1] > ans[k]:
            ans[k] = ans[k-1]
        print(ans[k])

solve()

算法及复杂度

  • 算法:状态压缩动态规划(Bitmask DP)。
  • 时间复杂度:预计算 ,状态转移通过枚举子集优化为 。对于 ,在 C++ 和 Java 中可轻松通过,Python 环境下也处于可接受范围。
  • 空间复杂度:
全部评论

相关推荐

评论
1
收藏
分享

创作者周榜

更多
牛客网
牛客网在线编程
牛客网题解
牛客企业服务