浅析 poj2785 4 Values whose Sum is 0 (善用lower/upper_bound()函数)
题目链接:
http://poj.org/problem?id=2785
题面:
Description
The SUM problem can be formulated as follows: given four lists A, B, C, D of integer values, compute how many quadruplet (a, b, c, d ) ∈ A x B x C x D are such that a + b + c + d = 0 . In the following, we assume that all lists have the same size n .
Input
The first line of the input file contains the size of the lists n (this value can be as large as 4000). We then have n lines containing four integer values (with absolute value as large as 228 ) that belong respectively to A, B, C and D .
Output
For each input file, your program has to write the number quadruplets whose sum is zero.
Sample Input
6
-45 22 42 -16
-41 -27 56 30
-36 53 -37 77
-36 30 -75 -46
26 -38 -10 62
-32 -54 -6 45
Sample Output
5
Hint
Sample Explanation:
Indeed, the sum of the five following quadruplets is zero:
(-45, -27, 42, 30), (26, 30, -10, -46),
(-32, 22, 56, -46),(-32, 30, -75, 77), (-32, -54, 56, 30).
题意:
给你 N 行 4 列的数,从每一列选取一个数,问使它们的和为0的情况有多少种
(N 4000)
分析过程+代码:
四分为二 + 双指针
-
把四列的数组手工分为前两列和与后两列,计算前两列和后两列的和(需要 次运算),然后对后两列的和所有可能情况进行统计,并分别排序(快排时间复杂度 )。
-
再用双指针判断前两列和与后两列和的和是否为0,若为0, 则把 前两列中该计算值出现次数 乘上 后两列中该计算值出现次数。(双指针时间复杂度 )
基于上面的思想写出如下代码:
poj——TLE代码:
#include<iostream>
#include<map>
#include<vector>
#include<algorithm>
using namespace std;
const int maxn = 4004;
int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};
map<int, int> cnt12;
map<int, int> cnt34;
vector<int> v12;
vector<int> v34;
long long ans = 0;
int main () {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
cnt12[row1[i] + row2[j]]++;
if (!count(v12.begin(), v12.end(), row1[i] + row2[j])) { // 应该是这里超时
v12.push_back(row1[i] + row2[j]);
}
cnt34[row3[i] + row4[j]]++;
if (!count(v34.begin(), v34.end(), row3[i] + row4[j])) { // 应该是这里超时
v34.push_back(row3[i] + row4[j]);
}
}
}
sort(v12.begin(), v12.end());
sort(v34.begin(), v34.end());
// cout << endl;
// for(auto a : v12) {
// cout << a << " ";
// }
// cout << endl;
//
// for (auto a : v34) {
// cout << a << " ";
// }
// cout << endl;
vector<int>::iterator it12 = v12.begin();
vector<int>::reverse_iterator it34 = v34.rbegin();
while (it12 != v12.end() && it34 != v34.rend()) {
if ((*it12) < -(*it34)) {
// cout << (*it12) << "(it12) ";
it12++;
}
else if ((*it12) > -(*it34)) {
// cout << (*it34) << "(it34) ";
it34++;
}
else {
ans += (cnt12[*it12] * cnt34[*it34]);
it12++;
// cout << endl << ans << endl;
}
}
cout << ans;
}
简单懒惰的修改
poj——CE代码:
#include<iostream>
#include<map>
#include<vector>
#include<algorithm>
using namespace std;
const int maxn = 4004;
int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};
map<int, int> cnt12;
map<int, int> cnt34;
vector<int> v12;
vector<int> v34;
long long ans = 0;
int main () {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= n; j++) {
cnt12[row1[i] + row2[j]]++;
if (!count(v12.begin(), v12.end(), row1[i] + row2[j])) { // 应该是这里超时
v12.push_back(row1[i] + row2[j]);
}
cnt34[row3[i] + row4[j]]++;
if (!count(v34.begin(), v34.end(), row3[i] + row4[j])) {
v34.push_back(row3[i] + row4[j]);
}
}
}
for (auto a : v12) {
if (count(v34.begin(), v34.end(), -a)) {
ans += (cnt12[a] * cnt34[-a]);
}
}
cout << ans;
return 0;
}
经过观察和分析发现,超时主要在于自己懒惰的模拟照搬的思想,总是把问题做拆分做的很复杂,其实大可不必死板地按照最初思路按步实现,不必每次插入数据都排个序(必超时),思维要灵活!!!
利用好STL中的Lower_bound()函数和upper_bound()函数,对已排序的数组中的具体数值进行定位获得该数值的数量即可!
// AC代码
#include<iostream>
#include<algorithm>
#include<math.h>
using namespace std;
const int maxn = 4004;
typedef long long ll;
int n;
int row1[maxn] = {};
int row2[maxn] = {};
int row3[maxn] = {};
int row4[maxn] = {};
int row12[maxn*maxn] = {};
int row34[maxn*maxn] = {};
ll ans = 0;
int main () {
cin >> n;
for (int i = 0; i < n; i++) {
cin >> row1[i] >> row2[i] >> row3[i] >> row4[i];
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
row12[i*n + j] = row1[i] + row2[j];
row34[i*n + j] = row3[i] + row4[j];
}
}
sort(row12, row12 + (n*n));
for (int i = 0; i < n*n; i++) {
ans += upper_bound(row12, row12+(n*n), -row34[i]) - lower_bound(row12, row12+(n*n), -row34[i]);
}
cout << ans << endl;
}