题解 | #术式终端的并行调度#
术式终端的并行调度
https://www.nowcoder.com/practice/de33104bddb3458592fcbaa7c4fa2ffc
题目链接
题目描述
小红有 台规格相同的服务器,每台服务器的 CPU 算力上限为
,内存容量上限为
。
现有
个任务,第
个任务需要消耗
算力、
内存,并产生
的价值。
任务在服务器上的分配必须满足:每台服务器上所有任务的算力总和
,且内存总和
。
请分别计算:当拥有
台服务器(
)时,所能获得的最大任务总价值。
解题思路
由于任务数量 较小(通常
),本题可以采用**状态压缩动态规划(Bitmask DP)**来解决。这实际上是一个变种的装箱问题。
-
预计算: 对于
种任务组合(由位掩码
表示),预计算出该组合所需的总算力
、总内存
以及总价值
。
-
状态定义: 定义
为完成
中所有任务所需的最少服务器数量。
-
状态转移:
- 如果一个任务组合
满足
且
,则该组合可以放入一台服务器,即
。
- 对于无法放入一台服务器的组合,我们通过枚举子集进行转移:
这是一个经典的枚举子集优化,总时间复杂度为
。
- 如果一个任务组合
-
最终答案: 定义
为使用
台服务器所能获得的最大价值。 遍历所有
,如果
,则更新
。 最后通过
确保答案随服务器数量单调递增。
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int main() {
int n, c_limit, m_limit;
if (!(cin >> n >> c_limit >> m_limit)) return 0;
vector<int> c(n), m(n), v(n);
for (int i = 0; i < n; i++) {
cin >> c[i] >> m[i] >> v[i];
}
int num_masks = 1 << n;
vector<long long> sum_c(num_masks, 0);
vector<long long> sum_m(num_masks, 0);
vector<long long> sum_v(num_masks, 0);
// 预计算每个 mask 的总消耗和总价值
for (int i = 0; i < n; i++) {
for (int mask = 0; mask < (1 << i); mask++) {
sum_c[mask | (1 << i)] = sum_c[mask] + c[i];
sum_m[mask | (1 << i)] = sum_m[mask] + m[i];
sum_v[mask | (1 << i)] = sum_v[mask] + v[i];
}
}
// dp[mask] 表示完成 mask 中任务所需的最少服务器数
vector<int> dp(num_masks, n + 1);
dp[0] = 0;
for (int mask = 1; mask < num_masks; mask++) {
if (sum_c[mask] <= c_limit && sum_m[mask] <= m_limit) {
dp[mask] = 1;
} else {
// 枚举子集 sub,其中 dp[sub] == 1 表示 sub 能被一台服务器装下
for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
if (dp[sub] == 1) {
if (dp[mask ^ sub] + 1 < dp[mask]) {
dp[mask] = dp[mask ^ sub] + 1;
}
}
}
}
}
vector<long long> ans(n + 1, 0);
for (int mask = 0; mask < num_masks; mask++) {
if (dp[mask] <= n) {
ans[dp[mask]] = max(ans[dp[mask]], sum_v[mask]);
}
}
for (int k = 1; k <= n; k++) {
ans[k] = max(ans[k], ans[k - 1]);
cout << ans[k] << endl;
}
return 0;
}
import java.util.*;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int cLimit = sc.nextInt();
int mLimit = sc.nextInt();
int[] c = new int[n];
int[] m = new int[n];
int[] v = new int[n];
for (int i = 0; i < n; i++) {
c[i] = sc.nextInt();
m[i] = sc.nextInt();
v[i] = sc.nextInt();
}
int numMasks = 1 << n;
long[] sumC = new long[numMasks];
long[] sumM = new long[numMasks];
long[] sumV = new long[numMasks];
for (int i = 0; i < n; i++) {
for (int mask = 0; mask < (1 << i); mask++) {
sumC[mask | (1 << i)] = sumC[mask] + c[i];
sumM[mask | (1 << i)] = sumM[mask] + m[i];
sumV[mask | (1 << i)] = sumV[mask] + v[i];
}
}
int[] dp = new int[numMasks];
Arrays.fill(dp, n + 1);
dp[0] = 0;
for (int mask = 1; mask < numMasks; mask++) {
if (sumC[mask] <= cLimit && sumM[mask] <= mLimit) {
dp[mask] = 1;
} else {
for (int sub = mask; sub > 0; sub = (sub - 1) & mask) {
if (dp[sub] == 1) {
dp[mask] = Math.min(dp[mask], dp[mask ^ sub] + 1);
}
}
}
}
long[] ans = new long[n + 1];
for (int mask = 0; mask < numMasks; mask++) {
if (dp[mask] <= n) {
ans[dp[mask]] = Math.max(ans[dp[mask]], sumV[mask]);
}
}
for (int k = 1; k <= n; k++) {
ans[k] = Math.max(ans[k], ans[k - 1]);
System.out.println(ans[k]);
}
}
}
def solve():
n, c_limit, m_limit = map(int, input().split())
c, m, v = [], [], []
for _ in range(n):
ti = list(map(int, input().split()))
c.append(ti[0])
m.append(ti[1])
v.append(ti[2])
num_masks = 1 << n
sum_c = [0] * num_masks
sum_m = [0] * num_masks
sum_v = [0] * num_masks
for i in range(n):
bit = 1 << i
for mask in range(bit):
sum_c[mask | bit] = sum_c[mask] + c[i]
sum_m[mask | bit] = sum_m[mask] + m[i]
sum_v[mask | bit] = sum_v[mask] + v[i]
dp = [n + 1] * num_masks
dp[0] = 0
valid = [False] * num_masks
for mask in range(1, num_masks):
if sum_c[mask] <= c_limit and sum_m[mask] <= m_limit:
dp[mask] = 1
valid[mask] = True
for mask in range(1, num_masks):
if dp[mask] == 1:
continue
# 子集枚举:枚举 mask 的所有非空子集 sub
sub = mask
while sub > 0:
if valid[sub]:
if dp[mask ^ sub] + 1 < dp[mask]:
dp[mask] = dp[mask ^ sub] + 1
sub = (sub - 1) & mask
ans = [0] * (n + 1)
for mask in range(num_masks):
k = dp[mask]
if k <= n:
if sum_v[mask] > ans[k]:
ans[k] = sum_v[mask]
for k in range(1, n + 1):
if ans[k-1] > ans[k]:
ans[k] = ans[k-1]
print(ans[k])
solve()
算法及复杂度
- 算法:状态压缩动态规划(Bitmask DP)。
- 时间复杂度:预计算
,状态转移通过枚举子集优化为
。对于
,
,在 C++ 和 Java 中可轻松通过,Python 环境下也处于可接受范围。
- 空间复杂度:
。


查看15道真题和解析