TLE 90pts求助
查看原帖
TLE 90pts求助
1227383
MPLN楼主2025/1/23 16:26

本来T#13,看完讨论区优化的帖子改完后T#13 & #15
求看看哪里还可以优化

#include <bits/stdc++.h>
using namespace std;

#define int long long

int n, m, fa[300010][20], dep[300010], pre[300010], lg2[300010], x;
int cnt[300010]; // 表示此点上面的边交量

struct node {
    int f, e, lca, t;
} a[300010];

struct edge {
    int v, w;
};

vector<edge> e[300010];
unordered_map<int, int> um;

void dfs(int f, int x, int d) {
    pre[x] += pre[f];
    fa[x][0] = f;
    dep[x] = d;
    for (edge v : e[x]) {
        if (v.v == f) continue;
        pre[v.v] += v.w;
        um[v.v] = v.w;
        dfs(x, v.v, d + 1);
    }
}

int findlca(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    while (dep[x] < dep[y]) {
        int d = lg2[dep[y] - dep[x]];
        y = fa[y][d];
    }
    for (int k = lg2[n]; k >= 0; k--) {
        if (fa[x][k] != fa[y][k]) {
            x = fa[x][k];
            y = fa[y][k];
        }
    }
    return x == y ? x : fa[x][0];
}

void getnum(int f, int x) {
    for (edge v : e[x]) {
        if (v.v == f) continue;
        getnum(x, v.v);
        cnt[x] += cnt[v.v];
    }
}

bool check() {
    int num = 0;
    memset(cnt, 0, sizeof(cnt));
    
    for (int i = 1; i <= m; i++) {
        if (a[i].t > x) {
            cnt[a[i].f]++;
            cnt[a[i].e]++;
            cnt[a[i].lca] -= 2;
            num++;
        }
    }
    
    getnum(0, 1);
    
    int mx = -1, id;
    for (int i = 1; i <= n; i++) {
        if (cnt[i] == num && um[i] > mx) {
            mx = um[i];
            id = i;
        }
    }
    
    if (mx == -1) return false;
    
    for (int i = 1; i <= m; i++) {
        if (a[i].t > x) {
            if (a[i].t - mx > x) return false;
        }
    }
    
    return true;
}

signed main() {
    int tmp1 = 0, tmp2 = 0;
    scanf("%d %d", &n, &m);
    
    for (int i = 1; i <= n + 10; i++) {
        lg2[i] = lg2[i - 1] + (2 * (1 << lg2[i - 1]) == i);
    }
    
    for (int i = 1; i < n; i++) {
        int x, y, z;
        scanf("%lld %lld %lld", &x, &y, &z);
        e[x].push_back({y, z});
        e[y].push_back({x, z});
        tmp2 = max(tmp2, z);
    }
    
    for (int i = 1; i <= m; i++) {
        int x, y;
        scanf("%lld %lld", &x, &y);
        a[i] = {x, y, 0, 0};
    }
    
    memset(fa, 0, sizeof(fa));
    dfs(0, 1, 1);  
    
    for (int i = 1, len = 2; len <= n; i++, len *= 2) {
        for (int j = 1; j <= n; j++) {
            fa[j][i] = fa[fa[j][i - 1]][i - 1];
        }
    }
    
    for (int i = 1; i <= m; i++) {
        a[i].lca = findlca(a[i].f, a[i].e);
        a[i].t = pre[a[i].f] + pre[a[i].e] - 2 * pre[a[i].lca];
        tmp1 = max(tmp1, a[i].t);
    }
    
    int l = tmp1 - tmp2, r = tmp1;
    while (l < r) {
        x = (l + r) / 2;
        if (check()) r = x;
        else l = x + 1;
    }
    
    printf("%lld", r);
    return 0;
}
2025/1/23 16:26
加载中...