题目描述
在含有 n 个整数的序列{a_1,a_2,...,a_n}中,当且仅当i
输入格式
开始一行一个正整数 n。
以后一行n 个整数a_1,a_2,...,a_n
输出格式
一行一个整数表示thair的个数。
输入输出样例
输入 #1
4
2 1 3 4
输出 #1
2
输入 #2
5
1 2 2 3 4
输出 #2
7
思路
作为一个三元的上升序列,我们很容易想到子序列枚举中间的元素。
L[i] 为 a[i] 左边小于a[i] 的元素个数。
R[i] 为 a[i] 右边大于a[i] 的元素个数。
乘法原理 :以a[i] 为中间元素的合法序列个数为L[i] * R[i]。
我们可以拿权值线段树来维护一段值域内数的个数,借此来计算出L 和R 数组。
代码
// #include <bits/stdc++.h> #include <iostream> #include <algorithm> using namespace std; #define ls u<<1 #define rs u<<1|1 #define int long long const int N = 100010; struct node { int l, r; int sum; } tr[N << 2]; int n, a[N]; int L[N], R[N]; void pushup(int u) { tr[u].sum = tr[ls].sum + tr[rs].sum; } void build(int u, int l, int r) { tr[u] = {l, r}; if(l == r) { tr[u].sum = 0; return ;} int mid = l + r >> 1; build(ls, l, mid), build(rs, mid + 1, r); pushup(u); } void modify(int u, int x) { if(tr[u].l >= x && tr[u].r <= x) { tr[u].sum += 1; return ; } int mid = tr[u].l + tr[u].r >> 1; if(x <= mid) modify(ls, x); else modify(rs, x); pushup(u); } int query(int u, int l, int r) { if(tr[u].l >= l && tr[u].r <= r) { return tr[u].sum; } int mid = tr[u].l + tr[u].r >> 1; int res = 0; if(l <= mid) res = query(ls, l, r); if(r > mid) res += query(rs, l, r); return res; } signed main() { cin >> n; int mx = 0; for (int i = 1; i <= n; i++) cin >> a[i], mx = max(mx, a[i]); build(1, 1, mx); for (int i = 1; i <= n; i++) { modify(1, a[i]); L[i] = query(1, 1, a[i] - 1); } build(1, 1, mx); for (int i = n; i >= 1; i--) { modify(1, a[i]); R[i] = query(1, a[i] + 1, mx); } int res = 0; for (int i = 1; i <= n; i++) res += L[i] * R[i]; cout << res << endl; return 0; }
M元上升子序列
DP + 树状数组
思路
const int M = 3; int n; int a[N], tr[N], b[N]; int f[M + 1][N]; int sum(int x) { int res = 0; for (; x; x -= x & -x) res += tr[x]; return res; } void add(int x, int c) { for (; x <= n; x += x & -x) tr[x] += c; } void solve() { cin >> n; for (int i = 1; i <= n; i++) cin >> a[i], b[i] = a[i], f[1][i] = 1; sort(a + 1, a + 1 + n); int m = unique(a + 1, a + 1 + n) - a - 1; for (int i = 1; i <= n; i++) b[i] = lower_bound(a + 1, a + 1 + m, b[i]) - a; for (int i = 2; i <= M; i++) { memset(tr, 0, sizeof tr); for (int j = 1; j <= n; j++) { f[i][j] = sum(b[j] - 1); add(b[j], f[i - 1][j]); } } int res = 0; for (int i = 1; i <= n; i++) res += f[M][i]; cout << res << endl; }
【变种】上升四元组
统计上升四元组
枚举 l,统计j 为中间的三元组数量(v[j]: i < j < k , a [ i ] < a [ k ] < a [ j ] i
class Solution { public: long long countQuadruplets(vector<int>& nums) { int n = nums.size(); vector<int> v(n, 0); long long res = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < i; j++) if (nums[j] < nums[i]) res += v[j]; // count:j之前比a[i]小的数个数 // v[j]+=count:j作为中间值,i作为k,统计上三元组个数。 for (int j = 0, count = 0; j < i; j++) { if (nums[j] > nums[i]) v[j] += count; count += nums[j] < nums[i]; } } return res; } };