输出的答案有过大还有过小的,67pts,有3个点WA。
#include <iostream>
#include <vector>
#include <stack>
using namespace std;
const int N = 510;
const int M = 1e6 + 10;
int n, m, H[N][N];
int cnt, head[N * N], ans1, ans2;
int low[N * N], dfn[N * N], scc, tot;
int in[N * N], out[N * N], id[N * N];
bool ins[N * N];
stack <int> st;
struct node {
int from, to, next;
} e[M];
void add(int u, int v) {
e[++ cnt] = {u, v, head[u]};
head[u] = cnt;
}
void tarjan(int u) {
dfn[u] = low[u] = ++ tot;
ins[u] = u, st.push(u);
for(int i = head[u]; i; i = e[i].next) {
int v = e[i].to;
if(!dfn[v]) {
tarjan(v);
low[u] = min(low[u], low[v]);
}
else if(ins[v]) low[u] = min(low[u], dfn[v]);
}
if(low[u] == dfn[u]) {
int v = 0; scc ++;
while(u != v) {
v = st.top(), st.pop();
ins[v] = 0, id[v] = scc;
}
}
}
inline int got(int h, int l) {
return m * (h - 1) + l;
}
inline void R(int h, int l) {
if(!H[h][l + 1]) return;
int u = got(h, l + 1), v = got(h, l);
if(H[h][l + 1] > H[h][l]) add(u, v);
if(H[h][l + 1] == H[h][l]) add(u, v), add(v, u);
}
inline void D(int h, int l) {
if(!H[h - 1][l]) return;
int u = got(h, l), v = got(h - 1, l);
if(H[h][l] > H[h - 1][l]) add(u, v);
if(H[h][l] == H[h - 1][l]) add(u, v), add(v, u);
}
inline void solve() {
for(int i = 1; i <= n; i ++) {
for(int j = 1; j <= m; j ++) {
D(i, j), R(i, j);
}
}
}
inline int rd() {
int w = 1, x = 0;
char c = getchar();
while(c < 48 || c > 57) {
if(c == 45) w *= -1;
c = getchar();
}
while(c >= 48 && c <= 57) {
x = (x << 1) + (x << 3) + (c ^ 48);
c = getchar();
}
return x * w;
}
int main() {
m = rd(), n = rd();
for(int i = 1; i <= n; i ++) {
for(int j = 1; j <= m; j ++) {
H[i][j] = rd();
}
}
solve();
for(int i = 1; i <= n * m; i ++) {
if(!dfn[i]) tarjan(i);
}
if(scc == 1) {
cout << "0";
return 0;
}
for(int i = 1; i <= cnt; i ++) {
int u = e[i].from, v = e[i].to;
if(id[u] == id[v]) continue;
in[id[v]] ++, out[id[u]] ++;
}
for(int i = 1; i <= scc; i ++) {
if(in[i] == 0) ans1 ++;
if(out[i] == 0) ans2 ++;
}
printf("%d", max(ans1, ans2));
return 0;
}