显示代码纯文本
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
struct arr{
int x,y,z;
}b[550100];
int a[510][510],num[510][510],fa[250100],ans[250100],n,m,N,M,T,ret,size[250100],st[250100],sum;
char c;
long long maxans;
int read(){
while (!isdigit(c = getchar()));
ret = c-48;
while (isdigit(c = getchar()))
(ret *= 10) += c-48;
return ret;
}
int getfa(int x){
while (x != fa[x]) x = fa[x];
return x;
}
long long getans(int x){
while (fa[x] != x && ans[x] == -1) x = fa[x];
return (long long)ans[x];
}
void _union(int x,int y,int z){
if (size[x] > size[y]) swap(x,y);
fa[x] = y;
if (size[y] < T && size[x]+size[y] >= T) ans[y] = z;
else
if (size[x] < T && size[x]+size[y] >= T) ans[x] = z;
size[y] += size[x];
}
int init(){
int i,j;
n = read(); m = read(); T = read();
N = n*m;
for (i = 1; i <= N; i++){
ans[i] = -1;
size[i] = 1;
fa[i] = i;
}
for (i = 1; i <= n; i++)
for (j = 1; j <= m; j++){
a[i][j] = read();
num[i][j] = (i-1)*m+j;
}
for (i = 1; i <= n; i++)
for (j = 1; j <= m; j++){
c = getchar();
while (c != '1' && c !='0') c = getchar();
if (c == '1'){
st[++sum] = num[i][j];
}
}
}
bool com(const arr &a,const arr &b){
return a.z < b.z;
}
int _abs(int x){
if (x < 0) return -x;
return x;
}
int main(){
freopen("skilevel.in","r",stdin);
freopen("skilevel.out","w",stdout);
int i,j,x,y;
init();
for (i = 1; i <= n; i++)
for (j = 1; j < m; j++){
b[++M].x = num[i][j];
b[M].y = num[i][j+1];
b[M].z = _abs(a[i][j+1]-a[i][j]);
}
for (i = 1; i < n; i++)
for (j = 1; j <= m; j++){
b[++M].x = num[i][j];
b[M].y = num[i+1][j];
b[M].z = _abs(a[i+1][j]-a[i][j]);
}
sort(b+1,b+1+M,com);
for (i = 1; i <= M; i++){
x = getfa(b[i].x);
y = getfa(b[i].y);
if (x != y) _union(x,y,b[i].z);
}
for (i = 1; i <= sum; i++){
maxans += getans(st[i]);
}
printf("%lld",maxans);
return 0;
}