我实现了 Chudnovsky 公式求 π。然而我的 bigint::sqrt
十分慢,导致 10005 的高精度计算是我程序的瓶颈。
经过比较,计算 105 位的 π 时 bigint::sqrt
需要大约 17 秒计算 10005×105×104,而 mpz_sqrt
仅使用约 40ms。我怀疑是复杂度问题。
求调 bigint::sqrt
。
namespace ntt {
static const mint G = mint::get_primitive_root_prime();
static constexpr unsigned ntt_len = __builtin_ctz((unsigned)(mint::mod() - 1));
mint root[ntt_len + 1];
mint iroot[ntt_len + 1];
mint rate2[ntt_len], irate2[ntt_len];
mint rate3[ntt_len], irate3[ntt_len];
inline void get_rev() {
root[ntt_len] = G.pow((mint::mod() - 1) >> ntt_len);
iroot[ntt_len] = root[ntt_len].inv();
for (unsigned i = ntt_len - 1; ~i; i--) {
root[i] = root[i + 1] * root[i + 1];
iroot[i] = iroot[i + 1] * iroot[i + 1];
}
mint prod = 1, iprod = 1;
for (unsigned i = 0; i < ntt_len - 1; i++) {
rate2[i] = root[i + 2] * prod;
irate2[i] = iroot[i + 2] * iprod;
prod *= iroot[i + 2];
iprod *= root[i + 2];
}
prod = iprod = 1;
for (unsigned i = 0; i < ntt_len - 2; i++) {
rate3[i] = root[i + 3] * prod;
irate3[i] = iroot[i + 3] * iprod;
prod *= iroot[i + 3];
iprod *= root[i + 3];
}
}
inline void NTT(mint* a, int log_len) {
int len = 0;
while (len < log_len) {
if (log_len - len == 1) {
unsigned p = 1u << (log_len - len - 1);
mint rot = 1;
for (unsigned s = 0; s < (1u << len); ++s) {
unsigned offset = s << (log_len - len);
for (unsigned i = 0; i < p; ++i) {
mint l = a[i + offset];
mint r = a[i + offset + p] * rot;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
if (s + 1 != (1u << len)) rot = rot * rate2[__builtin_ctz(~s)];
}
++len;
}
else {
unsigned p = 1u << (log_len - len - 2);
mint rot = 1, imag = root[2];
for (unsigned s = 0; s < (1u << len); ++s) {
mint rot2 = rot * rot;
mint rot3 = rot2 * rot;
unsigned offset = s << (log_len - len);
for (unsigned i = 0; i < p; ++i) {
mint a0 = a[i + offset + 0 * p];
mint a1 = a[i + offset + 1 * p] * rot;
mint a2 = a[i + offset + 2 * p] * rot2;
mint a3 = a[i + offset + 3 * p] * rot3;
mint a1na3imag = (a1 - a3) * imag;
mint na2 = -a2;
a[i + offset] = a0 + a1 + a2 + a3;
a[i + offset + 1 * p] = a0 - a1 + a2 - a3;
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
a[i + offset + 3 * p] = a0 + na2 - a1na3imag;
}
if (s + 1 != (1u << len)) rot = rot * rate3[__builtin_ctz(~s)];
}
len += 2;
}
}
}
inline void INTT(mint* a, int log_len) {
int len = log_len;
while (len) {
if (len == 1) {
unsigned p = 1u << (log_len - len);
mint irot = 1;
for (unsigned s = 0; s < (1u << (len - 1)); ++s) {
unsigned offset = s << (log_len - len + 1);
for (unsigned i = 0; i < p; ++i) {
mint l = a[i + offset];
mint r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] = (l - r) * irot;
}
if (s + 1 != (1u << (len - 1))) irot = irot * irate2[__builtin_ctz(~s)];
}
--len;
}
else {
unsigned p = 1u << (log_len - len);
mint irot = 1, iimag = iroot[2];
for (unsigned s = 0; s < (1u << (len - 2)); ++s) {
mint irot2 = irot * irot;
mint irot3 = irot2 * irot;
unsigned offset = s << (log_len - len + 2);
for (unsigned i = 0; i < p; ++i) {
mint a0 = a[i + offset + 0 * p];
mint a1 = a[i + offset + 1 * p];
mint a2 = a[i + offset + 2 * p];
mint a3 = a[i + offset + 3 * p];
mint a2na3iimag = (a2 - a3) * iimag;
a[i + offset] = a0 + a1 + a2 + a3;
a[i + offset + 1 * p] = (a0 - a1 + a2na3iimag) * irot;
a[i + offset + 2 * p] = (a0 + a1 - a2 - a3) * irot2;
a[i + offset + 3 * p] = (a0 - a1 - a2na3iimag) * irot3;
}
if (s + 1 != (1u << (len - 2))) irot = irot * irate3[__builtin_ctz(~s)];
}
len -= 2;
}
}
const unsigned N = 1u << log_len;
mint invN = mint(N).inv();
for (unsigned i = 0; i < N; ++i) a[i] *= invN;
}
static mint buf[1u << 26];
}
using ntt::get_rev;
using ntt::NTT;
using ntt::INTT;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wclass-memaccess"
inline vector<mint> operator*(const vector<mint>& a, const vector<mint>& b) {
using ntt::buf;
size_t anssiz = a.size() + b.size() - 1;
vector<mint> c(anssiz);
size_t len = 1;
while (len < anssiz) len <<= 1;
memcpy(buf, a.data(), a.size() * sizeof(mint));
memset(buf + a.size(), 0, (len - a.size()) * sizeof(mint));
memcpy(buf + len, b.data(), b.size() * sizeof(mint));
memset(buf + len + b.size(), 0, (len - b.size()) * sizeof(mint));
const int x = std::__lg(len);
NTT(buf, x);
NTT(buf + len, x);
for (size_t i = 0; i < len; i++) buf[i] *= buf[i + len];
INTT(buf, x);
std::copy_n(buf, anssiz, c.begin());
return c;
}
inline vector<mint> sqr(const vector<mint>& a) {
using ntt::buf;
size_t anssiz = a.size() + a.size() - 1;
vector<mint> c(anssiz);
size_t len = 1;
while (len < anssiz) len <<= 1;
memcpy(buf, a.data(), a.size() * sizeof(mint));
memset(buf + a.size(), 0, (len - a.size()) * sizeof(mint));
const int x = std::__lg(len);
NTT(buf, x);
for (size_t i = 0; i < len; i++) buf[i] *= buf[i];
INTT(buf, x);
std::copy_n(buf, anssiz, c.begin());
return c;
}
class bigint {
static const __uint128_t base = 1000000;
static const unsigned log_base = 6;
static void flatten(bigint& a) {
vector<mint>& arr = a.res;
mint inc(0);
for (size_t i = 0; i < arr.size(); i++) {
arr[i] += inc;
inc = arr[i].val() / base;
arr[i] = arr[i].val() % base;
}
if (inc.val()) {
while (inc.val() >= base) {
arr.push_back(inc.val() % base);
inc = inc.val() / base;
}
arr.push_back(inc.val());
}
}
vector<mint> shrink(const vector<mint>& arr) {
vector<mint> a(arr);
while (!a.empty() && a.back().val() == 0) a.pop_back();
return a;
}
inline bigint(const vector<mint>& rhs) : res(rhs), negative(false) { flatten(*this); }
public:
inline bigint() : res(), negative(false) {}
inline bigint& operator=(const bigint& rhs) {
res = rhs.res;
negative = rhs.negative;
return *this;
}
inline bigint(const bigint& rhs) : res(rhs.res), negative(rhs.negative) {}
inline bigint& operator=(bigint&& rhs) {
res = std::move(rhs.res);
negative = std::move(rhs.negative);
return *this;
}
inline bigint(__uint128_t x) : res(), negative(false) {
while (x) {
res.push_back(x % base);
x /= base;
}
}
inline bigint(bigint&& rhs) : res(std::move(rhs.res)), negative(std::move(rhs.negative)) {}
inline ~bigint() = default;
public:
inline void input(const std::string& s) {
res.clear();
negative = false;
int f;
if(s[0] == '-') f=1, negative = true;
else f=0;
for (int i=s.size()-1; i>=f; i-=log_base) {
int st = std::max(f, int(i-(log_base-1))), len = i-st+1;
res.push_back((long long)(atoi(s.substr(st,len).c_str())));
}
}
inline std::string output() const {
std::string ret;
if (negative) ret += '-';
if (res.empty()) {
ret += '0';
return ret;
}
ret += to_string_128(res.back().val());
if (res.size() > 1) for (size_t i = res.size() - 2; ~i; i--) {
if (res[i].val() < 100000) ret += '0';
if (res[i].val() < 10000) ret += '0';
if (res[i].val() < 1000) ret += '0';
if (res[i].val() < 100) ret += '0';
if (res[i].val() < 10) ret += '0';
ret += to_string_128(res[i].val());
}
return ret;
}
int cmpabs(const bigint& b) const {
if (res.size() < b.res.size()) return -1;
if (res.size() > b.res.size()) return 1;
for (size_t i = res.size() - 1; ~i; i--) {
if (res[i].val() < b.res[i].val()) return -1;
if (res[i].val() > b.res[i].val()) return 1;
}
return 0;
}
public:
bigint operator-() const {
bigint ret(*this);
ret.negative ^= 1;
return ret;
}
bool operator<(const bigint& rhs) const {
if (negative && !rhs.negative) return true;
if (!negative && rhs.negative) return false;
if (negative && rhs.negative) return -rhs < -*this;
if (res.size() ^ rhs.res.size()) return res.size() < rhs.res.size();
for (size_t i = res.size() - 1; ~i; i--) if (res[i] != rhs.res[i]) return res[i].val() < rhs.res[i].val();
return false;
}
bigint& operator*=(const bigint& rhs) {
res = (res * rhs.res);
negative ^= rhs.negative;
flatten(*this);
return *this;
}
bigint& operator<<=(size_t len) {
res.insert(res.begin(), len, mint(0));
return *this;
}
bigint& operator>>=(size_t len) {
if (len >= res.size()) {
res.clear();
return *this;
}
res = vector<mint>(res.begin() + len, res.end());
return *this;
}
inline bigint& operator+=(const bigint& b) {
if (negative && !b.negative) {
return *this = b - -*this;
}
if (negative && b.negative) {
*this = -*this + -b;
return *this = -*this;
}
if (!negative && b.negative) {
return *this -= -b;
}
res.resize(std::max(res.size(), b.res.size()));
for (size_t i = 0, iend = b.res.size(); i < iend; i++) res[i] += b.res[i];
flatten_add(*this);
return *this;
}
inline bigint& operator-=(const bigint& b) {
if (negative && !b.negative) {
*this = -*this + b;
return *this = -*this;
}
if (negative && b.negative) {
return *this = -b - -*this;
}
if (!negative && b.negative) {
return *this += -b;
}
if (*this < b) return *this = -(b - *this);
res.resize(std::max(res.size(), b.res.size()));
for (size_t i = 0, iend = b.res.size(); i < iend; i++) res[i] -= b.res[i];
flatten_sub(*this);
return *this;
}
bigint& operator/=(const bigint& b) {
negative ^= b.negative;
size_t m = res.size(), n = b.res.size();
vector<mint> rhs_p(b.res);
size_t offset;
if (m <= n << 1) offset = n << 1;
else {
offset = m + n;
rhs_p.insert(rhs_p.begin(), m - n, mint(0));
}
auto _res = (*this * div_inv_accurate(rhs_p)) >> offset;
auto ret = *this - _res * b;
flatten_sub(ret);
if (ret < b) ret = _res;
else {
ret = _res;
ret.res[0] += 1;
flatten_add(ret);
}
*this = ret;
return *this;
}
bigint& operator%=(const bigint& b) {
bigint quo = *this;
quo /= b;
*this -= b * quo;
flatten_sub(*this);
return *this;
}
bigint pow(size_t n) const {
bigint b(*this), ret(1);
while (n) {
if (n & 1) ret *= b;
b = b.sqr();
n >>= 1;
}
return ret;
}
bigint sqrt() const {
bigint ans(*this >> ((res.size() - 1) / 2));
flatten_add(ans);
long rem = 0;
while (true) {
bigint tmp = ans + *this / ans;
rem = 0;
for (size_t i = tmp.res.size() - 1; ~i; i--) {
auto [p, q] = std::div((long)tmp.res[i].val() + rem * base, 2l);
tmp.res[i] = p;
rem = q;
}
if (tmp.res.back() == 0) tmp.res.pop_back();
if (tmp.cmpabs(ans) >= 0) {
ans = tmp;
break;
}
ans = tmp;
}
do {
ans.res[0] += 1;
flatten_add(ans);
} while (cmpabs(ans.sqr()) >= 0);
ans.res[0] -= 1;
flatten_sub(ans);
ans.negative = negative;
return ans;
}
// floor(e^deg)
bigint exp(uint64_t deg) const {
bigint p, q;
recurse_exp(p, q, 0, deg);
std::string s = (p * bigint::div_inv_accurate(q)).output();
s[0] = '2';
p.input(s);
q = 1;
double digits = deg * M_LOG10E;
while (deg) {
if (deg & 1) q *= p;
p = p.sqr();
deg >>= 1;
}
s = q.output();
int index = digits;
if (abs(digits - round(digits)) < 0.1) {
index = digits;
if (s[0] == '1') index++;
}
p.input(s.substr(0, index + 1));
return p;
}
public:
friend bigint operator+(const bigint &lhs, const bigint &rhs) { return bigint(lhs) += rhs; }
friend bigint operator-(const bigint &lhs, const bigint &rhs) { return bigint(lhs) -= rhs; }
friend bigint operator*(const bigint &lhs, const bigint &rhs) { return bigint(lhs) *= rhs; }
friend bigint operator/(const bigint &lhs, const bigint &rhs) { return bigint(lhs) /= rhs; }
friend bigint operator%(const bigint &lhs, const bigint &rhs) { return bigint(lhs) %= rhs; }
friend bigint operator<<(const bigint &lhs, size_t rhs) { return bigint(lhs) <<= rhs; }
friend bigint operator>>(const bigint &lhs, size_t rhs) { return bigint(lhs) >>= rhs; }
friend bool operator==(const bigint &lhs, const bigint &rhs) { return lhs.res == rhs.res; }
friend bool operator!=(const bigint &lhs, const bigint &rhs) { return lhs.res != rhs.res; }
private:
static void flatten_add(bigint& a) {
vector<mint>& arr = a.res;
int inc(0);
for (size_t i = 0; i < arr.size(); i++) {
arr[i] += inc;
inc = arr[i].val() >= base;
if (inc) arr[i] -= base;
}
if (inc) arr.push_back(1);
}
static void flatten_sub(bigint &a) {
vector<mint>& arr = a.res;
int carry = 0;
for (size_t i = 0; i < arr.size(); i++) {
arr[i] -= carry;
carry = arr[i].val() >= (mint::mod() - 100 * base);
if (carry) arr[i] += base;
}
while (!arr.empty() && !arr.back().val()) arr.pop_back();
}
static bigint div_inv(const bigint& arr) {
size_t da = arr.res.size();
size_t da0 = (da >> 1) + 1;
if (da == 1) {
bigint tmp{base * base / arr.res[0].val()};
flatten(tmp);
return tmp;
}
else if (da == 2) {
bigint tmp{base * base * base * base / (arr.res[1].val() * base + arr.res[0].val())};
flatten(tmp);
return tmp;
}
else {
bigint tmp(div_inv(arr >> (da - da0)));
tmp = tmp << (da - da0);
bigint tem{2};
tem = tem << (da << 1);
tem -= arr * tmp;
flatten_sub(tem);
return (tmp * tem) >> (da << 1);
}
}
static bigint div_inv_accurate(const bigint& b) {
vector<bigint> t(7);
t[0] = b;
for (int i = 1; i ^ 7; ++i) {
t[i] = t[i - 1] + t[i - 1];
}
size_t n = b.res.size();
size_t err = 0;
auto _res = div_inv(b), diff = (bigint{1} << (n << 1)) - b * _res;
flatten_sub(diff);
for (int i = 6; i >= 0; --i) if (!(diff < t[i])) {
diff -= t[i];
flatten_sub(diff);
err |= 1u << i;
}
_res.res[0] += err;
flatten(_res);
return _res;
}
bigint sqr() const {
return bigint{::sqr(res)};
}
static void recurse_exp(bigint& p, bigint& q, size_t a, size_t b) {
if (b == a + 1) {
p = 1;
q = b;
return;
}
bigint p0, p1, q0, q1;
size_t m = (a + b) >> 1;
bigint::recurse_exp(p0, q0, a, m);
bigint::recurse_exp(p1, q1, m, b);
q = q0 * q1;
p = p0 * q1 + p1;
}
vector<mint> res;
bool negative;
};
void recurse(bigint& p, bigint& q, bigint& r, size_t a, size_t b) {
if (b == a + 1) {
p = -bigint(6*a - 5)*bigint(2*a - 1)*bigint(6*a - 1);
q = bigint(10939058860032000) * bigint(a) * bigint(a) * bigint(a);
r = p * bigint(545140134*a + 13591409);
return;
}
size_t m = (a + b) >> 1;
bigint p0, q0, r0, p1, q1, r1;
recurse(p0, q0, r0, a, m);
recurse(p1, q1, r1, m, b);
p = p0 * p1;
q = q0 * q1;
r = q1 * r0 + p0 * r1;
}
int main() {
get_rev();
size_t n = 100;
cin >> n;
bigint p, q, r;
recurse(p, q, r, 1, std::max(n / 10, (size_t)10));
cout << p.output() << '\n' << q.output() << '\n' << r.output() << '\n';
bigint a(10005);
a <<= std::max(n >> 1, (size_t)10);
#if 1
a = a.sqrt() * 426880;
#else // 16993ms ^^^ / vvv 39ms
mpz_t num;
mpz_init_set_str(num, a.output().c_str(), 10);
mpz_sqrt(num, num);
char* buf = mpz_get_str(NULL, 10, num);
mpz_clear(num);
a.input(buf);
a *= 426880;
free(buf);
#endif
a = (a * q) / (q * 13591409 + r);
std::string s = a.output().substr(1);
cout << "3.\n";
for (size_t i = 0; i < n; i++) {
cout << s[i];
if (i % 10 == 9) cout << ' ';
if (i % 50 == 49) cout << '\n';
}
cout << '\n';
}
Credits:
mint
来自 hly1204namespace ntt
内的所有内容来自 Atcoder Library