同学出的板子题,结果90分RE了,求调qwq
代码如下:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5 + 5;
int n, t, lc, rc;
struct node {
int l, r, cnt;
ll sum, min, sec, add1, add2;
} tr[N << 2];
inline void pushup (int u) {
lc = u << 1, rc = u << 1 | 1;
tr[u].sum = tr[lc].sum + tr[rc].sum;
if (tr[lc].min < tr[rc].min) tr[u].min = tr[lc].min, tr[u].cnt = tr[lc].cnt, tr[u].sec = min(tr[lc].sec, tr[rc].min);
else if (tr[lc].min > tr[rc].min) tr[u].min = tr[rc].min, tr[u].cnt = tr[rc].cnt, tr[u].sec = min(tr[rc].sec, tr[lc].min);
else tr[u].min = tr[lc].min, tr[u].cnt = tr[lc].cnt + tr[rc].cnt, tr[u].sec = min(tr[lc].sec, tr[rc].sec);
}
inline void change (long long k1, long long k2, int u) {
if (tr[u].sec != LLONG_MAX) tr[u].sec += k2;
tr[u].sum += k1 * tr[u].cnt + k2 * (tr[u].r - tr[u].l + 1 - tr[u].cnt), tr[u].min += k1;
tr[u].add1 += k1, tr[u].add2 += k2;
}
inline void pushdown (int u) {
lc = u << 1, rc = u << 1 | 1;
long long tmp = min(tr[lc].min, tr[rc].min);
if (tr[lc].min == tmp) change(tr[u].add1, tr[u].add2, lc);
else change(tr[u].add2, tr[u].add2, lc);
if (tr[rc].min == tmp) change(tr[u].add1, tr[u].add2, rc);
else change(tr[u].add2, tr[u].add2, rc);
tr[u].add1 = tr[u].add2 = 0;
}
void build (int u, int l, int r) {
tr[u].l = l, tr[u].r = r;
tr[u].add1 = tr[u].add2 = 0, tr[u].sec = LLONG_MAX;
if (l == r) {
scanf("%lld", &tr[u].sum);
tr[u].min = tr[u].sum, tr[u].cnt = 1;
return;
}
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify_add (int u, int l, int r, long long k) {
if (l <= tr[u].l && r >= tr[u].r) {
tr[u].sum += k * (tr[u].r - tr[u].l + 1), tr[u].min += k;
if (tr[u].sec != LLONG_MAX) tr[u].sec += k;
tr[u].add1 += k, tr[u].add2 += k;
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify_add(u << 1, l, r, k);
if (r > mid) modify_add(u << 1 | 1, l, r, k);
pushup(u);
}
void modify_max (int u, int l, int r) {
if (tr[u].min >= 0) return;
if (l <= tr[u].l && r >= tr[u].r && tr[u].sec > 0) {
tr[u].add1 -= tr[u].min, tr[u].sum -= tr[u].min * tr[u].cnt, tr[u].min = 0;
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) modify_max(u << 1, l, r);
if (r > mid) modify_max(u << 1 | 1, l, r);
pushup(u);
}
ll query (int u, int l, int r) {
if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
ll res = 0;
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res += query(u << 1 | 1, l, r);
return res;
}
int main () {
ll k;
scanf("%d%d", &n, &t), build(1, 1, n);
for (int i = 1, op, l, r; i <= t; i++) {
scanf("%d%d%d", &op, &l, &r);
if (op == 1) printf("%lld\n", query(1, l, r));
else if (op == 2) scanf("%lld", &k), modify_add(1, l, r, k);
else if (op == 3) scanf("%lld", &k), modify_add(1, l, l, -k), modify_max(1, l, l), modify_add(1, r, r, k);
else scanf("%lld", &k), modify_add(1, l, r, -k), modify_max(1, l, r);
}
return 0;
}
感谢