高精度板子求加速
  • 板块P1727 计算π
  • 楼主Ruiqun2009
  • 当前回复17
  • 已保存回复18
  • 发布时间2025/1/21 01:53
  • 上次更新2025/1/21 10:37:43
查看原帖
高精度板子求加速
589895
Ruiqun2009楼主2025/1/21 01:53

我实现了 Chudnovsky 公式求 π\pi。然而我的 bigint::sqrt 十分慢,导致 10005\sqrt{10005} 的高精度计算是我程序的瓶颈。

经过比较,计算 10510^5 位的 π\pibigint::sqrt 需要大约 17 秒计算 10005×105×104\sqrt{10005\times10^{5\times10^4}},而 mpz_sqrt 仅使用约 40ms。我怀疑是复杂度问题。

求调 bigint::sqrt

namespace ntt {
static const mint G = mint::get_primitive_root_prime();
static constexpr unsigned ntt_len = __builtin_ctz((unsigned)(mint::mod() - 1));
mint root[ntt_len + 1];
mint iroot[ntt_len + 1];
mint rate2[ntt_len], irate2[ntt_len];
mint rate3[ntt_len], irate3[ntt_len];
inline void get_rev() {
    root[ntt_len] = G.pow((mint::mod() - 1) >> ntt_len);
    iroot[ntt_len] = root[ntt_len].inv();
    for (unsigned i = ntt_len - 1; ~i; i--) {
        root[i] = root[i + 1] * root[i + 1];
        iroot[i] = iroot[i + 1] * iroot[i + 1];
    }
    mint prod = 1, iprod = 1;
    for (unsigned i = 0; i < ntt_len - 1; i++) {
        rate2[i] = root[i + 2] * prod;
        irate2[i] = iroot[i + 2] * iprod;
        prod *= iroot[i + 2];
        iprod *= root[i + 2];
    }
    prod = iprod = 1;
    for (unsigned i = 0; i < ntt_len - 2; i++) {
        rate3[i] = root[i + 3] * prod;
        irate3[i] = iroot[i + 3] * iprod;
        prod *= iroot[i + 3];
        iprod *= root[i + 3];
    }
}
inline void NTT(mint* a, int log_len) {
    int len = 0;
    while (len < log_len) {
        if (log_len - len == 1) {
            unsigned p = 1u << (log_len - len - 1);
            mint rot = 1;
            for (unsigned s = 0; s < (1u << len); ++s) {
                unsigned offset = s << (log_len - len);
                for (unsigned i = 0; i < p; ++i) {
                    mint l = a[i + offset];
                    mint r = a[i + offset + p] * rot;
                    a[i + offset] = l + r;
                    a[i + offset + p] = l - r;
                }
                if (s + 1 != (1u << len)) rot = rot * rate2[__builtin_ctz(~s)];
            }
            ++len;
        }
        else {
            unsigned p = 1u << (log_len - len - 2);
            mint rot = 1, imag = root[2];
            for (unsigned s = 0; s < (1u << len); ++s) {
                mint rot2 = rot * rot;
                mint rot3 = rot2 * rot;
                unsigned offset = s << (log_len - len);
                for (unsigned i = 0; i < p; ++i) {
                    mint a0 = a[i + offset + 0 * p];
                    mint a1 = a[i + offset + 1 * p] * rot;
                    mint a2 = a[i + offset + 2 * p] * rot2;
                    mint a3 = a[i + offset + 3 * p] * rot3;
                    mint a1na3imag = (a1 - a3) * imag;
                    mint na2 = -a2;
                    a[i + offset] = a0 + a1 + a2 + a3;
                    a[i + offset + 1 * p] = a0 - a1 + a2 - a3;
                    a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
                    a[i + offset + 3 * p] = a0 + na2 - a1na3imag;
                }
                if (s + 1 != (1u << len)) rot = rot * rate3[__builtin_ctz(~s)];
            }
            len += 2;
        }
    }
}
inline void INTT(mint* a, int log_len) {
    int len = log_len;
    while (len) {
        if (len == 1) {
            unsigned p = 1u << (log_len - len);
            mint irot = 1;
            for (unsigned s = 0; s < (1u << (len - 1)); ++s) {
                unsigned offset = s << (log_len - len + 1);
                for (unsigned i = 0; i < p; ++i) {
                    mint l = a[i + offset];
                    mint r = a[i + offset + p];
                    a[i + offset] = l + r;
                    a[i + offset + p] = (l - r) * irot;
                }
                if (s + 1 != (1u << (len - 1))) irot = irot * irate2[__builtin_ctz(~s)];
            }
            --len;
        }
        else {
            unsigned p = 1u << (log_len - len);
            mint irot = 1, iimag = iroot[2];
            for (unsigned s = 0; s < (1u << (len - 2)); ++s) {
                mint irot2 = irot * irot;
                mint irot3 = irot2 * irot;
                unsigned offset = s << (log_len - len + 2);
                for (unsigned i = 0; i < p; ++i) {
                    mint a0 = a[i + offset + 0 * p];
                    mint a1 = a[i + offset + 1 * p];
                    mint a2 = a[i + offset + 2 * p];
                    mint a3 = a[i + offset + 3 * p];
                    mint a2na3iimag = (a2 - a3) * iimag;
                    a[i + offset] = a0 + a1 + a2 + a3;
                    a[i + offset + 1 * p] = (a0 - a1 + a2na3iimag) * irot;
                    a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2;
                    a[i + offset + 3 * p] = (a0 - a1 - a2na3iimag) * irot3;
                }
                if (s + 1 != (1u << (len - 2))) irot = irot * irate3[__builtin_ctz(~s)];
            }
            len -= 2;
        }
    }
    const unsigned N = 1u << log_len;
    mint invN = mint(N).inv();
    for (unsigned i = 0; i < N; ++i) a[i] *= invN;
}
static mint buf[1u << 26];
}
using ntt::get_rev;
using ntt::NTT;
using ntt::INTT;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
inline vector<mint> operator*(const vector<mint>& a, const vector<mint>& b) {
    using ntt::buf;
    size_t anssiz = a.size() + b.size() - 1;
    vector<mint> c(anssiz);
    size_t len = 1;
    while (len < anssiz) len <<= 1;
    memcpy(buf, a.data(), a.size() * sizeof(mint));
    memset(buf + a.size(), 0, (len - a.size()) * sizeof(mint));
    memcpy(buf + len, b.data(), b.size() * sizeof(mint));
    memset(buf + len + b.size(), 0, (len - b.size()) * sizeof(mint));
    const int x = std::__lg(len);
    NTT(buf, x);
    NTT(buf + len, x);
    for (size_t i = 0; i < len; i++) buf[i] *= buf[i + len];
    INTT(buf, x);
    std::copy_n(buf, anssiz, c.begin());
    return c;
}
inline vector<mint> sqr(const vector<mint>& a) {
    using ntt::buf;
    size_t anssiz = a.size() + a.size() - 1;
    vector<mint> c(anssiz);
    size_t len = 1;
    while (len < anssiz) len <<= 1;
    memcpy(buf, a.data(), a.size() * sizeof(mint));
    memset(buf + a.size(), 0, (len - a.size()) * sizeof(mint));
    const int x = std::__lg(len);
    NTT(buf, x);
    for (size_t i = 0; i < len; i++) buf[i] *= buf[i];
    INTT(buf, x);
    std::copy_n(buf, anssiz, c.begin());
    return c;
}
class bigint {
    static const __uint128_t base = 1000000;
    static const unsigned log_base = 6;
    static void flatten(bigint& a) {
        vector<mint>& arr = a.res;
        mint inc(0);
        for (size_t i = 0; i < arr.size(); i++) {
            arr[i] += inc;
            inc = arr[i].val() / base;
            arr[i] = arr[i].val() % base;
        }
        if (inc.val()) {
            while (inc.val() >= base) {
                arr.push_back(inc.val() % base);
                inc = inc.val() / base;
            }
            arr.push_back(inc.val());
        }
    }
    vector<mint> shrink(const vector<mint>& arr) {
        vector<mint> a(arr);
        while (!a.empty() && a.back().val() == 0) a.pop_back();
        return a;
    }
    inline bigint(const vector<mint>& rhs) : res(rhs), negative(false) { flatten(*this); }
public:
    inline bigint() : res(), negative(false) {}
    inline bigint& operator=(const bigint& rhs) {
        res = rhs.res;
        negative = rhs.negative;
        return *this;
    }
    inline bigint(const bigint& rhs) : res(rhs.res), negative(rhs.negative) {}
    inline bigint& operator=(bigint&& rhs) {
        res = std::move(rhs.res);
        negative = std::move(rhs.negative);
        return *this;
    }
    inline bigint(__uint128_t x) : res(), negative(false) {
        while (x) {
            res.push_back(x % base);
            x /= base;
        }
    }
    inline bigint(bigint&& rhs) : res(std::move(rhs.res)), negative(std::move(rhs.negative)) {}
    inline ~bigint() = default;
public:
    inline void input(const std::string& s) {
        res.clear();
        negative = false;
        int f;
        if(s[0] == '-') f=1, negative = true;
        else f=0;
        for (int i=s.size()-1; i>=f; i-=log_base) {
            int st = std::max(f, int(i-(log_base-1))), len = i-st+1;
            res.push_back((long long)(atoi(s.substr(st,len).c_str())));
        }
    }
    inline std::string output() const {
        std::string ret;
        if (negative) ret += '-';
        if (res.empty()) {
            ret += '0';
            return ret;
        }
        ret += to_string_128(res.back().val());
        if (res.size() > 1) for (size_t i = res.size() - 2; ~i; i--) {
            if (res[i].val() < 100000) ret += '0';
            if (res[i].val() < 10000) ret += '0';
            if (res[i].val() < 1000) ret += '0';
            if (res[i].val() < 100) ret += '0';
            if (res[i].val() < 10) ret += '0';
            ret += to_string_128(res[i].val());
        }
        return ret;
    }
    int cmpabs(const bigint& b) const {
        if (res.size() < b.res.size()) return -1;
        if (res.size() > b.res.size()) return 1;
        for (size_t i = res.size() - 1; ~i; i--) {
            if (res[i].val() < b.res[i].val()) return -1;
            if (res[i].val() > b.res[i].val()) return 1;
        }
        return 0;
    }
public:
    bigint operator-() const {
        bigint ret(*this);
        ret.negative ^= 1;
        return ret;
    }
    bool operator<(const bigint& rhs) const {
        if (negative && !rhs.negative) return true;
        if (!negative && rhs.negative) return false;
        if (negative && rhs.negative) return -rhs < -*this;
        if (res.size() ^ rhs.res.size()) return res.size() < rhs.res.size();
        for (size_t i = res.size() - 1; ~i; i--) if (res[i] != rhs.res[i]) return res[i].val() < rhs.res[i].val();
        return false;
    }
    bigint& operator*=(const bigint& rhs) {
        res = (res * rhs.res);
        negative ^= rhs.negative;
        flatten(*this);
        return *this;
    }
    bigint& operator<<=(size_t len) {
        res.insert(res.begin(), len, mint(0));
        return *this;
    }
    bigint& operator>>=(size_t len) {
        if (len >= res.size()) {
            res.clear();
            return *this;
        }
        res = vector<mint>(res.begin() + len, res.end());
        return *this;
    }
    inline bigint& operator+=(const bigint& b) {
        if (negative && !b.negative) {
            return *this = b - -*this;
        }
        if (negative && b.negative) {
            *this = -*this + -b;
            return *this = -*this;
        }
        if (!negative && b.negative) {
            return *this -= -b;
        }
        res.resize(std::max(res.size(), b.res.size()));
        for (size_t i = 0, iend = b.res.size(); i < iend; i++) res[i] += b.res[i];
        flatten_add(*this);
        return *this;
    }
    inline bigint& operator-=(const bigint& b) {
        if (negative && !b.negative) {
            *this = -*this + b;
            return *this = -*this;
        }
        if (negative && b.negative) {
            return *this = -b - -*this;
        }
        if (!negative && b.negative) {
            return *this += -b;
        }
        if (*this < b) return *this = -(b - *this);
        res.resize(std::max(res.size(), b.res.size()));
        for (size_t i = 0, iend = b.res.size(); i < iend; i++) res[i] -= b.res[i];
        flatten_sub(*this);
        return *this;
    }
    bigint& operator/=(const bigint& b) {
        negative ^= b.negative;
        size_t m = res.size(), n = b.res.size();
        vector<mint> rhs_p(b.res);
        size_t offset;
        if (m <= n << 1) offset = n << 1;
        else {
            offset = m + n;
            rhs_p.insert(rhs_p.begin(), m - n, mint(0));
        }
        auto _res = (*this * div_inv_accurate(rhs_p)) >> offset;
        auto ret = *this - _res * b;
        flatten_sub(ret);
        if (ret < b) ret = _res;
        else {
            ret = _res;
            ret.res[0] += 1;
            flatten_add(ret);
        }
        *this = ret;
        return *this;
    }
    bigint& operator%=(const bigint& b) {
        bigint quo = *this;
        quo /= b;
        *this -= b * quo;
        flatten_sub(*this);
        return *this;
    }
    bigint pow(size_t n) const {
        bigint b(*this), ret(1);
        while (n) {
            if (n & 1) ret *= b;
            b = b.sqr();
            n >>= 1;
        }
        return ret;
    }
    bigint sqrt() const {
        bigint ans(*this >> ((res.size() - 1) / 2));
        flatten_add(ans);
        long rem = 0;
        while (true) {
            bigint tmp = ans + *this / ans;
            rem = 0;
            for (size_t i = tmp.res.size() - 1; ~i; i--) {
                auto [p, q] = std::div((long)tmp.res[i].val() + rem * base, 2l);
                tmp.res[i] = p;
                rem = q;
            }
            if (tmp.res.back() == 0) tmp.res.pop_back();
            if (tmp.cmpabs(ans) >= 0) {
                ans = tmp;
                break;
            }
            ans = tmp;
        }
        do {
            ans.res[0] += 1;
            flatten_add(ans);
        } while (cmpabs(ans.sqr()) >= 0);
        ans.res[0] -= 1;
        flatten_sub(ans);
        ans.negative = negative;
        return ans;
    }
    // floor(e^deg)
    bigint exp(uint64_t deg) const {
        bigint p, q;
        recurse_exp(p, q, 0, deg);
        std::string s = (p * bigint::div_inv_accurate(q)).output();
        s[0] = '2';
        p.input(s);
        q = 1;
        double digits = deg * M_LOG10E;
        while (deg) {
            if (deg & 1) q *= p;
            p = p.sqr();
            deg >>= 1;
        }
        s = q.output();
        int index = digits;
        if (abs(digits - round(digits)) < 0.1) {
            index = digits;
            if (s[0] == '1') index++;
        }
        p.input(s.substr(0, index + 1));
        return p;
    }
public:
    friend bigint operator+(const bigint &lhs, const bigint &rhs) { return bigint(lhs) += rhs; }
    friend bigint operator-(const bigint &lhs, const bigint &rhs) { return bigint(lhs) -= rhs; }
    friend bigint operator*(const bigint &lhs, const bigint &rhs) { return bigint(lhs) *= rhs; }
    friend bigint operator/(const bigint &lhs, const bigint &rhs) { return bigint(lhs) /= rhs; }
    friend bigint operator%(const bigint &lhs, const bigint &rhs) { return bigint(lhs) %= rhs; }
    friend bigint operator<<(const bigint &lhs, size_t rhs) { return bigint(lhs) <<= rhs; }
    friend bigint operator>>(const bigint &lhs, size_t rhs) { return bigint(lhs) >>= rhs; }
    friend bool operator==(const bigint &lhs, const bigint &rhs) { return lhs.res == rhs.res; }
    friend bool operator!=(const bigint &lhs, const bigint &rhs) { return lhs.res != rhs.res; }
private:
    static void flatten_add(bigint& a) {
        vector<mint>& arr = a.res;
        int inc(0);
        for (size_t i = 0; i < arr.size(); i++) {
            arr[i] += inc;
            inc = arr[i].val() >= base;
            if (inc) arr[i] -= base;
        }
        if (inc) arr.push_back(1);
    }
    static void flatten_sub(bigint &a) {
        vector<mint>& arr = a.res;
        int carry = 0;
        for (size_t i = 0; i < arr.size(); i++) {
            arr[i] -= carry;
            carry = arr[i].val() >= (mint::mod() - 100 * base);
            if (carry) arr[i] += base;
        }
        while (!arr.empty() && !arr.back().val()) arr.pop_back();
    }
    static bigint div_inv(const bigint& arr) {
        size_t da = arr.res.size();
        size_t da0 = (da >> 1) + 1;
        if (da == 1) {
            bigint tmp{base * base / arr.res[0].val()};
            flatten(tmp);
            return tmp;
        }
        else if (da == 2) {
            bigint tmp{base * base * base * base / (arr.res[1].val() * base + arr.res[0].val())};
            flatten(tmp);
            return tmp;
        }
        else {
            bigint tmp(div_inv(arr >> (da - da0)));
            tmp = tmp << (da - da0);
            bigint tem{2};
            tem = tem << (da << 1);
            tem -= arr * tmp;
            flatten_sub(tem);
            return (tmp * tem) >> (da << 1);
        }
    }
    static bigint div_inv_accurate(const bigint& b) {
        vector<bigint> t(7);
        t[0] = b;
        for (int i = 1; i ^ 7; ++i) {
            t[i] = t[i - 1] + t[i - 1];
        }
        size_t n = b.res.size();
        size_t err = 0;
        auto _res = div_inv(b), diff = (bigint{1} << (n << 1)) - b * _res;
        flatten_sub(diff);
        for (int i = 6; i >= 0; --i) if (!(diff < t[i])) {
            diff -= t[i];
            flatten_sub(diff);
            err |= 1u << i;
        }
        _res.res[0] += err;
        flatten(_res);
        return _res;
    }
    bigint sqr() const {
        return bigint{::sqr(res)};
    }
    static void recurse_exp(bigint& p, bigint& q, size_t a, size_t b) {
        if (b == a + 1) {
            p = 1;
            q = b;
            return;
        }
        bigint p0, p1, q0, q1;
        size_t m = (a + b) >> 1;
        bigint::recurse_exp(p0, q0, a, m);
        bigint::recurse_exp(p1, q1, m, b);
        q = q0 * q1;
        p = p0 * q1 + p1;
    }
    vector<mint> res;
    bool negative;
};
void recurse(bigint& p, bigint& q, bigint& r, size_t a, size_t b) {
    if (b == a + 1) {
        p = -bigint(6*a - 5)*bigint(2*a - 1)*bigint(6*a - 1);
        q = bigint(10939058860032000) * bigint(a) * bigint(a) * bigint(a);
        r = p * bigint(545140134*a + 13591409);
        return;
    }
    size_t m = (a + b) >> 1;
    bigint p0, q0, r0, p1, q1, r1;
    recurse(p0, q0, r0, a, m);
    recurse(p1, q1, r1, m, b);
    p = p0 * p1;
    q = q0 * q1;
    r = q1 * r0 + p0 * r1;
}
int main() {
    get_rev();
    size_t n = 100;
    cin >> n;
    bigint p, q, r;
    recurse(p, q, r, 1, std::max(n / 10, (size_t)10));
    cout << p.output() << '\n' << q.output() << '\n' << r.output() << '\n';
    bigint a(10005);
    a <<= std::max(n >> 1, (size_t)10);
#if 1
    a = a.sqrt() * 426880;
#else // 16993ms ^^^ / vvv 39ms
    mpz_t num;
    mpz_init_set_str(num, a.output().c_str(), 10);
    mpz_sqrt(num, num);
    char* buf = mpz_get_str(NULL, 10, num);
    mpz_clear(num);
    a.input(buf);
    a *= 426880;
    free(buf);
#endif
    a = (a * q) / (q * 13591409 + r);
    std::string s = a.output().substr(1);
    cout << "3.\n";
    for (size_t i = 0; i < n; i++) {
        cout << s[i];
        if (i % 10 == 9) cout << ' ';
        if (i % 50 == 49) cout << '\n';
    }
    cout << '\n';
}

Credits:

  • mint 来自 hly1204
  • namespace ntt 内的所有内容来自 Atcoder Library
2025/1/21 01:53
加载中...