记录编号 397152 评测结果 AAAAAAAAAA
题目名称 [JLOI&SHOI 2016] 侦察守卫 最终得分 100
用户昵称 GravatarFoolMike 是否通过 通过
代码语言 C++ 运行时间 2.800 s
提交时间 2017-04-19 20:07:42 内存使用 98.04 MiB
显示代码纯文本
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=5e5+10;
int n,m,d,v[N],w[N*2],head[N],next[N*2];
void add(int f,int t){
	static int cnt=0;
	w[++cnt]=t;
	next[cnt]=head[f];
	head[f]=cnt;
}
int read(){
	int x=0;char ch=getchar();
	while (ch>'9'||ch<'0') ch=getchar();
	while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
	return x;
}
int fa[N];bool use[N];
//f[i][j]表示子树i能向上覆盖距离j的最小花费 
//g[i][j]表示子树j最深的未被覆盖的点到根距离为j的最小花费 
//f单调不减,g单调不增 
int f[N][22],g[N][22],X[22],Y[22];
void solve(int x){
	for (int i=head[x];i;i=next[i])
		if (!fa[w[i]]) fa[w[i]]=x,solve(w[i]);
	for (int i=0;i<=d+1;i++){
		f[x][i]=g[x][i]=1e9;
		if (use[x]) g[x][i]=0;
	}
	if (!use[x]) f[x][0]=0;
	for (int i=head[x];i;i=next[i]){
		int v=w[i];
		if (fa[v]!=x) continue;
		for (int i=0;i<=d;i++) X[i]=Y[i]=1e9;
		//合并f[x]和f[v] 
		for (int a=0;a<=d;a++){
			X[a]=min(X[a],f[x][0]+f[v][a+1]);
			X[a]=min(X[a],f[x][a]+f[v][0]);
		}
		//合并f[x]和g[v] 
		for (int a=0;a<=d;a++){
			if (a) X[a]=min(X[a],f[x][a]+g[v][a-1]);
			Y[a+1]=min(Y[a+1],f[x][0]+g[v][a]);
		}
		//合并g[x]和f[v] 
		for (int a=0;a<=d;a++){
			Y[a]=min(Y[a],g[x][a]+f[v][0]);
			if (a) X[a-1]=min(X[a-1],g[x][a-1]+f[v][a]);
		}
		//合并g[x]和g[v] 
		for (int a=1;a<=d;a++)
			Y[a]=min(Y[a],g[x][a]+g[v][a-1]);
		//更新x 
		for (int i=0;i<=d;i++) f[x][i]=X[i],g[x][i]=Y[i];
	}
	//取x 
	int Min=1e9;
	for (int i=0;i<=d;i++) Min=min(Min,f[x][i]);
	for (int i=0;i<=d;i++) Min=min(Min,g[x][i]);
	f[x][d]=min(f[x][d],Min+v[x]);
	for (int i=d;i;i--) f[x][i-1]=min(f[x][i-1],f[x][i]);
	for (int i=1;i<=d;i++) g[x][i]=min(g[x][i],g[x][i-1]);
}
int main()
{
	freopen("observer.in","r",stdin);
	freopen("observer.out","w",stdout);
	n=read();d=read();
	for (int i=1;i<=n;i++) v[i]=read();
	m=read();
	for (int i=1;i<=m;i++) use[read()]=1;
	for (int i=1;i<n;i++){
		int f=read(),t=read();
		add(f,t);add(t,f);
	}
	fa[1]=1;solve(1);
	int ans=1e9;
	for (int i=0;i<=d;i++) ans=min(ans,f[1][i]);
	printf("%d\n",ans);
	return 0;
}