首页 > 试题广场 >

“好序列”的个数

[编程题]“好序列”的个数
  • 热度指数:462 时间限制:C/C++ 1秒,其他语言2秒 空间限制:C/C++ 32M,其他语言64M
  • 算法知识视频讲解
现在你面前有一棵n个节点的树(全连通无环图)。树上的边只有2种颜色,红色或者黑色。现在还给你一个整数k,考虑下面这个k个节点的序列[a1, a2, ..., ak]。
[a1, a2, ..., ak]如果是”好序列“当且仅当满足下面的条件:
1. 我们要走一条从a1开始到ak结束的路径。
2. 从a1开始,到a2走一条a1到a2的最短路。然后从a2开始,继续走一条到a3的最短路,以此类推,最终到a(k-1)和ak。
3. 走的路径中至少包含一条黑色的边。

我们看一下上面的图片中的树,如果k=3,那么下面的序列是“好序列”:[1,4,7][5,5,3]。下面的序列不是好序列: [1,4,6][5,5,5][3,7,3]
总共有n^k(n的k次方种路径方案),那么有多少路径是“好序列”呢?这个值可能非常大,输出的结果对(10^9+7)取模就可以。

输入描述:
第一行是2个整数n和k,其中(2 <= n <= 10^5, 2 <= k <= 100),n表示树的节点个数,k表示序列的长度。

下面n-1行,每行包含3个整数,u[i], v[i], w[i],其中1 <= u[i], v[i] <= n, w[i] = 0或1。u[i], v[i]表示这两个节点之间有一条边,w[i]表示这条边的颜色,其中0表示红色,1表示黑色。


输出描述:
输出所有“好序列”的个数模(10^9+7)
示例1

输入

4 4
1 2 1
2 3 1
3 4 1

输出

252

说明

这个例子中,所有序列一共有4^4 = 256个,其中不是好序列的只有4个:

[1, 1, 1, 1]

[2, 2, 2, 2]

[3, 3, 3, 3]

[4, 4, 4, 4]
示例2

输入

4 6
1 2 0
1 3 0
1 4 0

输出

0
示例3

输入

3 5
1 2 1
2 3 0

输出

210
好序列的个数不太好直接算出来,于是考虑用总的路径方案数减去坏序列的个数来得到,总的路径方案数正如题目所说,是n^k个,而一个坏序列由于不能有黑色的边,所以只能在一个没有黑边的子图里取它的所有节点,假如一个没有黑边的子图有m个节点,那么这个子图里一共有m^k个坏序列,如果我们在处理输入的时候不加黑边只加红边,那么这棵树就成为若干个没有黑边的子图(因为我们根本就不加黑边),用dfs数它们的节点数,就能够得到坏序列的个数。求幂用上快速幂。Python3代码:
def fastPow(x, n): #快速幂
	ans = 1
	while n:
		if n & 1: ans = ans * x % MOD
		x = x * x % MOD
		n >>= 1
	return ans

def scan(): return map(int, input().split())

def dfs(now): #dfs求某个子图的节点数
	vis[now] = True
	ans = 1
	for i in G[now]:
		if vis[i] == False:
			ans += dfs(i)
	return ans

MOD = 1000000007
n, k = scan()
G = [[] for _ in range(n+1)]
for _ in range(n-1):
	u, v, color = scan()
	if color == 0: #只加红边,不加黑边
		G[u].append(v)
		G[v].append(u)
vis = [False] * (n+1)
cntBad = 0 #坏序列的个数
for i in range(1, n+1):
	if vis[i] == False:
		cntBad = (cntBad + fastPow(dfs(i), k)) % MOD
print((fastPow(n, k) - cntBad + MOD) % MOD)


编辑于 2020-03-14 11:12:47 回复(3)

反着求,先求出所有序列个数,再减去不符合的个数。

黑边把图分成一个,一个的连通分量。

