#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5e6+10;
const int mod = 998244353;
template<typename type>
inline void read(type &x){
x=0;bool flag(0);char ch=getchar();
while(!isdigit(ch)) flag=ch=='-',ch=getchar();
while(isdigit(ch)) x=(x<<1)+(x<<3)+(ch^48),ch=getchar();
flag?x=-x:0;
}
template<typename type>
inline void write(type x,bool mode=1){
x<0?x=-x,putchar('-'):0;static short Stack[50],top(0);
do Stack[++top]=x%10,x/=10; while(x);
while(top) putchar(Stack[top--]|48);
mode?putchar('\n'):putchar(' ');
}
int n,k,u,v,d,p,len[N],cnt[N],tag[N],siz[N];
vector<int>e[N];
void dfs(int x,int fath){
siz[x]=1,cnt[x]=1;
for(int i=0;i<e[x].size();++i){
int to=e[x][i];
if(to==fath) continue;
dfs(to,x);
siz[x]+=siz[to];
if(len[to]+1+len[x]>d){
p=x;
d=len[to]+len[x]+1;
}
if(len[to]+1>len[x]){
len[x]=len[to]+1;
cnt[x]=cnt[to];
}else{
if(len[to]+1==len[x]){
cnt[x]+=cnt[to];
}
}
}
}
void dfs1(int x,int fath,int val){
if(siz[x]==1) tag[x]=val;
for(int i=0;i<e[x].size();++i){
int to=e[x][i];
if(to==fath) continue;
if(len[to]+1==len[x]){
dfs1(to,x,val);
tag[x]=(tag[x]+tag[to])%mod;
}
}
}
signed main(){
cin>>n>>k;
for(int i=1;i<n;++i){
cin>>u>>v;
e[u].push_back(v);
e[v].push_back(u);
}
dfs(1,0);
dfs(p,0);
int d1=-1,d2=-1,num1=0,num2=0,son1=0,son2=0;
for(int i=0;i<e[p].size();++i){
int to=e[p][i];
if(d1<len[to]) d1=len[to],num1=cnt[to],son1=1;
else if(d1==len[to]) num1+=cnt[to],++son1;
}
for(int i=0;i<e[p].size();++i){
int to=e[p][i];
if(len[to]==d1) continue;
if(len[to]>d2) d2=len[to],num2=cnt[to],son2=1;
else if(len[to]==d2) num2+=cnt[to],++son2;
}
if(son1>1){
for(int i=0;i<e[p].size();++i){
int to=e[p][i];
if(len[to]!=d1) continue;
dfs1(to,p,num1-cnt[to]);
tag[p]=(tag[p]+tag[to])%mod;
}
}else{
if(!son2){
for(int i=0;i<e[p].size();++i){
int to=e[p][i];
if(len[to]==d1) dfs1(to,p,1);
tag[p]=(tag[p]+tag[to])%mod;
}
}else{
for(int i=0;i<e[p].size();++i){
int to=e[p][i];
if(len[to]==d1) dfs1(to,p,num2);
if(len[to]==d2) dfs1(to,p,num1);
tag[p]=(tag[p]+tag[to])%mod;
}
}
}
if(son1+son2>=2) tag[p]=tag[p]*499122177%mod;
int ans=0;
for(int i=1;i<=n;++i){
if(k==1) ans=(ans+tag[i])%mod;
else ans=(ans+tag[i]*tag[i]%mod)%mod;
}
cout<<ans;
return 0;
}