显示代码纯文本
- #include<cstdio>
- #include<algorithm>
- #include<queue>
- using namespace std;
- typedef long long ll;
- typedef unsigned int u32;
- const int N=1e5+10;
- int hash[N],Mi[N];
- int n,l[N],r[N],root,len,top,size[N],s[N];
- ll tag[N];
- int HASH(int i,int l){return hash[i]-hash[i-l]*Mi[l];}
- int lcp(int x,int y){
- int l=0,r=min(x,y);
- while (l<r){
- int mid=(l+r)>>1;
- if (HASH(x,mid+1)==HASH(y,mid+1)) l=mid+1;else r=mid;
- }
- return l;
- }
- bool pre_cmp(int x,int y){
- int l=lcp(x,y);
- return s[x-l]<s[y-l];
- }
- int merge(int x,int y){
- if (!x||!y) return x|y;
- if (size[x]>size[y]){
- size[x]+=size[y];
- r[x]=merge(r[x],y);
- return x;
- }
- else{
- size[y]+=size[x];
- l[y]=merge(x,l[y]);
- return y;
- }
- }
- int que[N],tail;
- void visit(int x){
- if (l[x]) visit(l[x]);
- que[++tail]=x;
- if (r[x]) visit(r[x]);
- }
- int build(int L,int R,ll le,ll ri){
- if (L>R) return 0;
- int mid=(L+R)>>1,x=que[mid];
- tag[x]=(le+ri)>>1;
- l[x]=build(L,mid-1,le,tag[x]);
- r[x]=build(mid+1,R,tag[x],ri);
- size[x]=R-L+1;
- return x;
- }
- int rebuild(int x,ll L,ll R){
- tail=0;visit(x);
- return build(1,tail,L,R);
- }
- int insert(int x,ll L,ll R,int key){
- if (!x){
- size[++top]=1;
- l[top]=r[top]=0;
- tag[top]=(L+R)>>1;
- return top;
- }
- size[x]++;
- if (pre_cmp(x,key)){
- r[x]=insert(r[x],tag[x],R,key);
- if (size[r[x]]>0.65*size[x]) x=rebuild(x,L,R);
- }
- else{
- l[x]=insert(l[x],L,tag[x],key);
- if (size[l[x]]>0.65*size[x]) x=rebuild(x,L,R);
- }
- return x;
- }
- int del(int x,int key){
- if (x==key) return merge(l[x],r[x]);
- size[x]--;
- if (tag[x]<tag[key]) r[x]=del(r[x],key);
- else l[x]=del(l[x],key);
- return x;
- }
- int rank(int key){
- int x=root,ans=0;
- while (1){
- int i=size[l[x]]+1;
- if (key==x) return ans+i;
- tag[x]<tag[key]?x=r[x],ans+=i:x=l[x];
- }
- }
- int find(int R){
- int x=root;
- while (1){
- int i=size[l[x]]+1;
- if (R==i) return x;
- R>i?x=r[x],R-=i:x=l[x];
- }
- }
- int ans,Ans[N];
- void erase(int x){
- int R=rank(x),y=find(R-1),z=find(R+1);
- ans-=lcp(x,y)+lcp(x,z)-lcp(y,z);
- root=del(root,x);
- top--;len--;
- }
- void insert(int w){
- s[++len]=w;
- hash[len]=hash[len-1]*31+w;
- root=insert(root,0,1ll<<62,len);
- if (len<3) return;
- int R=rank(len),y=find(R-1),z=find(R+1);
- ans+=lcp(len,y)+lcp(len,z)-lcp(y,z);
- }
- vector<int> E[N];
- char str[N];
- void dfs(int x,int fa){
- insert(str[x]-'a'+1);
- Ans[x]=(len-1)*(len-2)/2-ans;
- for (int i=E[x].size()-1;i>=0;i--)
- if (E[x][i]!=fa) dfs(E[x][i],x);
- erase(len);
- }
- int main()
- {
- freopen("balsuffix.in","r",stdin);
- freopen("balsuffix.out","w",stdout);
- Mi[0]=1;
- for (int i=1;i<N;i++) Mi[i]=Mi[i-1]*31;
- insert(27);
- insert(0);
- int T;scanf("%d",&T);
- while (T--){
- scanf("%d",&n);
- for (int i=1;i<=n;i++) E[i].clear();
- for (int i=1;i<n;i++){
- int u,v;
- scanf("%d%d",&u,&v);
- E[u].push_back(v);
- E[v].push_back(u);
- }
- scanf("%s",str+1);
- dfs(1,0);
- for (int i=1;i<=n;i++) printf("%d\n",Ans[i]);
- }
- return 0;
- }