显示代码纯文本
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+7;
vector<int>G[N];
int a[N],w[N],n,m,fa[N],vis[N],f[N][2],bk;
void add(int u,int v){
G[u].push_back(v);
fa[v]=u;
}
int check(int u){
vis[u]=1;
int ff=fa[u];
if(vis[ff]){
return ff;
}
else{
return check(ff);
}
}
void dfs(int u){
f[u][1]=0;
f[u][0]=w[u];
vis[u]=1;
for(auto v:G[u]){
if(v!=bk){
dfs(v);
f[u][1]+=f[v][0];
f[u][0]+=min(f[v][0],f[v][1]);
}
}
}
int calc(int u){
int tmp=0;
bk=check(u);
dfs(bk);
tmp=f[bk][0];
bk=fa[bk];
dfs(bk);
tmp=min(tmp,f[bk][0]);
return tmp;
}
signed main(){
// freopen("Function.in","r",stdin);
// freopen("Function.out","w",stdout);
cin>>n;
memset(f,0x3f,sizeof(f));
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=1;i<=n;i++){
cin>>w[i];
if(a[i]==i){
w[i]=0;
}
}
for(int i=1;i<=n;i++){
add(a[i],i);
}
int ans=0;
for(int i=1;i<=n;i++){
if(!vis[i]){
ans+=calc(i);
}
}
cout<<ans;
return 0;
}