I substring题解
虽然已经有官方题解,这个算是个稍微详细那么一点的题解吧
题意:
给出一个串S,问最多可以选出多少个子串使得选出的子串两两不同构,同构的定义是,两个字符串A和B,通过一个映射函数f,让B的每个字符c通过映射函数c = f(c)后得到B',如果A == B',则A,B同构。例如ab和bc同构,映射函数是f(a) = b,f(b) = c,f(c) = a
思路:这题其实就是算S有多少个不同构的子串。先看个小例子,aab、bbc、aac,怎么算这3个串中有多少个不同构的串?这题字符只有abc三种,则映射函数只有3! = 6种(即将abc做全排列 后与abc一一对应即可得到映射函数),那我们可以按照映射函数,将每个串通过每个映射后得到的串放在一起然后去重,得到aab、aac、bba、bbc、cca、ccb,6个串,这样,得到的结果会算多6次,因为映射函数6种,而有个特殊情况是单字符的时候,aa、bb、cc,这时候只会算多3次,那答案就是(不同子串数量 + 3 × 单一字符的串数量) / 6
解法:个人用SAM计算不同子串,先将S通过映射函数得到6个串,S1、S2、...、S6,拼接得到T = S1:S2:...:S6,其中':'是除['a', 'z']字符以外的任意字符,接着跑个SAM,那SAM如何算多个串的不同子串个数呢(也就是T中不含':'的不同子串个数),我们知道某个串x是T的子串的话,通过起始状态跑x能够跑完,也就是说T的所有子串都可以通过起始状态跑到,那做个dp,dp[i]表示i状态与i的所有后续状态的路径总数(注意不跑拼接字符':'那条边),记忆化一下,得到不同子串个数。单个字符的串的数量则通过跑单个字符'a'一直跑,跑到最大长度时长度即个数
更新代码:
#include<iostream>
#include<cstring>
#include<vector>
#include<cstdio>
#include<algorithm>
#define rep(i,e) for(int i=0;i<(e);i++)
#define PB push_back
#define scd(a) scanf("%d",&a)
using namespace std;
typedef long long ll;
const int N = 1e6+10;
int idx;
int maxlen[N], minlen[N], trans[N][27], slink[N];
int new_state(int _maxlen, int _minlen, int* _trans, int _slink) {
maxlen[idx] = _maxlen;
minlen[idx] = _minlen;
for(int i = 0; i < 27; i++) {
if(_trans == NULL)
trans[idx][i] = -1;
else
trans[idx][i] = _trans[i];
}
slink[idx] = _slink;
return idx++;
}
int add_char(char ch, int u) {
int c = ch - 'a';
int z = new_state(maxlen[u] + 1, -1, NULL, -1);
while(u != -1 && trans[u][c] == -1) {
trans[u][c] = z;
u = slink[u];
}
if(u == -1) {
minlen[z] = 1;
slink[z] = 0;
return z;
}
int x = trans[u][c];
if(maxlen[u] + 1 == maxlen[x]) {
minlen[z] = maxlen[x] + 1;
slink[z] = x;
return z;
}
int y = new_state(maxlen[u] + 1, -1, trans[x], slink[x]);
minlen[z] = minlen[x] = maxlen[y] + 1;
slink[z] = slink[x] = y;
while(u != -1 && trans[u][c] == x) {
trans[u][c] = y;
u = slink[u];
}
minlen[y] = maxlen[slink[y]] + 1;
return z;
}
int n;
char s[N];
char f[200]; // 存映射
vector<int> ve;
void deal(int &st){
rep(k,3) f[ve[k] + 'a'] = k + 'a';
rep(i,n){
st = add_char(f[s[i]], st);
}
}
ll dp[N];
ll dfs(int st){
ll& ret = dp[st];
if(ret!=-1) return ret;
ret = st!=0; // 不算状态0
rep(i,26){
if(trans[st][i]!=-1){
ret += dfs(trans[st][i]);
}
}
return ret;
}
void work(){
idx=0;
scanf("%s",s);
n = strlen(s);
ve.clear();
rep(i,3)ve.PB(i);
int sta = new_state(0,0,NULL,-1);
do{
deal(sta);
sta = add_char(26 + 'a', sta);
}while(next_permutation(ve.begin(), ve.end()));//全排列
rep(i,idx) dp[i] = -1;
int cnt = -1; // 因为状态0是空串不算进去
for(int st = 0;st!=-1;st = trans[st][0], cnt++); // 计算同字符的串的数量(最大长度)
printf("%lld\n", (dfs(0)+cnt*3)/6);
}
int main() {
while(scd(n)==1)
work();
}

查看18道真题和解析