代码:
#include <bits/stdc++.h>
#define debug(a) cout<<#a<<"="<<a<<"\n";
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define pep(i,a,b) for(int i=(a);i>=(b);i--)
#define rrep(i,a,b,s) for(int i=(a);i<=(b);i+=(s))
#define ppep(i,a,b,s) for(int i=(a);i>=(b);i-=(s))
using namespace std;
int T;
int n;
int m;
int cnt;
int tot;
int ttt;
int a[2001000];
int a2[2001000];
int to[2001000];
int nx[2001000];
int head[2001000];
int dep[2001000];
int siz[2001000];
int son[2001000];
int faa[2001000];
int dfn[2001000];
int top[2001000];
void init()
{
cnt=0;
tot=0;
ttt=0;
memset(a,0,sizeof a);
memset(a2,0,sizeof a2);
memset(to,0,sizeof to);
memset(nx,0,sizeof nx);
memset(head,0,sizeof head);
memset(dep,0,sizeof dep);
memset(siz,0,sizeof siz);
memset(son,0,sizeof son);
memset(faa,0,sizeof faa);
memset(dfn,0,sizeof dfn);
memset(top,0,sizeof top);
}
struct line_tree
{
int l;
int r;
int l_color;
int r_color;
int sum;
int lz;
};
line_tree s[2001000];
void add(int u,int v)
{
to[++cnt]=v;
nx[cnt]=head[u];
head[u]=cnt;
}
void pu(int x)
{
s[x].sum=s[x*2].sum+s[x*2+1].sum;
if(s[x*2].r_color==s[x*2+1].l_color&&s[x*2].r_color!=0&&s[x*2+1].l_color!=0)
{
s[x].sum++;
}
s[x].l_color=s[x*2].l_color;
s[x].r_color=s[x*2+1].r_color;
}
void pd(int x)
{
s[x*2].sum=(s[x*2].r-s[x*2].l);
s[x*2].l_color=s[x*2].r_color=s[x].lz;
s[x*2].lz=s[x].lz;
s[x*2+1].sum=(s[x*2+1].r-s[x*2+1].l);
s[x*2+1].l_color=s[x*2+1].r_color=s[x].lz;
s[x*2+1].lz=s[x].lz;
s[x].lz=0;
}
line_tree add2(line_tree l,line_tree r)
{
line_tree res={0,0,0,0,0,0};
res.l_color=l.l_color,res.r_color=r.r_color;
res.sum=l.sum+r.sum;
if(l.r_color==r.l_color&&l.r_color!=0&&r.l_color!=0)
{
res.sum++;
}
return res;
}
void add3(int x,int l,int r)
{
s[x]={0,0,0,0,0,0};
s[x].l=l;
s[x].r=r;
if(l==r)
{
s[x].sum=1;
s[x].l_color=s[x].r_color=0;
return;
}
int mid=(l+r)/2;
add3(x*2,l,mid);
add3(x*2+1,mid+1,r);
pu(x);
}
void update(int x,int l,int r,int k)
{
if(s[x].l>=l&&s[x].r<=r)
{
s[x].sum=r-l;
s[x].l_color=s[x].r_color=k;
s[x].lz=k;
return;
}
if(s[x].lz)
{
pd(x);
}
int mid=(s[x].l+s[x].r)/2;
if(l<=mid)
{
update(x*2,l,r,k);
}
if(r>mid)
{
update(x*2+1,l,r,k);
}
pu(x);
}
line_tree find(int x,int l,int r)
{
if(s[x].l>=l&&s[x].r<=r)
{
return s[x];
}
if(s[x].lz)
{
pd(x);
}
int mid=(s[x].l+s[x].r)/2;
line_tree aa={0,0,0,0,0,0},bb={0,0,0,0,0,0};
if(l<=mid)
{
aa=find(x*2,l,r);
}
if(mid<r)
{
bb=find(x*2+1,l,r);
}
return add2(aa,bb);
}
void update2(int x,int y,int k)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
{
swap(x,y);
}
update(1,dfn[top[x]],dfn[x],k);
x=faa[top[x]];
}
if(dep[x]>dep[y])
{
swap(x,y);
}
update(1,dfn[x],dfn[y],k);
}
int find2(int x,int y)
{
line_tree l={0,0,0,0,0,0};
line_tree r={0,0,0,0,0,0};
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])
{
r=add2(find(1,dfn[top[y]],dfn[y]),r);
y=faa[top[y]];
}
else
{
l=add2(find(1,dfn[top[x]],dfn[x]),l);
x=faa[top[x]];
}
}
if(dep[x]<dep[y])
{
r=add2(find(1,dfn[x],dfn[y]),r);
}
else
{
l=add2(find(1,dfn[y],dfn[x]),l);
}
swap(l.l_color,l.r_color);
return add2(l,r).sum;
}
void dfs1(int u,int fa)
{
faa[u]=fa;
siz[u]=1;
dep[u]=dep[fa]+1;
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(v==fa)
{
continue;
}
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>siz[son[u]])
{
son[u]=v;
}
}
}
void dfs2(int u,int _top)
{
top[u]=_top;
dfn[u]=++tot;
a2[tot]=a[u];
if(!son[u])
{
return;
}
dfs2(son[u],_top);
for(int i=head[u];i;i=nx[i])
{
int v=to[i];
if(v==faa[u]||v==son[u])
{
continue;
}
dfs2(v,v);
}
}
int main()
{
scanf("%d",&T);
while(T--)
{
init();
scanf("%d%d",&n,&m);
rep(i,1,n-1)
{
int u;
int v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
dfs1(1,0);
dfs2(1,1);
add3(1,1,n);
rep(i,1,m)
{
int op;
int a;
int b;
scanf("%d%d%d",&op,&a,&b);
if(op==1)
{
update2(a,b,++ttt);
}
else if(op==2)
{
printf("%d\n",find2(a,b));
}
}
}
return 0;
}