[P3373模板]线段树 2
题目描述
如题,已知一个数列,你需要进行下面三种操作:
- 将某区间每一个数乘上 x
- 将某区间每一个数加上x
- 求出某区间每一个数的和
输入格式
输出格式
输出包含若干行整数,即为所有操作 3 的结果。
样例
样例输入
5 5 38 1 5 4 2 3 2 1 4 1 3 2 5 1 2 4 2 2 3 5 5 3 1 4
样例输出
17 2
提示
【数据范围】
样例说明:
1. 区间加法
s[pos].add = (s[pos].add + k) % mod; s[pos].sum = (s[pos].sum + k * (s[pos].r - s[pos].l + 1)) % mod;
2. 区间乘法
这里就有点不一样了。
先把 mul
和 sum
乘上 k
。
对于之前已经有的 add
,把它乘上 k
即可。在这里,我们把乘之后的值直接更新add的值。
你想, add
其实应该加到 sum
里面,所有乘上 k
后,运用乘法分配律, (sum + add) * k == sum * k + add * k
。
这样来实现 add
和 sum
有序进行。
s[pos].add = (s[pos].add * k) % mod; s[pos].mul = (s[pos].mul * k) % mod; s[pos].sum = (s[pos].sum * k) % mod;
3. pushdown的维护
现在要下传两个标记: add
和 mul
。
sum
:因为 add
之前已经乘过,所以在子孩子乘过 mul
后直接加就行。
mul
:直接乘。
add
:因为 add
的值是要包括乘之后的值,所以子孩子要先乘上 mul
。
s[pos << 1].sum = (s[pos << 1].sum * s[pos].mul + s[pos].add * (s[pos << 1].r - s[pos << 1].l + 1)) % mod; s[pos << 1].mul = (s[pos << 1].mul * s[pos].mul) % mod; s[pos << 1].add = (s[pos << 1].add * s[pos].mul + s[pos].add) % mod;
4. 位运算
在此注释: <<
和 |
是位运算,n << 1 == n * 2
,n << 1 | 1 == n * 2 + 1
(再具体的自己百度)。
5. 完整代码
import java.io.*; class Node { int l; int r; long sum; long add; long mul; public Node(int l, int r, long sum, long add, long mul) { this.l = l; this.r = r; this.sum = sum; this.add = add; this.mul = mul; } } public class Main { static BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out)); static int MAXN = 100010; static int[] a = new int[MAXN]; static Node[] s = new Node[MAXN * 4]; static int n; static int m; static int mod; public static void main(String[] args) throws IOException { Read read = new Read(); String[] s0 = read.getStringLine().split(" "); n = Integer.parseInt(s0[0]); m = Integer.parseInt(s0[1]); mod = Integer.parseInt(s0[2]); String[] s2 = read.getStringLine().split(" "); for (int i = 1; i <= n; i++) { a[i] = Integer.parseInt(s2[i - 1]); } for (int i = 1; i < s.length; i++) { s[i] = new Node(0,0,0,0,1); } buildTree(1, 1, n); for (int i = 1; i <= m; i++) { int opt; int x; int y; String[] si = read.getStringLine().split(" "); opt = Integer.parseInt(si[0]); x = Integer.parseInt(si[1]); y = Integer.parseInt(si[2]); if (opt == 1) { int k = Integer.parseInt(si[3]); ChangeMul(1, x, y, k); } else if (opt == 2) { int k = Integer.parseInt(si[3]); ChangeAdd(1, x, y, k); } else if (opt == 3) { writer.write(AskRange(1, x, y) + "\n"); } } writer.flush(); writer.close(); } static void update(int pos) { s[pos].sum = (s[pos << 1].sum + s[pos << 1 | 1].sum) % mod; } static void pushdown(int pos) { s[pos << 1].sum = (s[pos << 1].sum * s[pos].mul + s[pos].add * (s[pos << 1].r - s[pos << 1].l + 1)) % mod; s[pos << 1 | 1].sum = (s[pos << 1 | 1].sum * s[pos].mul + s[pos].add * (s[pos << 1 | 1].r - s[pos << 1 | 1].l + 1)) % mod; s[pos << 1].mul = (s[pos << 1].mul * s[pos].mul) % mod; s[pos << 1 | 1].mul = (s[pos << 1 | 1].mul * s[pos].mul) % mod; s[pos << 1].add = (s[pos << 1].add * s[pos].mul + s[pos].add) % mod; s[pos << 1 | 1].add = (s[pos << 1 | 1].add * s[pos].mul + s[pos].add) % mod; s[pos].add = 0; s[pos].mul = 1; } static void buildTree(int pos, int l, int r) { //建树 s[pos].l = l; s[pos].r = r; s[pos].mul = 1; if (l == r) { s[pos].sum = a[l] % mod; return; } int mid = (l + r) >> 1; buildTree(pos << 1, l, mid); buildTree(pos << 1 | 1, mid + 1, r); update(pos); } static void ChangeMul(int pos, int x, int y, int k) { //区间乘法 if (x <= s[pos].l && s[pos].r <= y) { s[pos].add = (s[pos].add * k) % mod; s[pos].mul = (s[pos].mul * k) % mod; s[pos].sum = (s[pos].sum * k) % mod; return; } pushdown(pos); int mid = (s[pos].l + s[pos].r) >> 1; if (x <= mid) { ChangeMul(pos << 1, x, y, k); } if (y > mid) { ChangeMul(pos << 1 | 1, x, y, k); } update(pos); return; } static void ChangeAdd(int pos, int x, int y, int k) { //区间加法 if (x <= s[pos].l && s[pos].r <= y) { s[pos].add = (s[pos].add + k) % mod; s[pos].sum = (s[pos].sum + (long) k * (s[pos].r - s[pos].l + 1)) % mod; return; } pushdown(pos); int mid = (s[pos].l + s[pos].r) >> 1; if (x <= mid) { ChangeAdd(pos << 1, x, y, k); } if (y > mid) { ChangeAdd(pos << 1 | 1, x, y, k); } update(pos); return; } static long AskRange(int pos, int x, int y) { //区间询问 if (x <= s[pos].l && s[pos].r <= y) { return s[pos].sum; } pushdown(pos); long val = 0; int mid = (s[pos].l + s[pos].r) >> 1; if (x <= mid) { val = (val + AskRange(pos << 1, x, y)) % mod; } if (y > mid) { val = (val + AskRange(pos << 1 | 1, x, y)) % mod; } return val; } } class Read { BufferedReader reader = new BufferedReader(new InputStreamReader(System.in)); StreamTokenizer st = new StreamTokenizer(new InputStreamReader(System.in)); public int nextInt() throws IOException { st.nextToken(); return (int) st.nval; } public double nextDouble() throws IOException { st.nextToken(); return st.nval; } public String nextString() throws IOException { st.nextToken(); return st.sval; } public String getStringLine() throws IOException { return reader.readLine(); } }
参考这位大佬提供的C++语言版本的模板