题目描述
输入输出
样例
样例输入#1
5 4
5 -1 3 4 -1
样例输出#1
5
样例输入#2
3 0
-1 2 -3
样例输出#2
4
样例输入#3
4 -1
-2 1 -2 3
样例输出#3
3
数据范围
思路
枚举1−n 的每个数作为结尾,找到每个数前面前缀大于s[i]−t 的前缀有多少。
直接开一个前缀数组 s, 从小到大排序。
从前遍历,先将 s[i−1](即上一轮的pre),在排序的前缀和数组 s 中,二分出大于 pre 的下标 l,然后插入到树状数组中。
当前枚举的前缀和 pre(pre+=a[i];),在排序的前缀和数组s 中,二分出大于pre−t 的下标r。
大于这个 r 的即为答案做贡献。
这里需要注意,0 号位也参与。1<=l<=n,0<=l-1<=n-1
树状数组维护的是 i 前面有几个数小于等于它。
权值线段树
s[r] - s[l - 1] < t
s[l-1]>=s[r] - t + 1
s[r]<=s[l-1] + t - 1
与树状数组思路一致。
权值线段树维护值域,用于查询值落于区间[valL,valR] 的个数。
由于前缀和,所以最小 −2e14, 由于−t 所以值最小会在 −4e14。因为权值线段树维护的都是正值,所以将所有数加上 4e14,拉到正的范畴。
正向遍历
s[l−1]>=s[r]−t+1
将 s[i−1]插入树中,r即视为i,(前面小于i的位置都已插入树中),相当于询问前面 >=s[i]−t+1 的值的个数
for (int i = 1; i <= n; i++) { modify(root, L, R, s[i - 1], 1); res += query(root, L, R, s[i] - t + 1, R); }
反向遍历
s[r]<=s[l−1]+t−1
将s[i]插入树中,l即视为i,(后面大于i的位置都已插入树中),相当于询问后面<=s[i−1]+t−1
for (int i = n; i >= 1; i--) { modify(root, L, R, s[i], 1); res += query(root, L, R, 1, s[i - 1] + t - 1); }
代码
- 树状数组
int n, t; int tr[N], s[N], a[N]; int sum(int x) { int res = 0; for (; x; x -= x & -x) res += tr[x]; return res; } void add(int x, int c) { for (; x <= n + 1; x += x & -x) tr[x] += c; } void solve() { cin >> n >> t; for (int i = 1; i <= n; i++) cin >> a[i], s[i] = s[i - 1] + a[i]; sort(s, s + 1 + n); int res = 0, pre = 0; for (int i = 1; i <= n; i++) { int l = upper_bound(s, s + 1 + n, pre) - s; add(l, 1); pre += a[i]; int r = upper_bound(s, s + 1 + n, pre - t) - s; res += sum(n + 1) - sum(r); } cout << res << endl; }
- 权值线段树
int n, t; struct node { int l, r; int v; } tr[N << 5]; int s[N], root, idx; const int L = 1; const int R = 1e15; void pushup(int p) { tr[p].v = tr[tr[p].l].v + tr[tr[p].r].v; } void modify(int &p, int l, int r, int x, int v) { if(!p) p = ++idx; if(l == r) { tr[p].v += v; return ;} int mid = l + r >> 1; if(x <= mid) modify(tr[p].l, l, mid, x, v); if(x > mid) modify(tr[p].r, mid + 1, r, x, v); pushup(p); } int query(int p, int l, int r, int ql, int qr) { if(!p) return 0; if(l >= ql && r <= qr) return tr[p].v; int mid = l + r >> 1; int v = 0; if(ql <= mid) v = query(tr[p].l, l, mid, ql, qr); if(qr > mid) v += query(tr[p].r, mid + 1, r, ql, qr); return v; } void solve() { cin >> n >> t; for (int i = 1; i <= n; i++) { int x; cin >> x; s[i] = s[i - 1] + x; } for (int i = 0; i <= n; i++) s[i] += 4e14; int res = 0; // s[r] - s[l - 1] < t // s[l - 1] >= s[r] - t + 1 // for (int i = 1; i <= n; i++) { // modify(root, L, R, s[i - 1], 1); // res += query(root, L, R, s[i] - t + 1, R); // } // s[r] <= s[l - 1] + t - 1 for (int i = n; i >= 1; i--) { modify(root, L, R, s[i], 1); res += query(root, L, R, 1, s[i - 1] + t - 1); } cout << res << endl; }