https://www.luogu.com.cn/problem/P1505
代码:
#include <bits/stdc++.h>
#define debug(a) cout<<#a<<"="<<a<<"\n";
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define pep(i,a,b) for(int i=(a);i>=(b);i--)
#define rrep(i,a,b,s) for(int i=(a);i<=(b);i+=(s))
#define ppep(i,a,b,s) for(int i=(a);i>=(b);i-=(s))
using namespace std;
int n;
int m;
int cnt;
int a[2001000];
int s[2001000];
int f[2001000];
int dfn[2001000];
int dfn2[2001000];
int siz[2001000];
int son[2001000];
int top[2001000];
int h[2001000];
int lz[2001000];
int mx[2001000];
int mn[2001000];
struct node
{
int u;
int v;
};
struct node2
{
int to;
int ss;
};
struct node3
{
int s;
int maxn;
int mins;
};
node k[2001000];
vector<node2> g[2001000];
void dfs(int u,int fa)
{
siz[u]=1;
h[u]=h[fa]+1;
f[u]=fa;
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i].to;
if(f[u]!=v)
{
a[v]=g[u][i].ss;
dfs(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v])
{
son[u]=v;
}
}
}
}
void dfs2(int u,int fa)
{
top[u]=fa;
dfn[u]=++cnt;
dfn2[cnt]=u;
if(!son[u])
{
return ;
}
dfs2(son[u],fa);
for(int i=0;i<g[u].size();i++)
{
int v=g[u][i].to;
if(f[u]!=v&&v!=son[u])
{
dfs2(v,v);
}
}
}
void pt(int x,int l,int r)
{
s[x]=-s[x];
swap(mn[x],mx[x]);
mn[x]=-mn[x];
mx[x]=-mx[x];
lz[x]=lz[x]^1;
}
void pd(int x,int l,int r)
{
int mid=(l+r)/2;
pt(x*2,l,mid);
pt(x*2+1,mid+1,r);
lz[x]=0;
}
void add(int x,int l,int r)
{
if(l==r)
{
s[x]=mn[x]=mx[x]=a[dfn2[l]];
return;
}
int mid=(l+r)/2;
add(x*2,l,mid);
add(x*2+1,mid+1,r);
s[x]=s[x*2]+s[x*2+1];
mn[x]=min(mn[x*2],mn[x*2+1]);
mx[x]=max(mx[x*2],mx[x*2+1]);
}
void agg1(int x,int l,int r,int xx,int yy)
{
if(l==r)
{
s[x]=mn[x]=mx[x]=yy;
return;
}
if(lz[x])
{
pd(x,l,r);
}
int mid=(l+r)/2;
if(mid>=xx) agg1(x*2,l,mid,xx,yy);
else agg1(x*2+1,mid+1,r,xx,yy);
s[x]=s[x*2]+s[x*2+1];
mn[x]=min(mn[x*2],mn[x*2+1]);
mx[x]=max(mx[x*2],mx[x*2+1]);
}
void agg2(int x,int l,int r,int xx,int yy)
{
if(xx<=l&&r<=yy)
{
pt(x,l,r);
return;
}
if(lz[x])
{
pd(x,l,r);
}
int mid=(l+r)/2;
if(mid>=xx) agg2(x*2,l,mid,xx,yy);
if(mid+1<=yy) agg2(x*2+1,mid+1,r,xx,yy);
s[x]=s[x*2]+s[x*2+1];
mn[x]=min(mn[x*2],mn[x*2+1]);
mx[x]=max(mx[x*2],mx[x*2+1]);
}
void agg3(int u,int v)
{
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]])
{
swap(u,v);
}
agg2(1,1,n,dfn[top[u]],dfn[u]);
u=f[top[u]];
}
if(u==v) return;
if(h[u]<h[v])
{
swap(u,v);
}
agg2(1,1,n,dfn[v]+1,dfn[u]);
}
node3 find2(int x,int l,int r,int xx,int yy)
{
if(xx<=l&&r<=yy)
{
return {s[x],mx[x],mn[x]};
}
if(lz[x])
{
pd(x,l,r);
}
int mid=(l+r)/2;
node3 u;
node3 u1={0,INT_MIN,INT_MAX};
node3 u2={0,INT_MIN,INT_MAX};
if(mid>=xx)
{
u1=find2(x*2,l,mid,xx,yy);
}
if(mid+1<=yy)
{
u2=find2(x*2+1,mid+1,r,xx,yy);
}
u.s=u1.s+u2.s;
u.maxn=max(u1.maxn,u2.maxn);
u.mins=min(u1.mins,u2.mins);
return u;
}
node3 find(int x,int y)
{
node3 ans={0,INT_MIN,INT_MAX};
node3 kk;
while(top[x]!=top[y])
{
if(h[top[x]]<h[top[y]])
{
swap(x,y);
}
kk=find2(1,1,n,dfn[top[x]],dfn[x]);
debug(kk.s)
ans.s=ans.s+kk.s;
ans.maxn=max(ans.maxn,kk.maxn);
ans.mins=min(ans.mins,kk.mins);
x=f[top[x]];
}
if(x==y) return ans;
if(h[x]<h[y])
{
swap(x,y);
}
kk=find2(1,1,n,dfn[y]+1,dfn[x]);
ans.s=ans.s+kk.s;
ans.maxn=max(ans.maxn,kk.maxn);
ans.mins=min(ans.mins,kk.mins);
return ans;
}
int main()
{
cin>>n;
rep(i,1,n-1)
{
int u;
int v;
int w;
cin>>u>>v>>w;
u++;
v++;
k[i].u=u;
k[i].v=v;
g[u].push_back({v,w});
g[v].push_back({u,w});
}
cin>>m;
dfs(1,0);
dfs2(1,1);
add(1,1,n);
rep(i,1,m)
{
string op;
int u;
int v;
cin>>op>>u>>v;
if(op=="C")
{
u++;
agg1(1,1,n,dfn[u],v);
}
else if(op=="N")
{
u++;
v++;
agg3(u,v);
}
else if(op=="SUM")
{
u++;
v++;
cout<<find(u,v).s<<endl;
}
else if(op=="MAX")
{
u++;
v++;
cout<<find(u,v).maxn<<endl;
}
else if(op=="MIN")
{
u++;
v++;
cout<<find(u,v).mins<<endl;
}
}
return 0;
}