题目描述
现在请求你维护一个数列,要求提供以下两种操作:
1、 查询操作。
语法:Q L
功能:查询当前数列中末尾 LL 个数中的最大的数,并输出这个数的值。
限制: L 不超过当前数列的长度。(L>0)
2、 插入操作。
语法:A n
功能:将 n 加上 t,其中 t 是最近一次查询操作的答案(如果还未执行过查询操作,则 t=0),并将所得结果对一个固定的常数 D 取模,将所得答案插入到数列的末尾。
限制:n 是整数(可能为负数)并且在长整范围内。
注意:初始时数列是空的,没有一个数。
输入描述
第一行两个整数,M 和 D,其中 M 表示操作的个数,D 如上文中所述。
接下来的 M 行,每行一个字符串,描述一个具体的操作。语法如上文所述。
其中,1≤M≤2×10^5,1≤D≤2×10^9。
输出描述
对于每一个查询操作,你应该按照顺序依次输出结果,每个结果占一行。
输入输出样例
示例 1
输入
1. 5 100 2. A 96 3. Q 1 4. A 97 5. Q 1 6. Q 2
输出
1. 96 2. 93 3. 96
运行限制
- 最大运行时间:1s
- 最大运行内存: 128M
思路:
线段树原理
我们以 {1,4,5,8,6,2,3,9,10,7}为例,讲解线段树的原理。首先用一颗满二叉树实现线段树,用于查询任意子区间的最小值。如图
每个结点上圆圈内的数字是这棵子树的最小值。圆圈旁边的数字,例如根结点的"1:[1,10]",1 表示结点的编号,[1,10] 是这个结点代表的元素范围,即第 1 到第 10 个元素
那我们来说说查询任意区间 [i, j]的最小值。例如查区间 [4,9] 的最小值,递归查询到区间 [4, 5]、[6,8]、[9,9],见图中画横线的线段,得最小值min{6, 2, 10} = 2
我们编码时可以使用标准的二叉树数据结构。
用数组 tree[] 实现一棵满二叉树。每个结点 x 的左右儿子是:
- 左儿子:p<<1,即p×2。例如,根结点 tree[1]的左儿子是tree[2],结点 tree[12] 的左儿子是 tree[24]
- 右儿子:p<<1|1,即 p×2+1。例如根结点tree[1] 的右儿子是tree[3],结点tree[12] 的左儿子是 tree[25]。
当有 N个数时,需要把二叉树的空间开到 4N 大,假设有一棵处理 n 个元素(叶子结点有 n 个)的线段树,且它的最后一层只有 1 个叶子,其他层都是满的;如果用满二叉树表示,它的结点总数是:最后一层有 2n 个结点(其中2n−1 个都浪费了没用到),而前面所有的层有 2n 个结点,加起来共 4n 个结点。
具体题目
回到本题中,我们以python代码为例子讲解。设计三个函数来实现线段是的功能:
1. build(p,l,r):
我们通过此函数建立一颗空树,其中p是tr[p],它代表区间【l,r】。buiild()是一个递归函数,递归到最底的叶子结点,赋初始值tree[p] = 0。
建树用二分法,从根结点开始逐层二分到叶子结点。此外我们通过下面这行代码,实现了从底往上的值的返回
tr[p].sum=max(tr[p<<1].sum,tr[p<<1|1].sum)
2.modify(x,y,z,p,l,r):
我们通过这个函数实现区间【x,y】的更新,p 表示结点tr[p],l 是左子树,r 是右子树。在这道题中我们使得[x,y]为【cnt,cnt】从而实现了对新增结点的赋值为(z+t)
3.query(x,y,p,l,r):
查询函数,查找区间【x,y】内的最大值,会有四种情况:
- 如果这棵子树完全被 [L, R]覆盖,也就是说这棵子树在要查询的区间之内,那么直接返回 tree[p]的值。见下列代码的 3- 4 行。这一步体现了线段树的高效率。
- 如果不能覆盖,那么需要把这棵子树二分,再继续下面两步的查询。
- 如果 L 与左部分有重叠。查询【l,mid】部分
- 如果 R与右部分右重叠。查询【mid+1,r】部分
query也是递归函数
注:python代码中的函数上面的注释说明的其参数和c++参数的对应关系
c++代码:
1. #include<bits/stdc++.h> 2. using namespace std; 3. const int N = 200001; 4. const int INF = 0X7FFFFFFF; 5. 6. int ls(int p){ return p<<1; } //左儿子,编号是 p*2 7. int rs(int p){ return p<<1|1;} //右儿子,编号是 p*2+1 8. 9. int tree[N<<2]; //4倍空间 10. 11. void push_up(int p){ //从下往上传递区间值 12. //tree[p] = tree[ls(p)] + tree[rs(p)]; //区间和 13. tree[p] = max(tree[ls(p)], tree[rs(p)]); //区间最大值 14. } 15. void build(int p,int pl,int pr){ //结点编号p指向区间[pl, pr] 16. if(pl==pr){ //到达最底层的叶子,存叶子的值 17. tree[p] = -INF; 18. return; 19. } 20. int mid = (pl+pr) >> 1; //分治:折半 21. build(ls(p),pl,mid); //递归左儿子 22. build(rs(p),mid+1,pr); //递归右儿子 23. push_up(p); //从下往上传递区间值 24. } 25. 26. void update(int p,int pl,int pr,int L,int R,int d){ 27. //区间修改,更新[L, R]内最大值 28. if(L<=pl && pr<=R){ 29. //完全覆盖,直接返回这个结点,它的子树不用再深入了 30. tree[p] = d; 31. return; 32. } 33. int mid=(pl+pr)>>1; 34. if(L<=mid) update(ls(p),pl,mid,L,R,d); //递归左子树 35. if(R>mid) update(rs(p),mid+1,pr,L,R,d); //递归右子树 36. push_up(p); //更新 37. return; 38. } 39. 40. int query(int p,int pl,int pr,int L,int R){ //在查询区间[L, R]的最大值 41. int res = -INF; 42. if (L<=pl && pr<=R) 43. return tree[p]; //完全覆盖 44. int mid=(pl+pr)>>1; 45. if (L<=mid) res = max(res, query(ls(p),pl,mid,L,R)); 46. //L与左子结点有重叠 47. if (R>mid) res = max(res, query(rs(p),mid+1,pr,L,R)); 48. //R与右子结点有重叠 49. return res; 50. } 51. int main (){ 52. int t=0,cnt=0,m,D; 53. scanf ("%d%d",&m,&D); 54. build(1,1,N); //这样写也行: update(1,1,N,1,N,-INF); 55. for (int b=1;b<=m;++b){ 56. char c[2];int x; 57. scanf ("%s %d",c,&x); 58. if (c[0]=='A'){ 59. cnt++; 60. update(1,1,N,cnt,cnt,(x+t)%D); 61. //update(1,1,N,cnt,(x+t)%D); 62. } 63. else { 64. t = query(1,1,N,cnt-x+1,cnt); 65. printf ("%d\n",t); 66. } 67. } 68. return 0; 69. }
python代码
1. N=100001 2. class SegTree(): 3. def __init__(self): 4. self.sum=0 5. 6. tr=[SegTree() for i in range(N<<2)] 7. 8. def build(p,l,r): 9. if l==r: 10. tr[p].sum=0 11. return 12. mid=(l+r)>>1 13. build(p<<1,l,mid) 14. build(p<<1|1,mid+1,r) 15. tr[p].sum=max(tr[p<<1].sum,tr[p<<1|1].sum) 16. 17. #x=L,y=R,z=d,l=pl,r=pr 18. def modify(x,y,z,p,l,r): 19. if x<=l and r<=y: 20. tr[p].sum=z 21. return 22. mid=(r+l)>>1 23. if x<=mid: 24. modify(x,y,z,p<<1,l,mid) 25. if y>mid: 26. modify(x,y,z,p<<1|1,mid+1,r) 27. tr[p].sum=max(tr[p<<1].sum,tr[p<<1|1].sum) 28. return 29. 30. #x=L,y=R,l=pl,r=pr 31. def query(x,y,p,l,r): 32. res=-1000 33. if x<=l and r<=y: 34. return tr[p].sum 35. mid=(l+r)>>1 36. if x<=mid: 37. res=max(res,query(x,y,p<<1,l,mid)) 38. if y>mid: 39. res=max(res,query(x,y,p<<1|1,mid+1,r)) 40. return res 41. 42. 43. [m,D]=list(map(int,input().split())) 44. build(1,1,N) 45. cnt=0 46. t=0 47. while m>0: 48. m-=1 49. op=list(input().split()) 50. if op[0]=='A': 51. cnt+=1 52. modify(cnt,cnt,(int(op[1])+t)%D,1,1,N) 53. if op[0]=='Q': 54. t=query(cnt-int(op[1])+1,cnt,1,1,N) 55. print(t)