暴力***超时,普遍只有 60% 的通过率
使用归并排序,在合并时统计距离,能降低时间复杂度
首先题目明确说明是计算坐标之间的距离,例子中的说明也表明了这一点。
但是呢,题目也说了,输入的数组是一个 1 到 n 的排列,这会导致逆序对距离之和等于逆序对元素差(大的减小的)的和。
逆序对:
(3, 2)差为1
(4, 2)差为2
总和为3
归并排序不说了,下面合并时如何统计逆序对的元素差。
将left与right数组合并,left中的元素与right中的元素都是已排序的。
这时,如果遇到left[i] > right[j],不仅仅表明i, j是一个逆序对,i + 1, j也是,i + 2, j也是 ...
如果只是单纯将left[i] - right[j]加到总距离中,然后j指针后移,显然,i + 1, j等逆序对的距离就被忽略了。
“正确”的做法,需要从i开始遍历left,计算所有的距离差,并加到总距离中,即:
dis += left[i] - right[j] + left[i + 1] - right[j] + ... + left[len(left) - 1] - right[j]
当然,如何使用遍历,那么时间复杂度是无法降低的。所以,观察上述式子,我们可以得出
dis += sum(left[i] ... left[len(left) - 1]) - right[j] * (len(left) - i)
使用一个变量记录sum(left[i] ... left[len(left) - 1])的值即可,之后每次发现逆序对们,只需要 O(1) 的时间即可计算出所有逆序对间的距离的和。
贴上代码:
n = int(input()) nums = list(map(int, input().split())) ans = 0 def mergesort(arr, left, right): global ans if left >= right: return mid = (left + right) // 2 mergesort(arr, left, mid) mergesort(arr, mid + 1, right) # merge res = [] i, j = left, mid + 1 sum_left = sum(arr[left:mid + 1]) while i < mid + 1 and j < right + 1: if arr[i] > arr[j]: ans += sum_left - (mid + 1 - i) * arr[j] res.append(arr[j]) j += 1 else: sum_left -= arr[i] res.append(arr[i]) i += 1 if i < mid + 1: res.extend(arr[i:mid + 1]) elif j < right + 1: res.extend(arr[j:right + 1]) arr[left:right + 1] = res mergesort(nums, 0, n - 1) print(ans)
import java.util.Scanner; public class Main { static long ans = 0; public static void main(String[] args) { Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int[] nums = new int[n]; for(int i = 0; i < n; i++) { nums[i] = sc.nextInt(); } mergerSort(nums, 0, nums.length - 1); System.out.println(ans); } public static void mergerSort(int[] nums, int l, int r) { if (l >= r) { return; } int m = l + (r - l) / 2; mergerSort(nums, l, m); mergerSort(nums, m + 1, r); if(nums[m] >= nums[m+1]) { merger(nums, l, m, r); } } private static void merger(int[] nums, int l, int m, int r) { int[] temp = new int[r - l + 1]; int i = l, j = m + 1, k = 0; while (i <= m && j <= r) { if (nums[i] > nums[j]) { // 产生了 j-i+1 对逆序对 int t = i; while(t <= m) { ans += (nums[t++] - nums[j]); } temp[k++] = nums[j++]; } else { temp[k++] = nums[i++]; } } while(i <= m) { temp[k++] = nums[i++]; } while(j <= r) { temp[k++] = nums[j++]; } for(int a = 0; a < temp.length; a++) { nums[l+a] = temp[a]; } } }
import java.util.*; public class Main{ public static void main(String[] args){ Scanner sc = new Scanner(System.in); int n = sc.nextInt(); int[][] arr = new int[n][2]; for(int i=0;i<n;i++){ arr[i][0] = sc.nextInt(); arr[i][1] = i; } long res = getAns(arr,n); System.out.println(res); } public static long getAns(int[][] arr,int n){ int[][] tmp = new int[n][2]; return reverse(arr,0,n-1,tmp); } public static long reverse(int[][] arr,int left,int right,int[][] tmp){ if(left>=right) return 0; int mid = (left+right)>>>1; long leftDistance = reverse(arr,left,mid,tmp); long rightDistance = reverse(arr,mid+1,right,tmp); if(arr[mid][0]<=arr[mid+1][0]) return leftDistance+rightDistance; long crossDistance = reverseCross(arr,left,mid,right,tmp); return crossDistance+leftDistance+rightDistance; } public static long reverseCross(int[][] arr,int left,int mid,int right,int[][] tmp){ for(int i=left;i<=right;i++){ tmp[i][0] = arr[i][0]; tmp[i][1] = arr[i][1]; } int i = left,j = mid+1; long count = 0; for(int k=left;k<=right;k++){ if(i==mid+1){ arr[k][0] = tmp[j][0]; arr[k][1] = tmp[j][1]; j++; }else if(j==right+1){ arr[k][0] = tmp[i][0]; arr[k][1] = tmp[i][1]; i++; }else if(tmp[i][0]<=tmp[j][0]){ arr[k][0] = tmp[i][0]; arr[k][1] = tmp[i][1]; i++; }else if(tmp[i][0]>tmp[j][0]){ arr[k][0] = tmp[j][0]; arr[k][1] = tmp[j][1]; for(int l=i;l<=mid;l++){ count+=tmp[j][1]-tmp[l][1]; } j++; } } return count; } }Java代码,用的归并排序方法,用二维数组分别记录值和索引,注意count要用long类型。
#include <cstdio> #define ll long long const int N = 1e5+5; int a[N],tmp[N],pre[N],pos[N]; ll merge(int l1,int r1,int l2,int r2){ int s1=l1,s2=l2; int cnt = s1; ll res = 0; pre[l1-1] = 0; for(int i=l1;i<=r1;i++) pre[i] = pre[i-1] + pos[a[i]]; while(s1<=r1&&s2<=r2){ if(a[s1]>a[s2]){ res+=(ll)(r1-s1+1)*pos[a[s2]] - (pre[r1]-pre[s1-1]); tmp[cnt++] = a[s2++]; }else{ tmp[cnt++] = a[s1++]; } } while(s1<=r1) tmp[cnt++] = a[s1++]; while(s2<=r2) tmp[cnt++] = a[s2++]; for(int i=l1;i<=r2;i++) a[i] = tmp[i]; return res; } ll mergeSort(int l,int r){ if(l>=r) return 0; int mid = (l+r)>>1; ll r1 = mergeSort(l,mid); ll r2 = mergeSort(mid+1,r); ll r3 = merge(l,mid,mid+1,r); return r1+r2+r3; } int main(){ int n; scanf("%d",&n); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); pos[a[i]] = i; } ll res = mergeSort(1,n); // for(int i=1;i<=n;i++) printf("%d ",a[i]); // puts(""); printf("%lld\n",res); return 0; }
n = int(input()) num = list(map(int, input().split())) ans = 0 acc = 0 for i, x in enumerate(num): m = i + 1 acc += x g = (m + 1) * m / 2 ans += (acc - g) print(int(ans))
n = int(input()) nums = list(map(int, input().split())) ans = 0 def mergesort(l): global ans L = len(l) if L <= 1: return l mid = L // 2 left = mergesort(l[0:mid]) right = mergesort(l[mid:]) res = [] sum_left = sum(left) if left else 0 while left and right: if right[0] < left[0]: ans += sum_left - len(left) * right[0] res.append(right[0]) right.pop(0) else: sum_left -= left[0] res.append(left[0]) left.pop(0) if left: res += left elif right: res += right return res mergesort(nums) print(ans)归并排序的同时记录和
import java.util.Scanner; public class demo { public static void main(String[] args) { Scanner sc=new Scanner(System.in); int n=sc.nextInt(); int[] a=new int[n]; sc.nextLine(); for (int k=0;k<n;k++){ a[k]=sc.nextInt(); } int ans=0; for (int i=0;i<n;i++){ for (int j=i+1;j<n;j++){ if (a[i]>a[j]){ ans=ans+Math.abs(a[i]-a[j]); } } } System.out.println(ans); } }
# include <bits/stdc++.h> using namespace std; long long CountAndMerge(vector<int>& nums, vector<int>& tmp, int l, int r){ if(l>=r) return 0; int mid = (l+r)/2; long long sum = CountAndMerge(nums, tmp, l, mid) + CountAndMerge(nums, tmp, mid+1, r); //三指针 int i = l, j = mid+1, pos = l; //求left的和 long long left_sum = 0; for(int i = l; i <=mid; i++){ left_sum+=nums[i]; } while (i<=mid && j<=r) { if(nums[i]<=nums[j]){ tmp[pos] = nums[i]; left_sum -= nums[i]; ++i; } else{ sum += left_sum - (mid-i+1) * nums[j]; tmp[pos] = nums[j]; ++j; } ++pos; } for(int k = i; k<=mid; k++){ tmp[pos++] = nums[k]; } for(int k = j; k<=r; ++k){ tmp[pos++] = nums[k]; } copy(tmp.begin() + l, tmp.begin() + r+1, nums.begin() + l); return sum; } int main(){ int N; cin >> N; vector<int> nums(N); for(int i = 0; i<N; i++){ int num; cin >> num; nums[i] = num; } //计算逆序对距离和 //临时变量 vector<int> tmp(N); long long ans = CountAndMerge(nums, tmp, 0, N-1); cout << ans << endl; return 0; }
用stack和字典提速,还是只有60%通过率,求大佬告知还有没有提速的可能。import collections res=0 dic=collections.defaultdict(list) n=int(input()) l=list(map(int,input().split())) stack=[l[0]] dic[l[0]].append(0) for i in range(1,n): ttmp=set() y=len(stack)-1 while y>=0 and l[i]<stack[y] : if stack[y] not in ttmp: for j in dic[stack[y]]: res+=i-j ttmp.add(stack[y]) y-=1 dic[l[i]].append(i) stack.insert(y+1,l[i]) print(res)
[7,5,6,4]通过归并的时候会变成
[7] [5] | [6] [4] # | 表示两个数组分割
[5,7] [4,6]
[4,5,7,6]
import java.util.Arrays; import java.util.Scanner; public class Main { public static void main(String[] args) { Scanner scanner = new Scanner(System.in); int n = scanner.nextInt(); int sum = 0; scanner.nextLine(); String str = scanner.nextLine(); String[] arrStr = str.split("\\s"); int[] arrInt = new int[arrStr.length]; for (int i = 0; i < arrStr.length; i++) { arrInt[i] = Integer.parseInt(arrStr[i]); } for (int i = 0; i < n; i++) { for (int j = i+1; j < n; j++) { if (arrInt[i] > arrInt[j]) { sum += (j - i); } } } System.out.println(sum); } }