记录编号 |
481157 |
评测结果 |
AAAAAAAAAAAAAAAAAAAA |
题目名称 |
[HAOI 2016]找相同子串 |
最终得分 |
100 |
用户昵称 |
Shirry |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
1.138 s |
提交时间 |
2018-01-01 15:43:11 |
内存使用 |
13.64 MiB |
显示代码纯文本
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=400010;
char s[maxn],a[maxn],b[maxn];
int n,c[maxn],x[maxn],y[maxn],sa[maxn],height[maxn],stack[maxn][3];
void get_sa(){
int *rnk=x,*tp=y;
int m=500;
for(int i=1;i<=n;i++)rnk[i]=tp[i]=s[i],c[rnk[i]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[rnk[i]]--]=i;
for(int j=1;j<=n;j<<=1){
int p=0;
for(int i=n-j+1;i<=n;i++)tp[++p]=i;
for(int i=1;i<=n;i++)if(sa[i]>j)tp[++p]=sa[i]-j;
for(int i=0;i<=m;i++)c[i]=0;
for(int i=1;i<=n;i++)c[rnk[tp[i]]]++;
for(int i=1;i<=m;i++)c[i]+=c[i-1];
for(int i=n;i>=1;i--)sa[c[rnk[tp[i]]]--]=tp[i];
swap(rnk,tp);p=1,rnk[sa[1]]=1;
for(int i=2;i<=n;i++){
int O1=sa[i]+j>n?-1:tp[sa[i]+j];
int O2=sa[i-1]+j>n?-1:tp[sa[i-1]+j];
rnk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&O1==O2)?p:++p;
}
m=p;
if(m>=n)break;
}
}
void get_height(){
int p=0,j;
for(int i=1;i<=n;i++)x[sa[i]]=i;
for(int i=1;i<=n;i++){
if(p)p--;
j=sa[x[i]-1];
while(s[i+p]==s[j+p])p++;
height[x[i]]=p;
}
}
void work(int la){
int top=0,cnt=0;
long long ans=0,tot=0;
for(int i=1;i<=n;i++){
if(height[i]<1)top=0,tot=0;
else{
cnt=0;
if(sa[i-1]<=la)cnt++,tot+=height[i];
while(top&&height[i]<=stack[top-1][0]){
top--;
tot+=(height[i]-stack[top][0])*stack[top][1];
cnt+=stack[top][1];
}
stack[top][0]=height[i],stack[top++][1]=cnt;
if(sa[i]>la+1)ans+=tot;
}
}
top=tot=0;
for(int i=1;i<=n;i++){
if(height[i]<1)top=0,tot=0;
else{
cnt=0;
if(sa[i-1]>la+1)cnt++,tot+=height[i];
while(top&&height[i]<=stack[top-1][0]){
top--;
tot+=(height[i]-stack[top][0])*stack[top][1];
cnt+=stack[top][1];
}
stack[top][0]=height[i],stack[top++][1]=cnt;
if(sa[i]<=la)ans+=tot;
}
}
printf("%lld\n",ans);
}
int main(){
freopen("find_2016.in","r",stdin);
freopen("find_2016.out","w",stdout);
scanf("%s%s",a+1,b+1);
int lena=strlen(a+1),lenb=strlen(b+1);
for(int i=1;i<=lena;i++)s[++n]=a[i];
s[++n]='#';
for(int i=1;i<=lenb;i++)s[++n]=b[i];
//for(int i=1;i<=n;i++)printf("%c",s[i]);
get_sa();get_height();
work(lena);
return 0;
}