所以我们只要用dfs求出每个不包含黑边的连通分量的个数sz,sz^k就是该集合不符合的个数。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=100050;
vector<int>ve[maxn];
struct node{
    int nxt,to,f;
}e[maxn*2];
int cnt=0,head[maxn],vis[maxn];
void add(int u,int v,int f){
    e[cnt].to=v;
    e[cnt].f=f;
    e[cnt].nxt=head[u];
    head[u]=cnt++;
}
int sz=0;
void dfs(int u,int fa){
    vis[u]=1;
    sz++;
    for(int i=head[u];~i;i=e[i].nxt){
        int v=e[i].to;
        int f=e[i].f;
        if(v==fa||f||vis[v]) continue;
        dfs(v,u);
    }
}
const ll mod=1e9+7;
ll fpow(ll a,ll p){
    ll ans=1;
    while(p){
        if(p&1) ans=ans*a%mod;
        a=a*a%mod;
        p>>=1;
    }
    return ans;
}
int main()
{
    int n,k;
    memset(head,-1,sizeof(head));
    scanf("%d%d",&n,&k);
    for(int i=1;i<n;i++){
        int x,y,z;
        scanf("%d%d%d",&x,&y,&z);
        add(x,y,z);
        add(y,x,z);
    }
    ll ans=fpow(n,k);
    for(int i=1;i<=n;i++){
        if(vis[i]) continue;
        sz=0;
        dfs(i,i);
        ans=(ans+mod-fpow(sz,k))%mod;
    }
    printf("%lld\n",ans);
    return 0;
}


发表于 2020-03-20 21:47:27 回复(0)
package com.wql; import java.io.Serializable; import java.math.BigInteger; import java.util.HashMap; import java.util.HashSet; import java.util.Scanner; import java.util.Set; public class main { static int mod = 1000000007; static Set<Integer>[] adj; static boolean[] visited; public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in); int n = scanner.nextInt(); int k = scanner.nextInt(); adj = new HashSet[n+1]; visited = new boolean[n+1];
        BigInteger bad = new BigInteger("0"); for (int i = 0; i <= n; i++) { adj[i] = new HashSet<>();
        } while (scanner.hasNext()) { int a = scanner.nextInt(); int b = scanner.nextInt(); int weight = scanner.nextInt(); if(weight== 0){ adj[a].add(b); adj[b].add(a);
            }
        } for (int i = 0; i <= n; ++i) { if(!visited[i]){
                bad = bad.add(fastPow(dfs(i),k)).mod(new BigInteger(""+mod));
            }
            BigInteger big = fastPow(n,k);
            System.out.println(big.subtract(bad).add(new BigInteger(mod+"")).mod(new BigInteger(mod+"")));
        }
    } public static int dfs(int num) { visited[num] = true; int res = 1; for (int next : adj[num]) { if (!visited[next]) res += dfs(next);
        } return res;
    } public static BigInteger fastPow(int n,int k){ if(k==0){ return new BigInteger("1");
            }
            BigInteger half = fastPow(n,k/2); if(k%2==0){ return half.multiply(half).mod(new BigInteger(mod+""));
            }else { return half.multiply(half).multiply(new BigInteger(n+"")).mod(new BigInteger(n+""));
            }
        }
    }

发表于 2020-03-15 11:20:28 回复(0)
参照上面的例子,给一Java能通过的解法
import java.math.*;
import java.util.*;
public class Main{
    
    static int mod=1000000007;
    static Set<Integer>[] adj;
    static boolean[] visited;
    public static void main(String[] args){

        Scanner scanner=new Scanner(System.in);
        int n=scanner.nextInt(),k=scanner.nextInt();
        adj=new HashSet[n+1];
        visited=new boolean[n+1];
        BigInteger bad=new BigInteger("0");
        for(int i=1;i<=n;i++) adj[i]=new HashSet<>();
        while(scanner.hasNext()){
            int a=scanner.nextInt(),b=scanner.nextInt(),weight=scanner.nextInt();
            if(weight==0) {
                adj[a].add(b);
                adj[b].add(a);
            }
        }
        for(int i=1;i<=n;i++)
            if(!visited[i]) bad=bad.add(fastPow(dfs(i),k)).mod(new BigInteger(""+mod));
        BigInteger big=fastPow(n,k);
        System.out.println(big.subtract(bad).add(new BigInteger(mod+"")).mod(new BigInteger(mod+"")));
    }

    public static int dfs(int num){
        visited[num]=true;
        int res=1;
        for(int next:adj[num]){
            if(!visited[next]) res+=dfs(next);
        }
        return res;
    }

    public static BigInteger fastPow(int n,int k){
        if(k==0) return new BigInteger("1");
        BigInteger half=fastPow(n,k/2);
        if(k%2==0) return half.multiply(half).mod(new BigInteger(mod+""));
        else return half.multiply(half).multiply(new BigInteger(n+"")).mod(new BigInteger(mod+""));
    }
    
}


发表于 2020-03-14 11:45:25 回复(0)