Help!
查看原帖
Help!
1353330
K_J_M楼主2025/1/22 19:52
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5e6+10;
const int mod = 998244353;
template<typename type>
inline void read(type &x){
    x=0;bool flag(0);char ch=getchar();
    while(!isdigit(ch)) flag=ch=='-',ch=getchar();
    while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
    flag?x=-x:0;
}
template<typename type>
inline void write(type x,bool mode=1){
    x<0?x=-x,putchar('-'):0;static short Stack[50],top(0);
    do Stack[++top]=x%10,x/=10; while(x);
    while(top) putchar(Stack[top--]|48);
    mode?putchar('\n'):putchar(' ');
}
int n,k,u,v,d,p,len[N],cnt[N],tag[N],siz[N];
vector<int>e[N];
void dfs(int x,int fath){
	siz[x]=1,cnt[x]=1;
	for(int i=0;i<e[x].size();++i){
		int to=e[x][i];
		if(to==fath) continue;
		dfs(to,x);
		siz[x]+=siz[to];
		if(len[to]+1+len[x]>d){
			p=x;
			d=len[to]+len[x]+1;
		}
		if(len[to]+1>len[x]){
			len[x]=len[to]+1;
			cnt[x]=cnt[to];
		}else{
			if(len[to]+1==len[x]){
				cnt[x]+=cnt[to];
			}
		}
	}
}
void dfs1(int x,int fath,int val){
	if(siz[x]==1) tag[x]=val;
	for(int i=0;i<e[x].size();++i){
		int to=e[x][i];
		if(to==fath) continue;
		if(len[to]+1==len[x]){
			dfs1(to,x,val);
			tag[x]=(tag[x]+tag[to])%mod;
		}
	}
}
signed main(){
	cin>>n>>k;
	for(int i=1;i<n;++i){
		cin>>u>>v;
		e[u].push_back(v);
		e[v].push_back(u);
	}
	dfs(1,0);
	dfs(p,0);
	int d1=-1,d2=-1,num1=0,num2=0,son1=0,son2=0;
	for(int i=0;i<e[p].size();++i){
		int to=e[p][i];
		if(d1<len[to]) d1=len[to],num1=cnt[to],son1=1;
		else if(d1==len[to]) num1+=cnt[to],++son1;
	}
	for(int i=0;i<e[p].size();++i){
		int to=e[p][i];
		if(len[to]==d1) continue;
		if(len[to]>d2) d2=len[to],num2=cnt[to],son2=1;
		else if(len[to]==d2) num2+=cnt[to],++son2;
	}
	if(son1>1){
		for(int i=0;i<e[p].size();++i){
			int to=e[p][i];
			if(len[to]!=d1) continue;
			dfs1(to,p,num1-cnt[to]);
			tag[p]=(tag[p]+tag[to])%mod;
		}
	}else{
		if(!son2){//只有最长链 
			for(int i=0;i<e[p].size();++i){
				int to=e[p][i];
				if(len[to]==d1) dfs1(to,p,1);
				tag[p]=(tag[p]+tag[to])%mod;
			}
		}else{
			for(int i=0;i<e[p].size();++i){
				int to=e[p][i];
				if(len[to]==d1) dfs1(to,p,num2);
				if(len[to]==d2) dfs1(to,p,num1);
				tag[p]=(tag[p]+tag[to])%mod;
			}
		}
	}
	if(son1+son2>=2) tag[p]=tag[p]*499122177%mod;
	int ans=0;
	for(int i=1;i<=n;++i){
		if(k==1) ans=(ans+tag[i])%mod;
		else ans=(ans+tag[i]*tag[i]%mod)%mod;
	}
	cout<<ans;
	return 0;
}

2025/1/22 19:52
加载中...