首页 > 试题广场 >

牛牛的航路

[编程题]牛牛的航路
  • 热度指数:164 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 256M,其他语言512M
  • 算法知识视频讲解
牛牛当上了牛客国的国王,在牛客国一共有个城市,这个城市之间原有条航路,但是由于城市航路规划,需要删除一些航路,使剩下的航路刚好能让这个城市联通,并且剩下的航路总载客量最大。牛客国的计数方式比较特别,他们喜欢用组合数来计数。所以两个城市之间航路的载客量也是用表示的。作为国师你能告诉牛牛剩下的航路的总载客量是多少呢,由于数可能很大请对取模后告诉牛牛。

输入描述:
第一行为两个整数,表示城市数量和原有航路条数。
接下来有行,每行有四个整数,表示一条航路之间的两个城市编号和航路载客量中的



输出描述:
输出为一行,输出剩下的航路的总载客量并对取模,若剩下的航路不能让所有城市联通输出
示例1

输入

5 5
1 2 1 1
1 5 1 1
3 5 1 1
2 4 1 1
4 5 2 1

输出

5
MST裸题,正常的并查集维护就可以
但是阴间在怎么比较组合数的大小
下面的代码使用的是直接高精度打表 + 各种字符串优化
本题卡常十分严重,在此感谢本题的一血大佬帮助优化代码 Orz 。
#include <bits/stdc++.h>
using namespace std;

string C[1005][1005];
int modC[1005][1005];
int mod = 1e9 + 7;

string add(const string &a, const string &b, string &ans)
{
    ans.reserve(max(a.size(), b.size()) + 1);
    int carry = 0;
    int i = a.size() - 1, j = b.size() - 1;
    while (i >= 0 || j >= 0) {
        int tmp = carry;
        if (i >= 0) {
            tmp += a[i] - '0';
            i--;
        }
        if (j >= 0) {
            tmp += b[j] - '0';
            j--;
        }
        carry = tmp / 10;
        ans += (char)((tmp % 10) + '0');
    }
    if (carry != 0) ans += (char)((carry % 10) + '0');
    reverse(ans.begin(), ans.end());
    return ans;
}

void initCo(int n)
{
    for (int i = 0; i <= n; i++) {
        C[i][0] = C[i][i] = "1";
        modC[i][0] = modC[i][i] = 1;
    }
    for (int i = 2; i <= n; i++) {
        for (int j = 1; j <= i / 2; j++) {
            add(C[i - 1][j], C[i - 1][j - 1], C[i][j]);
            C[i][i - j] = C[i][j];
            modC[i][j] = modC[i - 1][j] + modC[i - 1][j - 1];
            modC[i][j] %= mod;
            modC[i][i - j] = modC[i][j];
        }
    }
}


struct Edge {
    int u, v, a, b;
    string *val;
};

int cmp(Edge a, Edge b)
{
    if (a.val->size() != b.val->size()) {
        return a.val->size() > b.val->size();
    }
    return *(a.val) > *(b.val);
}

Edge e[500005];
int fa[1005];

int findFa(int x)
{
    int root = x;
    while (fa[root] != root) {
        root = fa[root];
    }
    while (x != root) {
        int fx = fa[x];
        fa[x] = root;
        x = fx;
    }
    return root;
}

void mergeFa(int x, int y)
{
    int fx = findFa(x);
    int fy = findFa(y);
    if (fx != fy) {
        fa[fx] = fy;
    }
}

int main()
{
    initCo(1000);
    int n, m;
    // input
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i++) {
        scanf("%d%d%d%d", &e[i].u, &e[i].v, &e[i].a, &e[i].b);
        e[i].val = &C[e[i].a][e[i].b];
    }
    // sort edges
    sort(e, e + m, cmp);

    // union find
    int ans = 0;
    for (int i = 1; i <= n; i++) {
        fa[i] = i;
    }
    int cnt = 0;
    for (int i = 0; i < m; i++) {
        if (findFa(e[i].u) == findFa(e[i].v)) {
            continue;
        }
        mergeFa(e[i].u, e[i].v);
        ans = (ans + modC[e[i].a][e[i].b]) % mod;
        cnt++;
    }
    if (cnt != n - 1) {
        printf("-1\n");
    } else {
        printf("%d\n", ans);
    }

    return 0;
}

