上代码
很容易想到 三元上升子序列的个数=每个数前比它小的数的个数*每个数后比它大的数的个数 之和
我们只需要维护 每个数前比它小的数的个数 和 每个数后比它大的数的个数 并记录下来就可以了,
那么怎样用线段树去维护呢?以一组数据 An={ 1,4,5,3 } 为例:
首先我们开个桶sum[n]记录每个数的出现次数,那么一开始这个桶就是 0 0 0 0 0
插入A1=1,此时的桶为 1 0 0 0 0 ,查询A1前有没有比它小的数: 1 0 0 0 0,好吧一个都没有,记smaller[1]=0;
插入A2=4,此时的桶为 1 0 0 1 0 ,查询A2前有没有比它小的数:(1 0 0)1 0,发现此时1比4小,smaller[2]=1;
插入A3=5,此时的桶为 1 0 0 1 1 ,查询A3前有没有比它小的数:(1 0 0 1)1,发现此时1,4比5小,smaller[3]=2;
插入A4=3,此时的桶为 1 0 1 1 1 ,查询A4前有没有比它小的数:(1 0)1 1 1,发现此时1比3小,smaller[4]=1;
按上面的操作步骤,即每次更新,将sum[A[n]]+1,然后sum[1]~sum[A[n]-1]的和即是small[n]的值,这也是用线段树求逆序对的方法
那么用同样的方法求出所有数后比它大的数,得到bigger[1~4],最后ans=ans+smaller[i]*bigger[i],1<=i<=4
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 3e4 + 10; //nlongn 实现 int n; ll a[N], b[N]; struct node { ll l; ll r; ll w; } tr[N << 2]; ll smaller[N]; ll bigger[N]; void pushup(ll u) { tr[u].w = tr[u << 1].w + tr[u << 1 | 1].w; } void build(ll u, ll l, ll r) { tr[u] = {l, r, 0}; if (l == r) return; ll mid = l + r >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); return; } void add(ll u, ll x) { if (tr[u].l == tr[u].r && tr[u].l == x) { tr[u].w++; return; } ll mid = tr[u].l + tr[u].r >> 1; if (x <= mid) add(u << 1, x); else add(u << 1 | 1, x); pushup(u); } void clear(ll u, ll x) { if (tr[u].l == tr[u].r && tr[u].l == x) { tr[u].w = 0; return; } ll mid = tr[u].l + tr[u].r >> 1; if (x <= mid) clear(u << 1, x); else clear(u << 1 | 1, x); pushup(u); } ll query(ll u, ll l, ll r) { if (tr[u].l >= l && tr[u].r <= r) { return tr[u].w; } ll s = 0; ll mid = tr[u].l + tr[u].r >> 1; if (l <= mid) s += query(u << 1, l, r); if (r > mid) s += query(u << 1 | 1, l, r); return s; } int main() { scanf("%d", &n); for (int i = 1; i <= n; i++) scanf("%lld", &a[i]), b[i] = a[i]; sort(b + 1, b + 1 + n); int last = unique(b + 1, b + 1 + n) - b - 1; for (int i = 1; i <= n; i++) { a[i] = lower_bound(b + 1, b + 1 + last, a[i]) - b; } build(1, 1, last); for (int i = 1; i <= n; i++) { add(1, a[i]); ll te = 0; te = query(1, 1, a[i] - 1); smaller[i] = te; } for (int i = 1; i <= last; i++) clear(1, i); for (int i = n; i >= 1; i--) { add(1, a[i]); ll te = 0; te = query(1, a[i] + 1, last); bigger[i] = te; } ll res = 0; // for (int i = 1; i <= n; i++) cout << smaller[i] << ' '; // cout << "\n"; // for (int i = 1; i <= n; i++) cout << bigger[i] << ' '; // cout << "\n"; for (int i = 1; i <= n; i++) res += smaller[i] * bigger[i]; printf("%lld\n", res); }