小强今天体检,其中有一个环节是测视力
小强看到的视力表是一张的表格,但是由于小强视力太差,他无法看清表格中的符号。不过热爱数学的他给自己出了这样一个问题:假设现在有a个向上的符号,b个向下的符号,c个向左的符号,d个向右的符号,把这些符号填到视力表中,总共有多少种可能的情况呢?
第一行输入五个数N, a, b, c, d保证
输出一个数字,表示答案由于结果可能很大,只需输出对998244353取模之后的结果即可
2 3 1 0 0
4
共有如下四种情况上上 上上上下 下上上下 下上上上 上上
2 2 1 1 0
12
2 1 1 1 1
24
数据保证,且对于的数据,N = 2对于的数据,对于的数据,对于的数据,
#include<iostream> #include<cmath> #include<string> #include<cstring> #include<vector> #include<map> #include<iomanip> #include<algorithm> #include<cstdio> #include<queue> #include<deque> #include<stack> #include<set> #include <cstdlib> #include <climits> #include <ctype.h> #include<functional> using namespace std; const int maxn = 1e5 + 5; typedef long long int ll; int n, a, b, c, d; const int MOD = 998244353; ll f[maxn], inv[maxn];//f存储i阶乘,inv存储(i阶乘)的逆元 // ll Pow(ll x, ll t) { ll cnt = 1; while (t) { if (t & 1) cnt = cnt*x%MOD; x = x * x %MOD; t >>= 1; } return cnt; } void init() { cin >> n >> a >> b >> c >> d; n = n*n; f[0] = 1; for (int i = 1;i <= n;i++) f[i] = i*f[i - 1] % MOD; //a的逆元 = a^(p-2) mod p inv[0] = 1; inv[n] = Pow(f[n], MOD - 2); //inv[i]=inv[i+1]*(i+1) for (int i = n-1;i >= 1;i--) inv[i] = inv[i + 1] * (i + 1) % MOD; } int main() { init(); ll ans = f[n]; //cout << f[n] << ' ' << inv[a] << ' ' << inv[b] << ' ' << inv[c] << ' ' << inv[d] << '\n'; //for (int i = n;i >= 1;i--) // cout << inv[i] << ' '; //cout << endl; ans = ans*inv[a] % MOD; ans = ans*inv[b] % MOD; ans = ans*inv[c] % MOD; ans = ans*inv[d] % MOD; cout << ans << '\n'; } /* 3 7 1 0 1 */
由于这个题没有给出具体数据, 写的时候照着能解决尽可能大的数据写的
通过分析, 不难发现最终结果为组合数
可以选择正常的组合数公式, 也可以选择使用阶乘公式+逆元进行组合数计算, 但是由于初衷是解决尽可能大的数据, 选择了使用lucas定理解决
lucas定理: (摘自百度)
我们令
那么:
代码可以递归的去实现这个过程, 其中递归终点为
时间
换成人话说就是
#include<iostream> #include<algorithm> using namespace std; #define int long long int qmi(int a,int k,int p) { int res = 1; while(k) { if(k&1)res = res*a%p; a = a*a%p; k>>=1; } return res; } int C(int a,int b,int p) { if(b>a)return 0; int res = 1; for(int i=1,j=a;i<=b;i++,j--) { res = res*j%p; res = res*qmi(i,p-2,p)%p; } return res; } int lucas(int a,int b,int p) { if(a<p && b<p) return C(a,b,p); return C(a%p,b%p,p)*lucas(a/p,b/p,p)%p; } signed main() { int n, a, b, c, d; cin >> n >> a >> b >> c >> d; int p = 998244353; // C_{n*n}^a*C_{n*n-a}^b*C_{n*n-a-b}^c int ans = lucas(n*n, a, p) * lucas(n*n-a, b, p) %p; ans *= lucas(n*n-a-b, c, p) %p; cout << ans % p << endl; }
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.math.BigInteger; import java.util.Arrays; public class Main { static int mod = 998244353; public static void main(String[] args) throws InterruptedException, IOException { BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); int[] s = Arrays.stream(br.readLine().split(" ")).mapToInt(Integer::parseInt).toArray(); int N = s[0]; int a = s[1]; int b = s[2]; int c = s[3]; BigInteger t = cal(N * N - a - b - c + 1, N * N); BigInteger ka = cal(1, a); BigInteger kb = cal(1, b); BigInteger kc = cal(1, c); t = t.divide(ka.multiply(kb.multiply(kc))); t = t.mod(new BigInteger(mod + "")); System.out.println(t.toString()); } public static BigInteger cal(int a, int b) { BigInteger t = new BigInteger(1 + ""); for (int i = a; i <= b; i++) { t = t.multiply(new BigInteger(i + "")); } return t; } }