MnZn求算复杂度
查看原帖
MnZn求算复杂度
674793
luoguhandongheng楼主2025/1/27 20:26
#pragma GCC optimize(2)
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
#define Tp template <typename T>
#define Ts template <typename T, typename... _T>
char buf[1 << 20], *p1 = buf, *p2 = buf;
#define getchar() (p1 == p2 && (p2 = buf + fread(p1 = buf, 1, 1 << 20, stdin), p1 == p2) ? EOF : *p1++)
Tp void read(T &x)
{
    x = 0;
    char c = getchar();
    bool f = 0;
    for (; !isdigit(c); c = getchar())
        if (c == '-')
            f = 1;
    for (; isdigit(c); c = getchar())
        x = (x << 1) + (x << 3) + (c ^ 48);
    f && (x = -x);
}
Ts void read(T &x, _T &...y) { read(x), read(y...); }
#define bst __gnu_pbds::tree<pii, __gnu_pbds::null_type, less<pii>, \
                             __gnu_pbds::rb_tree_tag,               \
                             __gnu_pbds::tree_order_statistics_node_update>
typedef long long ll;
typedef pair<ll, ll> pii;
const int N = 1e5 + 5, mod = 1e9 + 7;
vector<pii> e[N];
int vcnt[N], siz[N], tsiz, rt, maxsiz[N], n, k;
ll dis[N], L;
bool vis[N];
void chmax(int &a, int b)
{
    a = max(a, b);
}
void getrt(int u, int fa)
{
    siz[u] = 1;
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (!vis[v] && v != fa)
        {
            getrt(v, u);
            siz[u] += siz[v];
            chmax(maxsiz[u], siz[v]);
        }
    }
    chmax(maxsiz[u], tsiz - siz[u]);
    if (maxsiz[u] < maxsiz[rt])
        rt = u;
}
vector<ll> dist;
vector<int> s;
void getdis(int u, int fa)
{
    dist.push_back(dis[u]);
    s.push_back(u);
    for (pii x : e[u])
    {
        int v, w;
        tie(v, w) = x;
        if (!vis[v] && v != fa)
        {
            dis[v] = dis[u] + w;
            getdis(v, u);
        }
    }
}
void cal(int u, ll d, int upd)
{
    dist.clear();
    s.clear();
    dis[u] = 0;
    getdis(u, 0);
    sort(dist.begin(), dist.end());
    for (int v : s)
    {
        auto pos = upper_bound(dist.begin(), dist.end(), d - dis[v]);
        int k = pos - dist.begin();
        vcnt[v] += upd * k;
    }
}
void dfs(int u)
{
    vis[u] = 1;
    cal(u, L, 1);
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (!vis[v])
        {
            cal(v, L - 2 * w, -1);
        }
    }
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (!vis[v])
        {
            tsiz = siz[v], rt = 0;
            getrt(v, u);
            dfs(rt);
        }
    }
}
ll frac[N], inv[N], ifrac[N];
void init()
{
    frac[0] = inv[1] = ifrac[0] = 1;
    for (int i = 2; i <= n; ++i)
    {
        (inv[i] = (-mod / i) * inv[mod % i] + mod) %= mod;
    }
    for (int i = 1; i <= n; ++i)
    {
        (frac[i] = frac[i - 1] * i) %= mod;
        (ifrac[i] = ifrac[i - 1] * inv[i]) %= mod;
    }
}
ll C(int i, int j)
{
    if (i < j)
        return 0;
    return ((frac[i] * ifrac[i - j]) % mod * ifrac[j]) % mod;
}
bst tr;
int son[N], idfn[N], dfn[N], cnt;
ll wf[N];
void getson(int u, int fa)
{
    dfn[u] = ++cnt, idfn[cnt] = u;
    siz[u] = 1;
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (v != fa)
        {
            dis[v] = dis[u] + w;
            wf[v] = w;
            getson(v, u);
            if (siz[son[u]] < siz[v])
                son[u] = v;
            siz[u] += siz[v];
        }
    }
}
void color(int u, bool typ)
{
    if (typ == 1)
    {
        for (int i = dfn[u]; i <= dfn[u] + siz[u] - 1; ++i)
            tr.insert(pii{dis[idfn[i]], idfn[i]});
    }
    else
    {
        for (int i = dfn[u]; i <= dfn[u] + siz[u] - 1; ++i)
            tr.erase(pii{dis[idfn[i]], idfn[i]});
    }
}
ll ans;
void solve(int u, int fa)
{
    (ans += C(vcnt[u], k)) %= mod;
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (v != son[u] && v != fa)
        {
            solve(v, u);
        }
    }
    if (son[u])
        solve(son[u], u);
    for (pii x : e[u])
    {
        ll v, w;
        tie(v, w) = x;
        if (v != son[u] && v != fa)
        {
            color(v, 1);
        }
    }
    tr.insert(pii{dis[u], u});
    if (u != 1)
    {
        int c1 = tr.order_of_key(pii{dis[u] + L + 1, 0}), c2 = tr.order_of_key(pii{dis[u] + L - wf[u] + 1, 0});
        (ans -= C(vcnt[u] - c1 + c2, k) - mod) %= mod;
    }
    if (son[fa] != u)
    {
        color(u, 0);
    }
}
signed main()
{
    read(n, k, L);
    for (int i = 1; i < n; ++i)
    {
        ll u, v, w;
        read(u, v, w);
        e[u].emplace_back(v, w);
        e[v].emplace_back(u, w);
    }
    init();
    getrt(1, 0);
    tsiz = siz[1];
    maxsiz[0] = 1e9, rt = 0;
    getrt(1, 0);
    dfs(rt);
    getson(1, 0);
    tr.clear();
    solve(1, 0);
    cout << (ans * frac[k]) % mod << '\n';
    return 0;
}

总体来说就是一个点分 O(nlog2n)O(n\log^2 n) 加一个启发式合并套平衡树 O(nlog2n)O(n \log^2 n)。应该是 O(nlog2n)O(n \log^2 n) 的复杂度啊。但是过不了 n=1e5,是我复杂度算错了,还是写假了,还是常数太大了///

RP++

2025/1/27 20:26
加载中...