如果你f题78分
#include <iostream> #include <cstring> #include <algorithm> #include <vector> #include <math.h> using namespace std; typedef long long LL; const int N = 1e5 + 10, M = N * 2; int n,m,seq; int L[N], R[N]; int h[N], e[M], ne[M], w[M], idx; int depth[N],mx; LL s[N]; vector<int> f[N]; vector<int> val[N]; vector<vector<vector<LL>>> p;//RMQ void init() { for(int i = 1; i <= mx; i++) { int len = f[i].size() - 1; // cout << "len: " << len << endl; vector<vector<LL>> vt(len + 1); for(int j = 0; j <= len; j++) { vector<LL> tmp(20); vt[j] = tmp; } for(int k = 0; k <= 17; k++) { for(int j = 0; j + (1 << k) - 1 <= len; j++) { if(!k)vt[j][k] = val[i][j]; else vt[j][k] = max(vt[j][k - 1], vt[j + (1 << k - 1) ][k - 1]); } } p.push_back(vt); } } void add(int a, int b, int c) { e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx++; } void dfs(int u, int fa) { L[u] = ++ seq; for(int i = h[u]; ~i; i = ne[i]) { int j = e[i]; if(j == fa) continue; depth[j] = depth[u] + 1; mx = max(mx, depth[j]); s[j] = s[u] + w[i]; dfs(j, u); f[depth[j]].push_back( L[j] ); //存储dfs序 val[depth[j]].push_back(s[j]); //对应每个dfs序的权值 } R[u] = seq; } LL query(int t, int l, int r) { int len = r - l + 1; int k = log(len) / log(2); // cout << "l: "<< l << " r: " << r << " k: " << k << " " << r - (1 << k) << endl; LL res = p[t - 1][l][k]; if(r - (1 << k) >= 0)res = max(res, p[t - 1][r - (1 << k)][k]); return res; } int up(int t, int x)//找第一个大于等于x的下标位置 { int len = f[t].size(); if(!len)return -1; int l = 0, r = len - 1; while(l < r) { int mid = (l + r) >> 1; if(f[t][mid] >= x)r = mid; else l = mid + 1; } if(f[t][l] < x) return -1; return l; } int lower(int t, int x)//找第一个小于等于x的下标位置 { int len = f[t].size(); if(!len)return -1; int l = 0, r = len - 1; while(l < r) { int mid = (l + r + 1) >> 1; if(f[t][mid] > x)r = mid - 1; else l = mid; } if(f[t][l] > x) return -1; return l; } int main() { memset(h, -1, sizeof h); scanf("%d", &n); for(int i = 1; i < n; i++) { int a, b, c; scanf("%d%d%d", &a, &b, &c); add(a, b, c), add(b, a, c); } dfs(1, -1); // for(int i = 1; i <= n; i++)cout << L[i] << " " ; // cout << endl; // for(int i = 1; i <= n; i++)cout << R[i] << " "; // cout << endl; // for(int i = 0; i < 3 ; i++) // { // for(int d:f[i]) // { // cout << d << " "; // } // cout << endl; // } init(); scanf("%d", &m); while(m--) { int u, k; scanf("%d%d", &u, &k); // 询问 int t = depth[u] + k; if(t>mx) { printf("-1\n"); continue; } int l = up(t, L[u]), r = lower(t, R[u]); if(l == -1 || r == -1) { printf("-1\n"); continue; } // cout << l << " " << r << endl; printf("%d\n", query(t, l, r) - s[u]); } return 0; }
错误原因:二分时所找到的区间l, r 有可能l > r 此时非法
错误数据:
7
1 2 1
2 3 1
2 5 1
1 3 1
1 6 1
6 7 1
1
3 1
Ac代码 :
#include <iostream> #include <cstring> #include <algorithm> #include <vector> #include <math.h> using namespace std; typedef long long LL; const int N = 1e5 + 10, M = N * 2; int n,m,seq; int L[N], R[N]; int h[N], e[M], ne[M], w[M], idx; int depth[N],mx; LL s[N]; vector<int> f[N]; vector<LL> val[N]; vector<vector<vector<LL>>> p; void init() { for(int i = 1; i <= mx; i++) { int len = f[i].size() - 1; // cout << "len: " << len << endl; vector<vector<LL>> vt(len + 1); for(int j = 0; j <= len; j++) { vector<LL> tmp(20); vt[j] = tmp; } for(int k = 0; k <= 17; k++) { for(int j = 0; j + (1 << k) - 1 <= len; j++) { if(!k)vt[j][k] = val[i][j]; else vt[j][k] = max(vt[j][k - 1], vt[j + (1 << k - 1)][k - 1]); } } p.push_back(vt); } } void add(int a, int b, int c) { e[idx] = b; w[idx] = c; ne[idx] = h[a]; h[a] = idx++; } void dfs(int u, int fa) { L[u] = ++ seq; for(int i = h[u]; ~i; i = ne[i]) { int j = e[i]; if(j == fa) continue; depth[j] = depth[u] + 1; mx = max(mx, depth[j]); s[j] = s[u] + w[i]; dfs(j, u); f[depth[j]].push_back( L[j] ); //存储dfs序 val[depth[j]].push_back(s[j]); //对应每个dfs序的权值 } R[u] = seq; } LL query(int t, int l, int r) { int len = r - l + 1; int k = log(len) / log(2); return max(p[t - 1][l][k], p[t - 1][r - (1 << k) + 1][k]); } int up(int t, int x) { int len = f[t].size(); if(!len)return -1; int l = 0, r = len - 1; while(l < r) { int mid = (l + r) >> 1; if(f[t][mid] >= x)r = mid; else l = mid + 1; } if(f[t][l] < x) return -1; return l; } int lower(int t, int x) { int len = f[t].size(); if(!len)return -1; int l = 0, r = len - 1; while(l < r) { int mid = (l + r + 1) >> 1; if(f[t][mid] > x)r = mid - 1; else l = mid; } if(f[t][l] > x) return -1; return l; } int main() { memset(h, -1, sizeof h); scanf("%d", &n); for(int i = 1; i < n; i++) { int a, b, c; scanf("%d%d%d", &a, &b, &c); add(a, b, c), add(b, a, c); } dfs(1, -1); init(); scanf("%d", &m); while(m--) { int u, k; scanf("%d%d", &u, &k); // 询问 int t = depth[u] + k; if(t > mx) { printf("-1\n"); continue; } int l = up(t, L[u]), r = lower(t, R[u]); if(l > r || l == -1 || r == -1) { printf("-1\n"); continue; } printf("%lld\n", query(t, l, r) - s[u]); } return 0; }