记录编号 |
599577 |
评测结果 |
AAAAAAAAAAAAAAAAAAAAAAAAA |
题目名称 |
[NOIP 2024]树的遍历 |
最终得分 |
100 |
用户昵称 |
flyfree |
是否通过 |
通过 |
代码语言 |
C++ |
运行时间 |
7.160 s |
提交时间 |
2025-03-24 16:33:34 |
内存使用 |
14.96 MiB |
显示代码纯文本
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define MAXN 100010
#define mod 1000000007
#define debug cout<<"flyfree\n";
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
ll T,n,k,idx,sum,ans,type,cp;
ll hd[MAXN],ed[MAXN*2],nxt[MAXN*2];
ll dep[MAXN],d[MAXN],fa[MAXN],fac[MAXN],inf[MAXN],cnt[MAXN],_sum[MAXN];
bool mp[MAXN],used[MAXN];
queue<ll> q;
ll fpow(ll bas,ll ind){
ll now=1;
while(ind>0){
if(ind&1)now=now*bas%mod;
bas=bas*bas%mod;
ind/=2;
}
return now;
}
void dfs(ll now,ll p){
fa[now]=p,dep[now]=dep[p]+1,sum=(sum*fac[d[now]-1]%mod)%mod;
// cout<<now<<" "<<sum<<" "<<d[now]<<endl;
for(int i=hd[now];i;i=nxt[i]){
ll y=ed[i];
if(y==p)continue;
dfs(y,now);
}
}
void clear(){
memset(dep,0,sizeof(dep));
memset(d,0,sizeof(d));
memset(fa,0,sizeof(fa));
memset(hd,0,sizeof(hd));
memset(nxt,0,sizeof(nxt));
memset(ed,0,sizeof(ed));
memset(mp,0,sizeof(mp));
memset(used,0,sizeof(used));
memset(cnt,0,sizeof(cnt));
memset(_sum,0,sizeof(_sum));
idx=1,sum=1,ans=0,cp=0;
}
void find(ll now,ll p){
// cout<<now<<" "<<len<<" "<<p<<endl;
// if(used[now])debug;
// len=(len*(d[now]-1)%mod)%mod;
used[now]=true;
// _sum[now]=fpow(d[now]-1,mod-2);
ll s=fpow(d[now]-1,mod-2);
// _sum[now]=s;
for(int i=hd[now];i;i=nxt[i]){
ll y=ed[i];
if(y==p||mp[i/2])continue;
find(y,now);
cp=(cp+cnt[y]*cnt[now]%mod)%mod;
ans=(ans+_sum[now]*_sum[y]%mod*s%mod)%mod;
_sum[now]=(_sum[now]+_sum[y])%mod;
cnt[now]=(cnt[now]+cnt[y])%mod;
}
for(int i=hd[now];i;i=nxt[i]){
if(ed[i]!=p&&mp[i/2])ans=(ans+_sum[now]*s)%mod,_sum[now]=(_sum[now]+1)%mod,cp=(cp+cnt[now])%mod,cnt[now]=(cnt[now]+1)%mod;
}
_sum[now]=(_sum[now]*s%mod)%mod;
// _sum[now]=(_sum[now]*fpow(d[now]-1,mod-2));
// cout<<now<<" cnt:"<<cnt[now]<<" sum:"<<_sum[now]<<" d:"<<d[now]<<endl;
}
void insert(ll x,ll y){
nxt[++idx]=hd[x];
ed[idx]=y;
hd[x]=idx;
}
int main(){
// freopen("traverse9.in","r",stdin);
// freopen("w.out","w",stdout);
type=read(),T=read();
fac[0]=1;
for(int i=1;i<=100000;i++)fac[i]=(fac[i-1]*i%mod)%mod;
// inf[100000]=fpow(fac[100000],mod-2);
// for(int i=99999;i>=0;i--)inf[i]=(inf[i+1]*(i+1))%mod;
while(T--){
// idx=1,sum=1;
clear();
n=read(),k=read();
for(int i=1;i<n;i++){
ll x=read(),y=read();
d[x]++,d[y]++;
insert(x,y);
insert(y,x);
}
dfs(1,0);
for(int i=1;i<=k;i++){
ll c=read();
mp[c]=true;
q.push(c);
}
// if(k==1){
// cout<<sum<<endl;
// continue;
// }
while(!q.empty()){
ll x=ed[q.front()*2],y=ed[q.front()*2+1];
if(!used[x])find(x,y),ans=(ans+_sum[x])%mod,cp=(cp+cnt[x])%mod;
if(!used[y])find(y,x),ans=(ans+_sum[y])%mod,cp=(cp+cnt[y])%mod;
// find(x,1,y);
// find(y,1,x);
q.pop();
}
// if(type>=19&&type<=21){
// cout<<((n-1)*(n-2)%mod*fac[n-2])%mod<<endl;
// continue;
// }
// ans=ans?(ans*fpow(2,mod-2)%mod)%mod:ans;
// cout<<ans<<" "<<cp<<endl;
cout<<(k*sum%mod+mod-ans*(sum%mod)%mod)%mod<<endl;
}
return 0;
}