二更:将阶乘取对数,由于阶乘的对数和阶乘本身都满足递增,所以这种方法可以省去大量高精度的字符串操作,节省大量时间和空间。
C(n, m) = n! / (n-m)! / m!
等号两侧取对数:log(C(n, m)) = log(n!) - log((n-m)!) - log(m!)
其中 log(n!) 可以预处理为:log(1) + log(2) + ... + log(n)
#include <bits/stdc++.h>
using namespace std;

int C[1005][1005];
double logA[1005];
int mod = 1e9 + 7;

void init(int n) {
    for (int i = 0; i <= n; i++) {
        C[i][0] = C[i][i] = 1;
    }
    for (int i = 2; i <= n; i++) {
        for (int j = 1; j < i; j++) {
            C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
            C[i][j] %= mod;
        }
    }
    logA[1] = log(1);
    for (int i = 2; i <= n; i++) {
        logA[i] = logA[i - 1] + log(i);
    }
}

struct Edge {
    int u, v, a, b;
    double val;
};

int cmp(Edge a, Edge b) {
    return a.val > b.val;
}

Edge e[500005];
int fa[1005];

int findFa(int x)
{
    int root = x;
    while (fa[root] != root) {
        root = fa[root];
    }
    while (x != root) {
        int fx = fa[x];
        fa[x] = root;
        x = fx;
    }
    return root;
}

void mergeFa(int x, int y)
{
    int fx = findFa(x);
    int fy = findFa(y);
    if (fx != fy) {
        fa[fx] = fy;
    }
}

int main()
{
    init(1000);
    int n, m;
    scanf("%d%d", &n, &m);
    for (int i = 0; i < m; i++) {
        scanf("%d%d%d%d", &e[i].u, &e[i].v, &e[i].a, &e[i].b);
        e[i].val = logA[e[i].a] - logA[e[i].b] - logA[e[i].a - e[i].b];
    }
    sort(e, e + m, cmp);

    int ans = 0;
    for (int i = 1; i <= n; i++) {
        fa[i] = i;
    }
    int cnt = 0;
    for (int i = 0; i < m; i++) {
        if (findFa(e[i].u) == findFa(e[i].v)) {
            continue;
        }
        mergeFa(e[i].u, e[i].v);
        ans = (ans + C[e[i].a][e[i].b]) % mod;
        cnt++;
    }
    if (cnt != n - 1) {
        printf("-1\n");
    } else {
        printf("%d\n", ans);
    }

    return 0;
}



编辑于 2021-09-30 20:18:43 回复(1)
采用 并查集和拓扑搜索 都只能 通过 40%,有没有大佬 分享下 题解代码。
import java.util.*;

public class Main {
    public static void main(String[] args) throws Exception {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNext()) {
            String[] s = sc.nextLine().split(" ");
            int n = Integer.parseInt(s[0]);
            int m = Integer.parseInt(s[1]);
            long[][] arr = new long[m][3];
            for (int i = 0; i < m; i++) {
                s = sc.nextLine().split(" ");
                arr[i][0] = Long.parseLong(s[0]);
                arr[i][1] = Long.parseLong(s[1]);
                long a = Long.parseLong(s[2]);
                long b = Long.parseLong(s[3]);
                long a1 = fact(a);
                long b1 = (fact(b) * (fact(a - b)));
                if(b1 == 0) {
                    arr[i][2] = 0;
                } else {
                    arr[i][2] = a1 / b1;
                }
            }
            if (m < n - 1) {
                System.out.println("-1");
                continue;
            }
            Arrays.sort(arr, (o1, o2) -> (int) (o2[2] - o1[2]));
            UnionFind unionFind = new UnionFind();
            unionFind.init(n + 1);
            long res = 0;
            for (int i = 0; i < m; i++) {
                int u = (int) arr[i][0];
                int v = (int) arr[i][1];
                long weight = arr[i][2];
//                System.out.println(u + "    " + v + "   " + weight);
                if (unionFind.find(u) == unionFind.find(v)) {
                    continue;
                }
                unionFind.merge(u, v);
                res += weight;
            }
            int[] parent = unionFind.parent;
            int count = 0;
            for (int i = 1; i < parent.length; i++) {
                if (i == parent[i]) {
                    count++;
                }
            }
            if (count == 1) {
                System.out.println((int) res % 1000000007);
            } else {
                System.out.println("-1");
            }
        }
    }
    // 求阶乘
    private static long fact(long n) {
        long sum = 1;
        for (int i = 2; i <= n; i++) {
            sum *= i;
        }
        return sum;
    }
}

