大概思路是寻找白点,白点向下看有没有黑点,如果有就需要保留,反之则不需要
#include <bits/stdc++.h>
//#define int long long
using namespace std;
int n;
const int N = 1e5 + 114;
int a[N];
int white[N];
vector<int> v[N];
int cnt;
int ans = 0;
int flag[N];
int FA[N];
int vis[N];
void DFS(int x, int fa){
FA[x] = fa;
for(int i = 0;i < v[x].size(); i++){
if(v[x][i] == fa) continue;
DFS(v[x][i], x);
}
}
int dfs(int x, int fa){
if(vis[x]){
return flag[x];
}
vis[x] = true;
if(a[x]){
return flag[x] = true;
}
for(int i = 0;i < v[x].size(); i++){
if(v[x][i] == fa) continue;
flag[x] = flag[x] | dfs(v[x][i], x);
}
return flag[x];
}
signed main()
{
cin >> n;
for(int i = 1;i <= n; i++){
cin >> a[i];
if(!a[i]){
white[++cnt] = i;
}
}
for(int i = 1;i <= n - 1; i++){
int x, y;
cin >> x >> y;
v[x].push_back(y);
v[y].push_back(x);
}
DFS(1, 0);
for(int i = 1;i <= cnt; i++){
if(dfs(white[i], FA[white[i]])){
ans ++;
}
}
cout << ans << endl;
return 0;
}