记录编号 481157 评测结果 AAAAAAAAAAAAAAAAAAAA
题目名称 [HAOI 2016]找相同子串 最终得分 100
用户昵称 GravatarShirry 是否通过 通过
代码语言 C++ 运行时间 1.138 s
提交时间 2018-01-01 15:43:11 内存使用 13.64 MiB
显示代码纯文本
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=400010;
char s[maxn],a[maxn],b[maxn];
int n,c[maxn],x[maxn],y[maxn],sa[maxn],height[maxn],stack[maxn][3];
void get_sa(){
	int *rnk=x,*tp=y;
	int m=500;
	for(int i=1;i<=n;i++)rnk[i]=tp[i]=s[i],c[rnk[i]]++;
	for(int i=1;i<=m;i++)c[i]+=c[i-1];
	for(int i=n;i>=1;i--)sa[c[rnk[i]]--]=i;
	for(int j=1;j<=n;j<<=1){
		int p=0;
		for(int i=n-j+1;i<=n;i++)tp[++p]=i;
		for(int i=1;i<=n;i++)if(sa[i]>j)tp[++p]=sa[i]-j;
		for(int i=0;i<=m;i++)c[i]=0;
		for(int i=1;i<=n;i++)c[rnk[tp[i]]]++;
		for(int i=1;i<=m;i++)c[i]+=c[i-1];
		for(int i=n;i>=1;i--)sa[c[rnk[tp[i]]]--]=tp[i];
		swap(rnk,tp);p=1,rnk[sa[1]]=1;
		for(int i=2;i<=n;i++){
			int O1=sa[i]+j>n?-1:tp[sa[i]+j];
			int O2=sa[i-1]+j>n?-1:tp[sa[i-1]+j];
			rnk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&O1==O2)?p:++p;
		}
		m=p;
		if(m>=n)break;
	}
}
void get_height(){
	int p=0,j;
	for(int i=1;i<=n;i++)x[sa[i]]=i;
	for(int i=1;i<=n;i++){
		if(p)p--;
		j=sa[x[i]-1];
		while(s[i+p]==s[j+p])p++;
		height[x[i]]=p;
	}
}
void work(int la){
	int top=0,cnt=0;
	long long ans=0,tot=0;
	for(int i=1;i<=n;i++){
		if(height[i]<1)top=0,tot=0;
		else{
			cnt=0;
			if(sa[i-1]<=la)cnt++,tot+=height[i];
			while(top&&height[i]<=stack[top-1][0]){
				top--;
				tot+=(height[i]-stack[top][0])*stack[top][1];
				cnt+=stack[top][1];
			}
			stack[top][0]=height[i],stack[top++][1]=cnt;
			if(sa[i]>la+1)ans+=tot;
		}
	}
	top=tot=0;
	for(int i=1;i<=n;i++){
		if(height[i]<1)top=0,tot=0;
		else{
			cnt=0;
			if(sa[i-1]>la+1)cnt++,tot+=height[i];
			while(top&&height[i]<=stack[top-1][0]){
				top--;
				tot+=(height[i]-stack[top][0])*stack[top][1];
				cnt+=stack[top][1];
			}
			stack[top][0]=height[i],stack[top++][1]=cnt;
			if(sa[i]<=la)ans+=tot;
		}
	}
	printf("%lld\n",ans);
}
int main(){
	freopen("find_2016.in","r",stdin);
	freopen("find_2016.out","w",stdout);
	scanf("%s%s",a+1,b+1);
	int lena=strlen(a+1),lenb=strlen(b+1);
	for(int i=1;i<=lena;i++)s[++n]=a[i];
	s[++n]='#';
	for(int i=1;i<=lenb;i++)s[++n]=b[i];
	//for(int i=1;i<=n;i++)printf("%c",s[i]);
	get_sa();get_height();
	work(lena);
	return 0;
}