class UnionFind {
    int[] parent;
    int[] rank;
    void init(int len) {
        this.parent = new int[len];
        this.rank = new int[len];
        
        for (int i = 0; i < len; ++i) {
            parent[i] = i;
            rank[i] = 1;
        }
    }
    int find(int x) {
        if (x == parent[x]) {
            return parent[x];
        } else {
            parent[x] = find(parent[x]);
            return parent[x];
        }
    }
    void merge(int x, int y) {
        int xr = find(x);
        int yr = find(y);
        if (xr == yr) {
            return;
        }
        if (rank[xr] == rank[yr]) {
            parent[xr] = yr;
            rank[yr]++;
        } else if (rank[xr] < rank[yr]) {
            parent[xr] = yr;
        } else {
            parent[yr] = xr;
        }
    }
}
import java.util.*;

public class Main {
    public static void main(String[] args) throws Exception {
        Scanner sc = new Scanner(System.in);
        while (sc.hasNext()) {
            String[] s = sc.nextLine().split(" ");
            int n = Integer.parseInt(s[0]);
            int m = Integer.parseInt(s[1]);
            long[][] arr = new long[m][3];
            Map<Long, List<long[]>> map = new HashMap<>();
            long x = 0;
            long max = 0;
            for (int i = 0; i < m; i++) {
                s = sc.nextLine().split(" ");
                arr[i][0] = Long.parseLong(s[0]);
                arr[i][1] = Long.parseLong(s[1]);
                long a = Long.parseLong(s[2]);
                long b = Long.parseLong(s[3]);
                long a1 = fact(a);
                long b1 = (fact(b) * (fact(a - b)));
                if(b1 == 0) {
                    arr[i][2] = 0;
                } else {
                    arr[i][2] = a1 / b1;
                }
                if (max < arr[i][2]) {
                    x = arr[i][0];
                }
                if (map.containsKey(arr[i][0])) {
                    map.get(arr[i][0]).add(new long[]{arr[i][1], arr[i][2]});
                } else {
                    List<long[]> list = new ArrayList<>();
                    list.add(new long[]{arr[i][1], arr[i][2]});
                    map.put(arr[i][0], list);
                }
                if (map.containsKey(arr[i][1])) {
                    map.get(arr[i][1]).add(new long[]{arr[i][0], arr[i][2]});
                } else {
                    List<long[]> list = new ArrayList<>();
                    list.add(new long[]{arr[i][0], arr[i][2]});
                    map.put(arr[i][1], list);
                }
            }
            if (m < n - 1) {
                System.out.println(-1);
                continue;
            }
            PriorityQueue<long[]> pq = new PriorityQueue<long[]>((o1, o2) -> (int) (o2[1] - o1[1]));
            pq.offer(new long[]{x, 0});
            Set<Long> set = new HashSet<>();
            long res = 0;
            while (!pq.isEmpty()) {
                long[] tmp = pq.poll();
                long u = tmp[0];
                if (set.contains(u)) {
                    continue;
                }
                set.add(u);
                res += tmp[1];
                if (map.containsKey(u)) {
                    List<long[]> vs = map.get(u);
                    for (long[] v : vs) {
                        pq.offer(new long[]{v[0], v[1]});
                    }
                }
            }
            if (set.size() == n) {
                System.out.println((int) res % 1000000007);
            } else {
                System.out.println(-1);
            }
        }
    }

    // 求阶乘
    private static long fact(long n) {
        long sum = 1;
        for (int i = 2; i <= n; i++) {
            sum *= i;
        }
        return sum;
    }
}







编辑于 2021-08-05 11:23:23 回复(2)