10分WA求助
查看原帖
10分WA求助
375195
闲着没事人楼主2021/1/8 14:50

应该是操作2有问题,但我看了好久都看不出来,求各位大佬的解答,多谢!

#include <bits/stdc++.h>
#define N 100001
#define lson p << 1
#define rson (p << 1) + 1
using namespace std;
long long n, m, r, P, cnt, res, head[N], point[N], id[N], np[N], tot, top[N], son[N], size[N], fa[N], deep[N];

struct edge{
	int v, next;
}e[N * 2];

struct tree{
	int l, r;
	long long dat, add;
}t[4 * N];

int len(int p){
	return t[p].r - t[p].l + 1;
}

void add_edge(int u, int v){
	e[++tot].v = v;
	e[tot].next = head[u];
	head[u] = tot;
}

void dfs1(int x, int f, int dep){
	deep[x] = dep;
	fa[x] = f;
	size[x] = 1;
	int mxson = -1;
	for(int i = head[x]; i; i = e[i].next){
		int v = e[i].v;
		if(v == f) continue;
		dfs1(v, x, dep+1);
		size[x] += size[v];
		if(size[v] > mxson) mxson = size[v], son[x] = v;
	}
}

void dfs2(int x, int topf){
	id[x] = ++cnt;
	top[x] = topf;
	np[cnt] = point[x];
	if(!son[x]) return;
	dfs2(son[x], topf);
	for(int i = head[x]; i; i = e[i].next){
		int v = e[i].v;
		if(v == fa[x] || v == son[x]) continue;
		dfs2(v, v);
	}
}

void build(int p, int l, int r){
	t[p].l = l;
	t[p].r = r;
	if(l == r){
		t[p].dat = np[l] % P;
		return;
	}
	int mid = (l + r) / 2;
	build(lson, l, mid);
	build(rson, mid + 1, r);
	t[p].dat = (t[lson].dat + t[rson].dat) % P;
	return;
}

void spread(int p){
	if(!t[p].add) return;
	t[lson].dat += t[p].add * len(lson);
	t[rson].dat += t[p].add * len(rson);
	t[lson].add += t[p].add;
	t[rson].add += t[p].add;
	t[lson].dat %= P;
	t[rson].dat %= P;
	t[p].add = 0;
	return; 
}

void add(int p, int l, int r, int u){
	if(l <= t[p].l && r >= t[p].r){
		t[p].dat += u * len(p);
		t[p].add += u;
		return;
	}
	spread(p);
	int mid = (t[p].l + t[p].r) / 2;
	if(l <= mid) add(lson, l, r, u);
	if(r > mid) add(rson, l, r, u);
	t[p].dat += t[lson].dat + t[rson].dat;
	t[p].dat %= P;
	return;
}

void chain(int p, int l, int r){
	if(l <= t[p].l && r >= t[p].r){
		res += t[p].dat;
		return;
	}
	spread(p);
	int mid = (t[p].l + t[p].r) / 2;
	if(l <= mid) chain(lson, l, r);
	if(r > mid) chain(rson, l, r);
	return;
}

int lca_sum(int x, int y){
//	cout << "lca_sum = ";
	int ans = 0;
	while(top[x] != top[y]){
		if(deep[top[x]] < deep[top[y]]) swap(x, y);
		res = 0;
		chain(1, id[top[x]], id[x]);
		ans = (ans + res) % P;
		x = fa[top[x]];
	}
	if(deep[x] > deep[y]) swap(x, y);
	res = 0;
	chain(1, id[x], id[y]);
	ans += res;
	return ans % P;
}

void lca_add(int x, int y, int u){
	u %= P;
	while(top[x] != top[y]){
		if(deep[top[x]] < deep[top[y]]) swap(x, y);
		add(1, id[top[x]], id[x], u);
		x = fa[top[x]];
	}
	if(deep[x] > deep[y]) swap(x, y);
	add(1, id[x], id[y], u);
//	cout << "lca_add" << endl;
	return;
}

int tree_sum(int x){
	res = 0;
//	cout << "tree_sum = ";
	chain(1, id[x], id[x] + size[x] - 1);
	return res % P;
}

void tree_add(int x, int u){
//	cout << "tree_add" << endl;
	add(1, id[x], id[x] + size[x] - 1, u % P);
	return;
}

//long long ask(int p, int l, int r){
//	if(l <= t[p].l && r >= t[p].r) return t[p].dat;
//	spread(p);
//	long long he = 0;
//	int mid = (t[p].l + t[p].r) / 2;
//	if(l <= mid) he += ask(lson, l, r);
//	if(r > mid) he += ask(rson, l, r);
//	return he;
//}

int main(){
	ios::sync_with_stdio(false);
	cin >> n >> m >> r >> P;
	for(int i = 1; i <= n; i++) cin >> point[i];
	for(int i = 1; i < n; i++){
		int a, b;
		cin >> a >> b;
		add_edge(a, b);
		add_edge(b, a);
	}
	dfs1(r, 0, 0);
	dfs2(r, r);
	build(1, 1, n);
	while(m--){
		int wbbb, o1, o2, o3;
		cin >> wbbb;
		switch(wbbb){
			case 1:
				cin >> o1 >> o2 >> o3;
				lca_add(o1, o2, o3);
				break;
			case 2:
				cin >> o1 >> o2;
				cout << lca_sum(o1, o2) << endl;
				break;
			case 3:
				cin >> o1 >> o2;
				tree_add(o1, o2);
				break;
			case 4:
				cin >> o1;
				cout << tree_sum(o1) << endl;
		}
//		for(int i = 1; i <= n; i++) cout << ask(1, id[i], id[i]) << ' ';
//		cout << endl;
	}
	return 0;
}

2021/1/8 14:50
加载中...