显示代码纯文本
#include<iostream>
#include<cstring>
using namespace std;
const long long INF = 4e18;
const int MAXN = 200005;
long long dp[MAXN][2], f[MAXN][2], c[MAXN];
bool vis[MAXN], loop[MAXN];
int to[MAXN], nxt[MAXN], h[MAXN];
int a[MAXN], r[MAXN];
int cnt, len;
void add(int u, int v){
to[++ cnt] = v; nxt[cnt] = h[u]; h[u] = cnt;
}
void dfs(int u){
vis[u] = 1;
dp[u][0] = (u != a[u]) ? c[u] : 0;
dp[u][1] = 0;
for(int i = h[u]; i; i = nxt[i]){
int v = to[i];
if(!loop[v] && !vis[v]) {
dfs(v);
dp[u][0] += min(dp[v][0], dp[v][1]);
dp[u][1] += dp[v][0];
}
}
}
long long solve(int x){
len = 0;
int y = x;
do{
r[++ len] = y;
loop[y] = 1;
y = a[y];
}while(y != x);
for(int i = 1; i <= len; i ++) dfs(r[i]);
if(len == 1) return min(dp[r[1]][0], dp[r[1]][1]);
for(int i = 1; i <= len; i ++){
f[i][0] = f[i][1] = INF;
}
f[1][0] = dp[r[1]][0];
for(int i = 2; i <= len; i ++){
f[i][0] = min(f[i - 1][0], f[i - 1][1]) + dp[r[i]][0];
f[i][1] = f[i - 1][0] + dp[r[i]][1];
}
long long res = min(f[len][0], f[len][1]);
for(int i = 1; i <= len; i ++){
f[i][0] = f[i][1] = INF;
}
f[1][0] = dp[r[1]][0];
f[1][1] = dp[r[1]][1];
for(int i = 2; i <= len; i ++){
f[i][0] = min(f[i - 1][0], f[i - 1][1]) + dp[r[i]][0];
f[i][1] = f[i - 1][0] + dp[r[i]][1];
}
res = min(res, f[len][0]);
return res;
}
int main(){
int n;
cin >> n;
for(int i = 1; i <= n; i ++){
cin >> a[i];
add(a[i], i);
}
for(int i = 1; i <= n; i ++) cin >> c[i];
long long ans = 0;
memset(vis, 0, sizeof(vis));
for(int i = 1; i <= n; i ++){
if(!vis[i]){
int x = a[i], y = i;
while(x != y){
x = a[a[x]];
y = a[y];
}
ans += solve(x);
}
}
cout << ans << '\n';
return 0;
}