282ms-290ms被卡,时限200ms,求优化
#include <cstring>
#include <string>
#include <stdio.h>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <stack>
#include <queue>
#include <limits.h>
#include <list>
#include <set>
#include <map>
#include <unordered_map>
#include <bitset>
#include <random>
using namespace std;
#define min(a,b) ((a)>(b)?(b):(a))
#define max(a,b) ((a)<(b)?(b):(a))
#define INF 0x3f3f3f3f3f3f3f3f
#define ll long long
#define sc scanf
#define pr printf
#define v1 first
#define v2 second
#define f(nm1,nm2,nm3) for(ll nm1=nm2; nm1<= nm3; nm1++)
int n,m;
vector<pair<int,int>> v[10005];
bool b[10005];
int sz[10005];
int f[10005];
#undef max
int dfs1(int k,int fa,int sztree){
sz[k]=1;
f[k]=fa;
int maxn=0;
int res=0;
for(auto y:v[k]){
if(y.v1==fa||b[y.v1])
continue;
res=max(res,dfs1(y.v1,k,sztree));
sz[k]+=sz[y.v1];
maxn=max(maxn,sz[y.v1]);
}
maxn=max(maxn,sztree-sz[k]);
if(maxn*2<=sztree){
return k;
}
return res;
}
unordered_map<int,int> mp;
vector<pair<int,int>> sn;
int dep[10005];
void dfs3(int k,int fa,int sd,int anc){
dep[k]=dep[fa]+sd;
mp[dep[k]]++;
sn.push_back({dep[k],anc});
for(auto y:v[k]){
if(y.v1==fa||b[y.v1])
continue;
dfs3(y.v1,k,y.v2,anc);
}
}
int tar;
bool ans=0;
vector<pair<vector<pair<int,int>>,unordered_map<int,int>>> vv;
void dfs2(int k,int szt){
int res=dfs1(k,0,szt);
memset(dep,0,sizeof(dep));
sn.clear();
mp.clear();
for(auto y:v[res]){
if(!b[y.v1])
dfs3(y.v1,res,y.v2,y.v1);
}
sn.push_back({0,res});
sort(sn.begin(),sn.end());
vv.push_back({sn,mp});
b[res]=1;
for(auto y:v[res]){
if(b[y.v1])
continue;
if(y.v1==f[res])
dfs2(y.v1,szt-sz[res]);
else
dfs2(y.v1,sz[y.v1]);
}
}
namespace command{
void query(int k){
tar=k;
ans=0;
for(auto yy:vv){
auto son=yy.v1;
auto mp2=yy.v2;
for(int i=0,j=son.size()-1; i<j; i++){
while(i<j&&son[j].v1+son[i].v1>tar)
j--;
if(i>=j)
break;
if(son[i].v1+son[j].v1==tar){
if(son[i].v2==son[j].v2){
if(mp2[son[i].v1]>1||mp2[son[j].v1]>1)
ans=1;
}
else
ans=1;
}
if(ans)
break;
}
if(ans)
break;
}
if(ans)
pr("AYE\n");
else{
pr("NAY\n");
}
}
}
int main(){
sc("%d%d",&n,&m);
for(int i=1,x,y,w; i < n; i++){
sc("%d%d%d",&x,&y,&w);
v[x].push_back({y,w});
v[y].push_back({x,w});
}
dfs2(1,n);
while(m--){
int k;
sc("%d",&k);
command::query(k);
}
return 0;
}