初识线段树
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,实际应用时一般还要开4N的数组以免越界,因此有时需要离散化让空间压缩。
题目一:
现在有100000个正整数,编号从1到100000。
现给定一个区间[L,R]。
求得区间L到R的总和为多少
方法一:直接for(int i=L;i<=R;i++)来遍历100000个数字,全部加起来
方法二:通过求取前缀和来简化计算,另前缀和数组为B[100010],那么结果就是B[R]-B[L-1]就是结果,不难看出来,方法二比方法一更加快
题目二:
现在有100000个正整数,编号从1到100000。
现给定一个区间[L,R]和一个正整数k,c。
将第k个数加上c之后,对区间L到R求其总和
如果继续使用方法一它的时间复杂度是不会变化的。
但对于方法二来说,加了一个数之后,它的前缀和数组就要发生改变了,假如k=10,那么[10,100000]这整段区间的前缀和全部都需要修改,这就会大大降低计算速度
从上面的两个例子可以看出来
方法一:求和慢,但修改很快
方法二:求和快,但求和很慢
那么有没有一种方法可以兼顾这两种方法的优点呢,求和以及修改都快,这就是这篇要介绍的线段树了,线段数的插入的时间复杂度都是logN
线段树的划分
线段树是一颗二叉树,给定一个区间[L,R]之后,我们不断将区间平分,直到L==R
。
如何定义一个线段树
由图可知,线段树是由很多个区间组成的,每一个区间都记录了区间的左端点
和右端点
,以及区间内的数值之和,所以我们需要定义一个结构体
struct node { int l, r; int sum; }tr[4*N];
数组大小需要开四倍,原因就不证明了,先记住即可
如何计算每个区间的值呢?
自下而上计算
可以从线段树的叶子节点(只有自己的节点),比如区间[1,2]可以通过计算node[i].l+node[i].r(1+2)。
从下往上依次计算。
void push_up(int u) { tr[u].sum = tr[2 * u].sum + tr[2 * u + 1].sum;//2*u为左儿子,2*u+1为右儿子 }
如何建立起一个线段树呢?
void build(int u, int l, int r) { if (l == r) tr[u] = { l,r ,w[l]};//如果达到了叶子节点,就赋值 else { tr[u] = { l,r };//没有到达叶子节点,就先记录下当前区间的左端点和右端点 int mid = l + r >> 1;//将区间平分 build(2 * u, l, mid);//递归左儿子 build(2 * u + 1, mid + 1, r);//递归右儿子 push_up(u);//回溯的时候依次通过左右儿子算得sum } }
如何对某个值进行修改呢?
void modify(int u, int x, int v) { if (tr[u].l == tr[u].r)//递归到了叶子节点的时候 { tr[u].sum += v; return; } else { int mid = (tr[u].l + tr[u].r) / 2; if (x <= mid) modify(u * 2, x, v);//如果当前序列在左边,那么就递归左区间 else modify(u * 2 + 1, x, v);//在右边就递归右区间 push_up(u);//修改了之后,还要需要修改一些节点的值,重新自下而上计算 } }
如何求得某个区间的和呢?
需要设计到的区间有[4],[5,6],[7,8],[9,10],[11]。
int query(int u, int l, int r) { //需要累加所有在这个范围内的区间 if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum; //否则的话就需要递归计算 int mid = (tr[u].l + tr[u].r) / 2; int sum = 0; if (mid >= l) sum += query(u*2, l, r);//如果左区间和要求的区间有交集,那么递归左区间 if (r >= mid + 1) sum += query(u * 2 + 1, l, r);//如果右区间和要求的区间有交集,那么递归右区间 return sum; }
经典例题:
AC代码:
#include<iostream> using namespace std; const int N = 100010; int n, m; int w[N];//权值 //定义线段树节点 struct node { int l, r; int sum; }tr[4*N];//要开四倍大小 //向上累加 void push_up(int u) { tr[u].sum = tr[2 * u].sum + tr[2 * u + 1].sum; } //建树 void build(int u, int l, int r) { if (l == r) tr[u] = { l,r ,w[l]};//如果达到了叶子节点,就赋值 else { tr[u] = { l,r };//没有到达叶子节点,就先记录下当前区间的左端点和右端点 int mid = l + r >> 1;//将区间平分 build(2 * u, l, mid);//递归左儿子 build(2 * u + 1, mid + 1, r);//递归右儿子 push_up(u);//回溯的时候依次通过左右儿子算得sum } } //区间查询 int query(int u, int l, int r) { //需要累加所有在这个范围内的区间 if (l <= tr[u].l && r >= tr[u].r) return tr[u].sum; //否则的话就需要递归计算 int mid = (tr[u].l + tr[u].r) / 2; int sum = 0; if (mid >= l) sum += query(u*2, l, r);//如果左区间和要求的区间有交集,那么递归左区间 if (r >= mid + 1) sum += query(u * 2 + 1, l, r);//如果右区间和要求的区间有交集,那么递归右区间 return sum; } //修改 void modify(int u, int x, int v) { if (tr[u].l == tr[u].r)//递归到了叶子节点的时候 { tr[u].sum += v; return; } else { int mid = (tr[u].l + tr[u].r) / 2; if (x <= mid) modify(u * 2, x, v);//如果当前序列在左边,那么就递归左区间 else modify(u * 2 + 1, x, v);//在右边就递归右区间 push_up(u);//修改了之后,还要需要修改一些节点的值,重新自下而上计算 } } int main(void) { cin >> n >> m; for (int i = 1; i <= n; i++) scanf("%d", &w[i]); build(1, 1, n); while (m--) { int k, a, b; cin >> k >> a >> b; if (k == 0) cout << query(1, a, b) << endl; else { modify(1, a, b); } } return 0; }
没有完全AC代码(太慢了):
#include<iostream> #include<algorithm> using namespace std; const int N = 100010; int w[N]; int n, m; struct node { int l,r; int maxv; }tr[N*4]; void push_up(int u) { tr[u].maxv = max(tr[u * 2].maxv, tr[u * 2 + 1].maxv); } void build(int u, int l, int r) { if (l == r) { tr[u] = { l,r,w[l] }; return; } else { tr[u] = { l,r}; int mid = (l + r) >> 1; build(u * 2, l, mid); build(u * 2 + 1, mid+1, r); push_up(u); } } int query(int u, int l, int r) { if (tr[u].l >= l && tr[u].r <= r) return tr[u].maxv; int mid = (tr[u].l + tr[u].r) / 2; int maxv = -10000000; if (l <= mid) maxv = max(maxv, query(u * 2, l, r)); if (r > mid + 1) maxv = max(maxv, query(u * 2 + 1, l, r)); return maxv; } int main() { int l, r; scanf("%d %d", &n, &m); for (int i = 1; i <= n; ++i) scanf("%d", &w[i]); build(1, 1, n); while (m--) { scanf("%d %d", &l, &r); printf("%d\n", query(1, l, r)); } return 0; }