本来T#13,看完讨论区优化的帖子改完后T#13 & #15
求看看哪里还可以优化
#include <bits/stdc++.h>
using namespace std;
#define int long long
int n, m, fa[300010][20], dep[300010], pre[300010], lg2[300010], x;
int cnt[300010]; // 表示此点上面的边交量
struct node {
int f, e, lca, t;
} a[300010];
struct edge {
int v, w;
};
vector<edge> e[300010];
unordered_map<int, int> um;
void dfs(int f, int x, int d) {
pre[x] += pre[f];
fa[x][0] = f;
dep[x] = d;
for (edge v : e[x]) {
if (v.v == f) continue;
pre[v.v] += v.w;
um[v.v] = v.w;
dfs(x, v.v, d + 1);
}
}
int findlca(int x, int y) {
if (dep[x] > dep[y]) swap(x, y);
while (dep[x] < dep[y]) {
int d = lg2[dep[y] - dep[x]];
y = fa[y][d];
}
for (int k = lg2[n]; k >= 0; k--) {
if (fa[x][k] != fa[y][k]) {
x = fa[x][k];
y = fa[y][k];
}
}
return x == y ? x : fa[x][0];
}
void getnum(int f, int x) {
for (edge v : e[x]) {
if (v.v == f) continue;
getnum(x, v.v);
cnt[x] += cnt[v.v];
}
}
bool check() {
int num = 0;
memset(cnt, 0, sizeof(cnt));
for (int i = 1; i <= m; i++) {
if (a[i].t > x) {
cnt[a[i].f]++;
cnt[a[i].e]++;
cnt[a[i].lca] -= 2;
num++;
}
}
getnum(0, 1);
int mx = -1, id;
for (int i = 1; i <= n; i++) {
if (cnt[i] == num && um[i] > mx) {
mx = um[i];
id = i;
}
}
if (mx == -1) return false;
for (int i = 1; i <= m; i++) {
if (a[i].t > x) {
if (a[i].t - mx > x) return false;
}
}
return true;
}
signed main() {
int tmp1 = 0, tmp2 = 0;
scanf("%d %d", &n, &m);
for (int i = 1; i <= n + 10; i++) {
lg2[i] = lg2[i - 1] + (2 * (1 << lg2[i - 1]) == i);
}
for (int i = 1; i < n; i++) {
int x, y, z;
scanf("%lld %lld %lld", &x, &y, &z);
e[x].push_back({y, z});
e[y].push_back({x, z});
tmp2 = max(tmp2, z);
}
for (int i = 1; i <= m; i++) {
int x, y;
scanf("%lld %lld", &x, &y);
a[i] = {x, y, 0, 0};
}
memset(fa, 0, sizeof(fa));
dfs(0, 1, 1);
for (int i = 1, len = 2; len <= n; i++, len *= 2) {
for (int j = 1; j <= n; j++) {
fa[j][i] = fa[fa[j][i - 1]][i - 1];
}
}
for (int i = 1; i <= m; i++) {
a[i].lca = findlca(a[i].f, a[i].e);
a[i].t = pre[a[i].f] + pre[a[i].e] - 2 * pre[a[i].lca];
tmp1 = max(tmp1, a[i].t);
}
int l = tmp1 - tmp2, r = tmp1;
while (l < r) {
x = (l + r) / 2;
if (check()) r = x;
else l = x + 1;
}
printf("%lld", r);
return 0;
}