题目链接: P8981 Round 1 距离
这是代码,10pts
(马蜂良好)
求求了求求了 TwT
#include<bits/stdc++.h>
using namespace std;
#define rint long long
#define int long long
const int N=5e6+10;
const int mod=998244353;
template<typename type>
inline void read(type &x){
x=0;bool flag(0);char ch=getchar();
while(!isdigit(ch)) flag=ch=='-',ch=getchar();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
flag?x=-x:0;
}
template<typename type>
inline void write(type x,bool mode=1){
x<0?x=-x,putchar('-'):0;static short Stack[50],top(0);
do Stack[++top]=x%10,x/=10; while(x);
while(top) putchar(Stack[top--]|48);
mode?putchar('\n'):putchar(' ');
}
int n,k,x,y,d,p,num,len[N],cnt[N],tag[N],dis[N],head[N<<1];
struct node{
int nxt,to;
}e[N<<1];
inline void add(int x,int y){
e[++num].to=y;
e[num].nxt=head[x];
head[x]=num;
}
inline void dfs(int x,int fa){
dis[x]=1,cnt[x]=1;
for(rint i=head[x];i;i=e[i].nxt){
rint to=e[i].to;
if(to==fa) continue;
dfs(to,x);
dis[x]+=dis[to];
if(len[to]+1+len[x]>d){
p=x;
d=len[to]+len[x]+1;
}
if(len[to]+1>len[x]){
len[x]=len[to]+1;
cnt[x]=cnt[to];
}else{
if(len[to]+1==len[x]){
cnt[x]+=cnt[to];
}
}
}
}
void dfs1(int x,int fa,int val){
if(dis[x]==1) tag[x]=val;
for(rint i=head[x];i;i=e[i].nxt){
rint to=e[i].to;
if(to==fa) continue;
if(len[to]+1==len[x]){
dfs1(to,x,val);
tag[x]=(tag[x]+tag[to])%mod;
}
}
}
signed main(){
read(n),read(k);
for(rint i=1;i<n;++i){
read(x),read(y);
add(x,y);
add(y,x);
}
dfs(1,0);dfs(p,0);
rint d1=-1,d2=-1,num1=0,num2=0,son1=0,son2=0;
for(rint i=head[p];i;i=e[i].nxt){
rint to=e[i].to;
if(d1<len[to]) d1=len[to],num1=cnt[to],son1=1;
else if(d1==len[to]) num1+=cnt[to],++son1;
}
for(rint i=head[p];i;i=e[i].nxt){
rint to=e[i].to;
if(len[to]==d1) continue;
if(len[to]>d2) d2=len[to],num2=cnt[to],son2=1;
else if(len[to]==d2) num2+=cnt[to],++son2;
}
if(son1>1){
for(rint i=head[p];i;i=e[i].nxt){
rint to=e[i].to;
if(len[to]!=d1) continue;
dfs1(to,p,num1-cnt[to]);
tag[p]=(tag[p]+tag[to])%mod;
}
}else{
if(!son2){
for(rint i=head[p];i;i=e[i].nxt){
rint to=e[i].to;
if(len[to]==d1) dfs1(to,p,1);
tag[p]=(tag[p]+tag[to])%mod;
}
}else{
for(rint i=head[p];i;i=e[i].nxt){
rint to=e[i].to;
if(len[to]==d1) dfs1(to,p,num2);
if(len[to]==d2) dfs1(to,p,num1);
tag[p]=(tag[p]+tag[to])%mod;
}
}
}
if(son1+son2>=2) tag[p]=tag[p]*499122177%mod;
rint ans=0;
for(rint i=1;i<=n;++i){
if(k==1) ans=(ans+tag[i])%mod;
else ans=(ans+tag[i]*tag[i]%mod)%mod;
}
write(ans);
return 0;
}