引入
众所周知,普通的线段树,我们一般开 4 n 的数组以避免越界。然而,当一些题目n 很大的时候,如果正常的去建一颗线段树,开 4 n倍的空间,会超出空间限制。
普通的线段树是开一个满二叉树来建树的,这也就意味着很多情况下线段树的节点并没有被完全使用,造成了极大的空间浪费。
这时候,就需要动态开点线段树。
动态开点线段树
普通的线段树左右儿子表示方法:
左儿子:u<<1
右儿子:u<<1|1
这样我们可以通过父节点u直接计算出左右儿子,但是这极大的浪费了空间。所以我们考虑先不建树,等到用到没有开的节点时,再将这个节点加入树中。这样的话我们就可以减少空间开销,从而优化空间复杂度。
但因为是动态开点,所以节点本身是无序的,其排列顺序仅与遍历顺序有关。在这里,我们采用链式存储法,即对一个节点建立左右儿子指针。这样,建立新的节点时,就不会浪费多余的空间,我们只需要开 2 n 的数组即可,也就是我们一个满二叉树的节点的个数。
作用
一般作用:
节约空间
用于主席树中
实现代码
P3372 【模板】线段树 1
struct
因为递归方式是固定的,所以动态开点线段树的结构体不需要加上区间限制。
struct node { int l, r; // 指向左右儿子 int add, sum; } tr[N << 1];
build
if(!p) p = ++idx
: 即为动态开点核心代码
void build(int &p, int l, int r) { if(!p) p = ++idx; if(l == r) { tr[p].sum = a[l]; return ;} int mid = l + r >> 1; build(tr[p].l, l, mid), build(tr[p].r, mid + 1, r); pushup(p); }
modify
l,r:是当前节点 p所维护的区间。
ql,qr:是询问的区间。
void modify(int &p, int l, int r, int ql, int qr, int k) { if(!p) p = ++idx; if(l >= ql && r <= qr) { tr[p].sum += (r - l + 1) * k; tr[p].add += k; return ; } pushdown(p, l, r); int mid = l + r >> 1; if(ql <= mid) modify(tr[p].l, l, mid, ql, qr, k); if(qr > mid) modify(tr[p].r, mid + 1, r, ql, qr, k); pushup(p); }
query
int query(int p, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[p].sum; } int mid = l + r >> 1; pushdown(p, l, r); int v = 0; if(ql <= mid) v = query(tr[p].l, l, mid, ql, qr); if(qr > mid) v += query(tr[p].r, mid + 1, r, ql, qr); return v; }
pushup
void pushup(int p) { tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum; }
pushdown
void pushdown(int p, int l, int r) { if(tr[p].add) { int mid = l + r >> 1; tr[tr[p].l].sum += (mid - l + 1) * tr[p].add, tr[tr[p].l].add += tr[p].add; tr[tr[p].r].sum += (r - mid) * tr[p].add, tr[tr[p].r].add += tr[p].add; tr[p].add = 0; } }
完整代码:
- 直接建全树
#include <bits/stdc++.h> using namespace std; const int N = 100010; #define int long long struct node { int l, r; int add, sum; } tr[N << 1]; // 正常线段树,这里不开4倍大小会RE int n, m, idx, root; int a[N]; void pushup(int p) { tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum; } void pushdown(int p, int l, int r) { if(tr[p].add) { int mid = l + r >> 1; tr[tr[p].l].sum += (mid - l + 1) * tr[p].add, tr[tr[p].l].add += tr[p].add; tr[tr[p].r].sum += (r - mid) * tr[p].add, tr[tr[p].r].add += tr[p].add; tr[p].add = 0; } } void build(int &p, int l, int r) { if(!p) p = ++idx; if(l == r) { tr[p].sum = a[l]; return ;} int mid = l + r >> 1; build(tr[p].l, l, mid), build(tr[p].r, mid + 1, r); pushup(p); } void modify(int &p, int l, int r, int ql, int qr, int k) { if(!p) p = ++idx; if(l >= ql && r <= qr) { tr[p].sum += (r - l + 1) * k; tr[p].add += k; return ; } pushdown(p, l, r); int mid = l + r >> 1; if(ql <= mid) modify(tr[p].l, l, mid, ql, qr, k); if(qr > mid) modify(tr[p].r, mid + 1, r, ql, qr, k); pushup(p); } int query(int p, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[p].sum; } int mid = l + r >> 1; pushdown(p, l, r); int v = 0; if(ql <= mid) v = query(tr[p].l, l, mid, ql, qr); if(qr > mid) v += query(tr[p].r, mid + 1, r, ql, qr); return v; } signed main() { cin >> n >> m; for (int i = 1; i <= n; i++) cin >> a[i]; build(root, 1, n); int op, x, y, k; while(m--) { cin >> op >> x >> y; if(op == 1) { cin >> k; modify(root, 1, n, x, y, k); } else { cout << query(root, 1, n, x, y) << endl; } } return 0; }
- 也可以动态插入节点
#include <bits/stdc++.h> using namespace std; const int N = 100010; #define int long long struct node { int l, r; int add, sum; } tr[N << 1]; int n, m, idx, root; void pushup(int p) { tr[p].sum = tr[tr[p].l].sum + tr[tr[p].r].sum; } void pushdown(int p, int l, int r) { if(tr[p].add) { int mid = l + r >> 1; tr[tr[p].l].sum += (mid - l + 1) * tr[p].add, tr[tr[p].l].add += tr[p].add; tr[tr[p].r].sum += (r - mid) * tr[p].add, tr[tr[p].r].add += tr[p].add; tr[p].add = 0; } } void modify(int &p, int l, int r, int ql, int qr, int k) { if(!p) p = ++idx; if(l >= ql && r <= qr) { tr[p].sum += (r - l + 1) * k; tr[p].add += k; return ; } pushdown(p, l, r); int mid = l + r >> 1; if(ql <= mid) modify(tr[p].l, l, mid, ql, qr, k); if(qr > mid) modify(tr[p].r, mid + 1, r, ql, qr, k); pushup(p); } int query(int p, int l, int r, int ql, int qr) { if(l >= ql && r <= qr) { return tr[p].sum; } int mid = l + r >> 1; pushdown(p, l, r); int v = 0; if(ql <= mid) v = query(tr[p].l, l, mid, ql, qr); if(qr > mid) v += query(tr[p].r, mid + 1, r, ql, qr); return v; } signed main() { cin >> n >> m; for (int i = 1; i <= n; i++) { int x; cin >> x; modify(root, 1, n, i, i, x); } int op, x, y, k; while(m--) { cin >> op >> x >> y; if(op == 1) { cin >> k; modify(root, 1, n, x, y, k); } else { cout << query(root, 1, n, x, y) << endl; } } return 0; }