全T求助
查看原帖
全T求助
1068414
pystraf11楼主2025/1/28 10:06
// Problem: P4681 [THUSC 2015] 平方运算
// Contest: Luogu
// URL: https://www.luogu.com.cn/problem/P4681
// Memory Limit: 250 MB
// Time Limit: 2000 ms
// 
// Powered by CP Editor (https://cpeditor.org)

#include <bits/stdc++.h>
using namespace std;

using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
using ui128 = unsigned __int128;
using f4 = float;
using f8 = double;
using f16 = long double;

template<class T>
bool chmax(T &a, const T &b){
	if(a < b){ a = b; return true; }
	return false;
}

template<class T>
bool chmin(T &a, const T &b){
	if(a > b){ a = b; return true; }
	return false;
}

struct Node {
    int l, r;
    bool cycle;
    int now, tag;
    array<i64, 60> sum;
};
inline int ls(int u) { return 2 * u + 1; }
inline int rs(int u) { return 2 * u + 2; }

struct SegTree {
    vector<Node> tr;
    vector<int> P, vis;
    int M, mod;
    
    inline SegTree() {}
    inline SegTree(const vector<int>& a, int _mod):
                   P(_mod), vis(_mod), M(1), mod(_mod) {
        for (int i = 0; i < mod; i++) get_loop(i);
        for (int i = 0; i < mod; i++)
            if (P[i] != 0) M = lcm(M, P[i]);
        
        const int n = a.size();
        tr.resize(n << 2);
        build(0, 0, n - 1, a);
    }
    
    inline void get_loop(int x) {
        for (int i = 0, y = x; ; i++, y = y * y % mod) {
            if (vis[y] != -1) {
                P[y] = i - vis[y];
                break;
            }
            else vis[y] = i;
        }
        for (int y = x; vis[y]; y = y * y % mod) vis[y] = -1;
    }
    
    inline void upd(int u) {
        if (P[tr[u].sum[0]] != 0) {
			for (int i = 1; i < M; i++) 
			    tr[u].sum[i] = tr[u].sum[i - 1] * tr[u].sum[i - 1] % mod;
			tr[u].now = 0;
			tr[u].cycle = 1;
		}
		else tr[u].now = tr[u].cycle = 0;
    }
    
    inline void apply(int u, int k) {
        tr[u].tag = (tr[u].tag + k) % M;
        tr[u].now = (tr[u].now + k) % M;
    }
    
    inline void pushup(int u) {
        tr[u].cycle = tr[ls(u)].cycle && tr[rs(u)].cycle;
        tr[u].now = 0;
        if (!tr[u].cycle)
            tr[u].sum[0] = tr[ls(u)].sum[tr[ls(u)].now] + tr[rs(u)].sum[tr[rs(u)].now];
        else {
            int nowL = tr[ls(u)].now, nowR = tr[rs(u)].now;
            for (int i = 0; i < M; i++) {
                tr[u].sum[i] = tr[ls(u)].sum[nowL] + tr[rs(u)].sum[nowR];
                nowL = (nowL + 1) % M;
                nowR = (nowR + 1) % M;
            }
        }
    }
    
    inline void pushdown(int u) {
        if (tr[u].tag) {
            apply(ls(u), tr[u].tag);
            apply(rs(u), tr[u].tag);
            tr[u].tag = 0;
        }
    }
    
    inline void build(int u, int l, int r, const vector<int>& a) {
        tr[u].l = l;
        tr[u].r = r;
        if (l == r) {
            tr[u].sum[0] = a[l];
            tr[u].tag = 0;
            return upd(u);
        }
        const int mid = (l + r) >> 1;
        build(ls(u), l, mid, a);
        build(rs(u), mid + 1, r, a);
        pushup(u);
    }
    
    inline void square(int u, int l, int r) {
        if (l <= tr[u].l && tr[u].r <= r && tr[u].cycle) return apply(u, 1);
        if (tr[u].l == tr[u].r) {
            tr[u].sum[0] = tr[u].sum[0] * tr[u].sum[0] % mod;
            return upd(u);
        }
        const int mid = (tr[u].l + tr[u].r) >> 1;
        pushdown(u);
        if (l <= mid) square(ls(u), l, r);
        if (r > mid) square(rs(u), l, r);
        pushup(u);
    }
    
    inline i64 query(int u, int l, int r) {
        if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum[tr[u].now];
        const int mid = (tr[u].l + tr[u].r) >> 1;
        i64 res = 0;
        pushdown(u);
        if (l <= mid) res += query(ls(u), l, r);
        if (r > mid) res += query(rs(u), l, r);
        return res;
    }
};

signed main() {
	ios::sync_with_stdio(0);
	cin.tie(0), cout.tie(0);
	
	int n, m, p;
	scanf("%d %d %d", &n, &m, &p);
	vector<int> a(n);
	for (int i = 0; i < n; i++) scanf("%d", &a[i]);
	SegTree seg(a, p);
	for (int i = 0, op, l, r; i < m; i++) {
	    scanf("%d %d %d", &op, &l, &r);
	    l--, r--;
	    if (op == 1) seg.square(0, l, r);
	    else printf("%lld\n", seg.query(0, l, r));
	}
	
	return 0;
}
2025/1/28 10:06
加载中...