斜率优化dp学习 + 模板
参考博客
例题:HDU 5956 The Elder(斜率优化DP)
参考博客2
模板题:HDU3045
Picnic Cows(HDU3045)
题目链接:https://vjudge.net/problem/HDU-3045
题目大意:
给出一个有N (1<= N <=400000)个正数的序列,要求把序列分成若干组(可以打乱顺序),每组的元素个数不能小于T (1 < T <= N)。每一组的代价是每个元素与最小元素的差之和,总代价是每个组的代价之和,求总代价的最小值。
样例输入包含:
第一行 N
第二行 N个数,如题意
样例输出包含:
第一行 最小的总代价
分析:
首先,审题。可以打乱序列顺序,又知道代价为组内每个元素与最小值差之和,故想到贪心,先将序列排序(用STL sort)。
先从最简单的DP方程想起:
容易想到:
f[i] = min( f[j] + (a[j + 1 -> i] - Min k) ) (0 <= j < i)
– –> f[i] = min( f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j ) )
Min k 代表序列 j + 1 -> i 内的最小值,排序后可以简化为a[j + 1]。提取相似项合并成前缀和sum。这个方程的思路就是枚举 j 不断地计算状态值更新答案。但是数据规模达到了 40000 ,这种以O(n ^ 2)为绝对上界方法明显行不通。所以接下来我们要引入“斜率”来优化。
首先要对方程进行变形:
f[i] = f[j] + sum[i] - sum[j] - a[j + 1] * ( i - j )
– –> f[i] = (f[j] - sum[j] + a[j + 1] * j) - i * a[j + 1] + sum[i]
(此步将只由i决定的量与只由j决定的量分开)
由于 sum[i] 在当前枚举到 i 的状态下是一个不变量,所以在分析时可以忽略(因为对决策优不优没有影响)(当然写的时候肯定不能忽略)
令 i = k
a[j + 1] = x
f[j] - sum[j] + a[j + 1] * j = y
f[i] = b
原方程变为
– –> b = y - k * x
移项
– –> y = k * x + b
code:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
typedef long long dnt;
int n, T, Q[405005];
dnt sum[405005], f[405005], a[405005];
dnt Y( int i, int j )
{
return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
}
dnt X( int i, int j )
{
return a[j + 1] - a[i + 1];
}
dnt DP( int i, int j )
{
return f[j] + (sum[i] - sum[j]) - (i - j) * a[j + 1];
}
inline dnt R()
{
static char ch;
register dnt res, T = 1;
while( ( ch = getchar() ) < '0' || ch > '9' )if( ch == '-' )T = -1;
res = ch - 48;
while( ( ch = getchar() ) <= '9' && ch >= '0')
res = res * 10 + ch - 48;
return res*T;
}
int main()
{
sum[0] = 0;
while(~scanf( "%d%d", &n, &T ))
{
a[0] = 0, f[0] = 0;
for(int i = 1; i <= n; i++)
scanf( "%I64d", &a[i] );
sort(a + 1, a + n + 1);
for(int i = 1; i <= n; i++)
sum[i] = sum[i - 1] + a[i];
int h = 0, t = 0;
Q[++t] = 0;
for(int i = 1; i <= n; i++)
{
int cur = i - T + 1;
for(; h + 1 < t && Y(Q[h + 1], Q[h + 2]) <= i * X(Q[h + 1], Q[h + 2]); h++);
f[i] = DP(i, Q[h + 1]);
if(cur < T) continue;
for(; h + 1 < t && Y(Q[t - 1], Q[t]) * X(Q[t], cur) >= X(Q[t - 1], Q[t]) * Y(Q[t], cur); t--);
Q[++t] = cur;
}
printf( "%I64d\n", f[n] );
}
return 0;
}
HDU 5956 The Elder(斜率优化DP)
#include<bits/stdc++.h>
using namespace std;
int n,p;
typedef long long dnt;
int Q[405005];
dnt sum[405005], f[405005], a[405005];
const int N = 2e5+7;
int head[N],cnt;
int stk[N],L,R;
/* 题意: 蛤题材,就是一颗树上每条边有个权值,每个节点都有新闻要送到根节点就是1节点, 运送过程中如果不换青蛙就是走过的所有边权之和的平方,如果换就每次更换要加上P, 也就是求“每个节点到根节点这段路径切分成几块之后 [每块的权值和的平方加上(块个数-1)*P] 的最小值”。 然后找到所有节点中消耗最大的那个是多少。 思路: 明显是树DP,先推一下转移方程:dp[i] = min{dp[j]+P+(sum[i]-sum[j])^2} 如此明显的斜率优化DP,那么就很简单了只要书上跑DP用斜率优化把O(N^2)优化到O(N)就行了。 */
dnt ans;
struct node
{
int u,v,w;
int next;
} edge[N];
void addedge(int u,int v,int w)
{
edge[cnt].v = v;
edge[cnt].w = w;
edge[cnt].next =head[u];
head[u] = cnt++;
edge[cnt].v = u;
edge[cnt].next = head[v];
head[v] = cnt++;
}
dnt Y( int j, int i )
{
return f[j]-f[i]+(sum[j]*sum[j])-sum[i]*sum[i];
// return f[j] - sum[j] + j * a[j + 1] - (f[i] - sum[i] + i * a[i + 1]);
}
dnt X( int j, int i )
{
// return a[j + 1] - a[i + 1];
return sum[j] - sum[i];
//树上前缀和不用-1
}
dnt DP( int i, int j )
{
// return f[j] + (sum[i] - sum[j]) - (i - j) * a[j + 1];
return f[j]+p+(sum[i]-sum[j])*(sum[i]-sum[j]);
}
//dnt Xie(int k,int i)
//{
// return (Y(k)-Y(i))/(X(k)-X(i));
//}
void dfs(int now,int fa)
{
vector<pair<int,int> >sv;
int l =L,r =R;//l,r 保存现场存储L,R
//假设存在j<k<i
//且f[k]<f[j] 即k比j优则淘汰掉j
//队尾元素L
while(L+1<R&&Y(stk[L+1],stk[L])<=2LL*sum[now]*X(stk[L+1],stk[L]))
{
sv.push_back(make_pair(L,stk[L]));
L++;
}
if(now!=1)
{
// f[now] = DP(now,stk[L]);
f[now]=f[stk[L]]+(1LL*sum[now]-sum[stk[L]])*(1LL*sum[now]-sum[stk[L]])+p;
ans = max(ans,f[now]);
}
//队尾元素R-1
while(L+1<R&&Y(stk[R-1],stk[R-2])*X(now,stk[R-1])>=Y(now,stk[R-1])*X(stk[R-1],stk[R-2]))
{
R--;
sv.push_back(make_pair(R,stk[R]));
}
stk[R++] = now;
for(int i=head[now];~i;i=edge[i].next)
{
int v = edge[i].v;
if(v==fa)
{
continue;
}
sum[v] =sum[now]+edge[i].w;
dfs(v,now);
}
L = l;
R = r;
//恢复 现场
int SZ = sv.size();
for(int i=0; i<SZ; i++)
{
int pos= sv[i].first;
int id = sv[i].second;
stk[pos] = id;
}
}
int main()
{
int T;
cin>>T;
while(T--)
{
cnt =0;
memset(head,-1,sizeof head);
memset(sum,0,sizeof sum);
cin>>n>>p;
int u,v,w;
for(int i=1; i<=n-1; i++)
{
scanf("%d%d%d",&u,&v,&w);
addedge(u,v,w);
}
ans = 0;
L = R =0;
f[1] = -p;
dfs(1,0);
cout<<ans<<endl;
}
return 0;
}
in:
1
6 50
1 2 4
2 3 5
1 4 3
4 5 3
5 6 3
out:81