牛客IOI周赛23-提高组 A-单词
单词
https://ac.nowcoder.com/acm/contest/11165/A
先用kmp求出t能出现在s的哪些位置。
我们维护这样的三个dp方程:
dp1[i]表示 最后一个区间结尾r在i的方案数。
dp2[i]表示 最后一个区间结尾r在[0...i]的方案数。
dp3[i]表示 最后一个区间开头l在[0...i]的方案数。
转移即可,dp2[n]即为答案。
复杂度O(n)。
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 1e6 + 100;
const int MOD = 99824353;
int n, m;
int nex[N];
char s1[N], s2[N];
void getnex(char *str, int n) {
int i = 0, j = -1;
nex[0] = -1;
while (i < n) {
if (j == -1 || str[i] == str[j]) {
i++; j++;
nex[i] = j;
}
else
j = nex[j];
}
}
vector<int> V;
int dp1[N], dp2[N], dp3[N];
/*
dp1[i] 表示r以i结尾的区间个数
dp2[i] 表示r以0...i结尾的区间个数
dp3[i] 表示l在0...i的区间个数
*/
void upd(int &a, int b) {
a += b;
if (a >= MOD) a -= MOD;
}
int main() {
//freopen("0.txt", "r", stdin);
scanf("%s%s", s1, s2);
n = strlen(s1);
m = strlen(s2);
getnex(s2, m);
for (int i = 0, j = 0; i < n;) {
if (j == -1 || s1[i] == s2[j]) {
i++; j++;
if (j == m) V.push_back(i);
}
else
j = nex[j];
}
if (V.size() == 0) { puts("1"); return 0; }
dp1[0] = dp2[0] = 1;
int ans = 0;
for (int i = 1, j = 0; i <= n; i++) {
dp3[i] = (dp3[i - 1] + dp2[i - 1]) % MOD;
if (i >= V[0]) {
if (j + 1 < V.size() && i == V[j + 1]) j++;
dp1[i] = dp3[V[j] - m + 1];
}
dp2[i] = (dp2[i - 1] + dp1[i]) % MOD;
}
printf("%d\n", dp2[n]);
return 0;
}