维护区间加减和区间查询):
模板一:
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 10; int n, m; typedef long long ll; ll ans[N << 2], a[N], tag[N << 2]; inline int ls(int x) { return x << 1; } inline int rs(int x) { return x << 1 | 1; } void pushdown(int p, int l, int r) { int mid = l + r >> 1; ans[ls(p)] += tag[p] * (mid - l + 1); ans[rs(p)] += tag[p] * (r - mid); tag[ls(p)] += tag[p]; tag[rs(p)] += tag[p]; tag[p] = 0; } void pushup(int p) { ans[p] = ans[ls(p)] + ans[rs(p)]; } void build(int p, int l, int r) { tag[p] = 0; if (l == r) { ans[p] = a[l]; return; } int mid = l + r >> 1; build(ls(p), l, mid); build(rs(p), mid + 1, r); pushup(p); } void update(int nl, int nr, int l, int r, int p, int k) { if (nl <= l && nr >= r) { tag[p] += k; ans[p] += k * (r - l + 1); return; } pushdown(p, l, r); int mid = (l + r) >> 1; if (nl <= mid) update(nl, nr, l, mid, ls(p), k); if (nr > mid) update(nl, nr, mid + 1, r, rs(p), k); pushup(p); } ll query(int nl, int nr, int l, int r, int p) { ll res = 0; if (nl <= l && nr >= r) return ans[p]; int mid = (l + r) >> 1; pushdown(p, l, r); if (nl <= mid) res += query(nl, nr, l, mid, ls(p)); if (nr > mid) res += query(nl, nr, mid + 1, r, rs(p)); return res; } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++) { cin >> a[i]; } build(1, 1, n); while (m--) { int x, y, z, k; scanf("%d", &z); if (z == 1) { scanf("%d%d%d", &x, &y, &k); update(x, y, 1, n, 1, k); } else { scanf("%d%d", &x, &y); ll res = query(x, y, 1, n, 1); printf("%lld\n", res); } } }
线段树模板二:
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 10; int tag[N * 4]; struct node { int l; int r; long long w; } tr[N * 4]; int n, m; int a[N]; int x, y, k; int s; void pushup(int p) { tr[p].w = tr[p << 1].w + tr[p << 1 | 1].w; } void pushdown(int p) { int mid = tr[p].r + tr[p].l >> 1; tr[p << 1].w += tag[p] * (mid - tr[p].l + 1); tr[p << 1 | 1].w += tag[p] * ( tr[p].r - mid); tag[p << 1] += tag[p]; tag[p << 1 | 1] += tag[p]; tag[p] = 0; } void build(int p, int l, int r) { tr[p] = {l, r, 0}; if (l == r) { tr[p].w = a[l]; return ; } int mid = tr[p].l + tr[p].r >> 1; build(p * 2, l, mid); build(p * 2 + 1, mid + 1, r); pushup(p); return ; } void modify(int p, int x, int y, int k) { if (tr[p].l >= x && tr[p].r <= y) { tr[p].w += k * ( (tr[p].r - tr[p].l + 1)); tag[p] += k; return; } pushdown(p); int mid = tr[p].r + tr[p].l >> 1; if (x <= mid) modify(p << 1, x, y, k); if (y > mid) modify(p << 1 | 1, x, y, k); pushup(p); } long long query(int p, int x, int y) { if (tr[p].l >= x && tr[p].r <= y) { return tr[p].w; } pushdown(p); int mid = tr[p].l + tr[p].r >> 1; long long t = 0; if (x <= mid) t += query(p << 1, x, y); if (y > mid) t += query(p << 1 | 1, x, y); return t; } int main() { cin >> n >> m; for (int i = 1; i <= n; i++) cin >> a[i]; build(1, 1, n); for (int i = 1; i <= m; i++) { cin >> s; if (s == 1) { cin >> x >> y >> k; modify(1, x, y, k); } else { cin >> x >> y; cout << query(1, x, y) << endl; } } }
2.线段树模板(同时支持区间加法和乘法,以及求区间和):
#include<cstdio> #include<iostream> #include<cstring> #include<string> #include<algorithm> #include<iomanip> #include<vector> #include<queue> #include<map> #include<unordered_map> #include<set> #include<stack> #include<utility> #include<cstdlib> #include<cmath> using namespace std; const int N = 1e5 + 10; typedef long long ll; struct node { ll l; ll r; ll add; ll mul; ll w; } tr[N * 4]; ll n, m, p; ll a[N]; inline void pushup(ll u) { tr[u].w = (tr[u << 1].w + tr[u << 1 | 1].w) % p; } inline void pushdown(ll u) { tr[u << 1].w = (tr[u].mul * tr[u << 1].w + tr[u].add * (tr[u << 1].r - tr[u << 1].l + 1)) % p; tr[u << 1 | 1].w = (tr[u].mul * tr[u << 1 | 1].w + tr[u].add * (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1)) % p; tr[u << 1].mul = (tr[u].mul * tr[u << 1].mul) % p; tr[u << 1 | 1].mul = (tr[u].mul * tr[u << 1 | 1].mul) % p; tr[u << 1].add = (tr[u << 1].add * tr[u].mul + tr[u].add) % p; tr[u << 1 | 1].add = (tr[u << 1 | 1].add * tr[u].mul + tr[u].add) % p; tr[u].mul = 1; tr[u].add = 0; return; } void build(ll u, ll l, ll r) { tr[u] = { l, r, 0, 1, 0 }; if (l == r) { tr[u].w = a[l] % p; return; } int mid = l + r >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); pushup(u); tr[u].w %= p; return; } void up1(ll u, ll l, ll r, ll k) { if (tr[u].l >= l && tr[u].r <= r) { tr[u].w = (tr[u].w * k) % p; tr[u].mul = (tr[u].mul * k) % p; tr[u].add = (tr[u].add * k) % p; return; } pushdown(u); pushup(u); int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) up1(u << 1, l, r, k); if (r > mid) up1(u << 1 | 1, l, r, k); pushup(u); } void up2(ll u, ll l, ll r, ll k) { if (tr[u].l >= l && r >= tr[u].r) { tr[u].add = (tr[u].add + k) % p; tr[u].w = (tr[u].w + k * (tr[u].r - tr[u].l + 1)) % p; return; } pushdown(u); pushup(u); int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) up2(u << 1, l, r, k); if (r > mid) up2(u << 1 | 1, l, r, k); pushup(u); return; } ll query(ll u, ll l, ll r, ll p) { if (tr[u].l >= l && tr[u].r <= r) { return tr[u].w % p; } pushdown(u); ll ans = 0; int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) ans = (query(u << 1, l, r, p)) % p; if (r > mid) ans = (ans + query(u << 1 | 1, l, r, p)) % p; return ans % p; } int main() { scanf("%lld%lld%lld", &n, &m, &p); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]); build(1, 1, n); for (int i = 1; i <= m; i++) { long long x, y, k; ll op; scanf("%lld", &op); if (op == 1) { scanf("%lld%lld%lld", &x, &y, &k); up1(1, x, y, k); } else if (op == 2) { scanf("%lld%lld%lld", &x, &y, &k); up2(1, x, y, k); } else { scanf("%lld%lld", &x, &y); printf("%lld\n", query(1, x, y, p)); } } return 0; }
/********************************************************************* 程序名: 版权: 作者: Joecai 日期: 2022-05-16 13:57 说明: *********************************************************************/ #include <bits/stdc++.h> using namespace std; #define x first #define y second # define rep(i,be,en) for(int i=be;i<=en;i++) # define pre(i,be,en) for(int i=be;i>=en;i--) #define ll long long #define endl "\n" #define LOCAL #define pb push_back #define int long long typedef pair<ll, ll> PII; #define eb emplace_back #define sp(i) setprecision(i) const int N = 2e5 + 10, INF = 0x3f3f3f3f; int n, q; struct node { int l; int r; int sum; int tag; } tr[N << 2]; int a[N]; void pushup(int u) { tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum; } void pushdown(int u) { tr[u << 1].sum = (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].tag; tr[u << 1 | 1].sum = (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].tag; tr[u << 1].tag = tr[u].tag; tr[u << 1 | 1].tag = tr[u].tag; tr[u].tag = 0; } void build(int u, int l, int r) { tr[u] = {l, r}; if (l == r) { tr[u].sum = a[l]; return; } int mid = l + r >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); pushup(u); } void update(int u, int l, int r, int s) { if (tr[u].r <= r && tr[u].l >= l) { tr[u].sum = (tr[u].r - tr[u].l + 1) * s; tr[u].tag = s; return; } if (tr[u].tag) pushdown(u); int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) update(u << 1, l, r, s); if (r > mid) update(u << 1 | 1, l, r, s); pushup(u); } int query(int u, int l, int r) { if (tr[u].r <= r && tr[u].l >= l) { return tr[u].sum; } if (tr[u].tag) pushdown(u); int sum = 0; int mid = tr[u].r + tr[u].l >> 1; if (l <= mid) sum += query(u << 1, l, r); if (r > mid) sum += query(u << 1 | 1, l, r); return sum; } void solve() { cin >> n >> q; for (int i = 1; i <= n; i++) { cin >> a[i]; } build(1, 1, n); for (int i = 1; i <= q; i++) { int op, x, y; cin >> op; if (op == 1) { cin >> x >> y; update(1, x, x, y); cout << query(1, 1, n) << endl; } else { cin >> y; update(1, 1, n, y); cout << query(1, 1, n) << endl; } } } signed main() { std::ios::sync_with_stdio(false); std::cin.tie(nullptr); //#ifdef LOCAL //freopen("data.in.txt","r",stdin); //freopen("data.out.txt","w",stdout); //#endif int __ = 1; //cin>>__; while (__--) { solve(); } return 0; } /* 5 5 1 2 3 4 5 1 1 5 2 10 1 5 11 1 4 1 2 1 19 50 51 42 5 */