RE 0pts求条 悬3关
查看原帖
RE 0pts求条 悬3关
502314
Just_A_Sentence楼主2024/12/15 00:04

#include<bits/stdc++.h>
using namespace std;
struct tree{
	int to,nxt;
}tr[60010];
struct tree1{
	int l,r;
	int sum,imax;
}tr1[60010];
int dep[30005],FA[30005],siz[30005],son[30005],top[30005],head[30005],DFS[30005],id[30005],a[30005];
int tot=0,dtot=0;
void add(int u,int v){
	tr[++tot].to=v;
	tr[tot].nxt=head[u];
	head[u]=tot;
}
void dfs1(int x,int fa){
	FA[x]=fa;
	dep[x]=dep[fa]+1;
	int imax=-1,pmax=0;
	siz[x]=1;
	for(int i=head[x];i;i=tr[i].nxt){
		int u=tr[i].to;
		if(u!=fa){
			dfs1(u,x);
			if(siz[u]>imax){
				imax=siz[u];
				pmax=u;
			}
			siz[x]+=siz[u];
		}
	}
	son[x]=pmax;
	return;
}
void dfs2(int x,int fa){
	if(son[fa]==x) top[x]=top[fa];
	else top[x]=x;
	for(int i=head[x];i;i=tr[i].nxt){
		int u=tr[i].to;
		if(u!=fa) dfs2(u,x);
	}
}
void dfs(int x,int fa){
	DFS[++dtot]=x;
	id[x]=dtot;
	if(son[x]!=0)
	dfs(son[x],x);
	for(int i=head[x];i;i=tr[i].nxt){
		int u=tr[i].to;
		if(u!=fa&&u!=son[x]){
			dfs(u,x);
		}
	}
}
void build(int x,int l,int r){
	tr1[x].l=l;
	tr1[x].r=r;
	if(l==r){
		tr1[x].imax=tr1[x].sum=a[DFS[l]];
		return;
	}
	int mid=(l+r)>>1;
	build(x<<1,l,mid);
	build(x<<1|1,mid+1,r);
	tr1[x].sum=tr1[x<<1].sum+tr1[x<<1|1].sum;
	tr1[x].imax=max(tr1[x<<1].imax,tr1[x<<1|1].imax);
}
void change(int x,int xx,int y){
	if(tr1[x].l==tr1[x].r){
		tr1[x].imax=tr1[x].sum=y;
		return;
	}
	int mid=(tr1[x].l+tr1[x].r)>>1;
	if(id[xx]<=mid) change(x<<1,xx,y);
	else change(x<<1|1,xx,y);
	tr1[x].sum=tr1[x<<1].sum+tr1[x<<1|1].sum;
	tr1[x].imax=max(tr1[x<<1].imax,tr1[x<<1|1].imax);
}
int LCA(int u,int v){
	if(top[u]==top[v]){
		if(dep[u]<dep[v]) return u;
		return v;
	}
	if(dep[top[u]]<dep[top[v]]) swap(u,v);
	while(top[u]!=top[v]){
		if(dep[top[u]]<dep[top[v]]) swap(u,v);
		u=FA[top[u]];
	}
	if(dep[u]<dep[v]) return u;
	return v;
}
int tfind1(int x,int l,int r){
	if(tr1[x].l>=l&&tr1[x].r<=r) return tr1[x].sum;
	int mid=(l+r)>>1,sum=0;
	if(mid>=l) sum+=tfind1(x<<1,l,r);
	if(mid<r) sum+=tfind1(x<<1|1,l,r);
	return sum;
}
int tfind2(int x,int l,int r){
	if(tr1[x].l>=l&&tr1[x].r<=r) return tr1[x].imax;
	int mid=(l+r)>>1,imax=0;
	if(mid>=l) imax=max(imax,tfind2(x<<1,l,r));
	if(mid<r) imax=max(imax,tfind2(x<<1|1,l,r));
	return imax;
}
int find1(int u,int v){
	int lca=LCA(u,v);
	int sum=0;
	while(top[u]!=top[lca]){
		sum+=tfind1(1,id[top[u]],id[u]);
		u=FA[top[u]];
	}
	sum+=tfind1(1,id[lca],id[u]);
	while(top[v]!=top[lca]){
		sum+=tfind1(1,id[top[v]],id[v]);
		v=FA[top[v]];
	}
	sum+=tfind1(1,id[lca],id[v]);
	return sum;
}
int find2(int u,int v){
	int lca=LCA(u,v);
	int imax=0;
	while(top[u]!=top[lca]){
		imax=max(imax,tfind2(1,id[top[u]],id[u]));
		u=FA[top[u]];
	}
	imax=max(imax,tfind2(1,id[lca],id[u]));
	while(top[v]!=top[lca]){
		imax=max(imax,tfind2(1,id[top[v]],id[v]));
		v=FA[top[v]];
	}
	imax=max(imax,tfind2(1,id[lca],id[v]));
	return imax;
}
int main(){
	int n;
	scanf("%d",&n);
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d%d",&u,&v);
		add(u,v);
		add(v,u);
	}
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	dfs1(1,0);
	dfs2(1,0);
	dfs(1,0);
	build(1,1,n);
	int q;
	scanf("%d",&q);
	while(q--){
		char s[15];
		int x,y;
		scanf("%s%d%d",s,&x,&y);
		if(s[1]=='H'){
			change(1,x,y);
		}
		else if(s[1]=='M'){
			printf("%d\n",find2(x,y));
		}
		else{
			printf("%d\n",find1(x,y));
		}
	}
	return 0;
}

2024/12/15 00:04
加载中...