子序列的定义:对于一个序列a=a[1],a[2],…a[n]。则非空序列a’=a[p1],a[p2]…a[pm]为a的一个子序列,其中1<=p1<p2<…<pm<=n。
例如4,14,2,3和14,1,2,3都为4,13,14,1,2,3的子序列。对于给出序列a,有些子序列可能是相同的,这里只算做1个,请输出a的不同子序列的数量。由于答案比较大,输出Mod 10^9 + 7的结果即可。
输入
第1行:一个数N,表示序列的长度(1 <= N <= 100000)
第2 - N + 1行:序列中的元素(1 <= a[i] <= 100000)
输出
输出a的不同子序列的数量Mod 10^9 + 7。
输入样例
4
1
2
3
2
输出样例
13
我们知道如果不存在重复的数,那么dp[i]=dp[i-1]*2(含空集的情况)。现在考虑出现了重复的数。比如当前要取的数为a[i],且a[i]最近一次在之前的j位置出现过了。那么有dp[i]=dp[i-1]*2-dp[j-1]。所以我们利用一个数组mark记录下a[i]出现的位置就好了,没有出现过为0。
假设子序列的前k个数的子序列个数为d(k),那么前k - 1个子序列的个数就为d(k - 1)个子序列,从k - 1 到k的变化是怎样的呢?
1、假设数组a[N]第k个数为a[k],如果a[k] 与前面的k - 1个数都不相同,那么就有 : d(k) = d(k - 1) + 【d(k - 1) + 1】 = 2d(k - 1) + 1,为什么呢?可以这样想,对于前k- 1项的子序列个数为d(k - 1),那前k个数,无非就是在前k - 1项的基础上多加了一个数a[k](a[k]与前k - 1个数任意一个都相等),那就在原来的组合上加上a[k],就有d(k - 1)个,还有一个a[k]自己构成一个子序列,所以还要加1;
2、假设a[k] 与前面的k - 1个数其中一个相等,那依旧加上前k - 1个子序列个数 d(k - 1),但是由于前面有与a[k]相等的数,所以要减掉重复的部分,如何找到重复的部分呢,假设离k最近的一个与a[k]相等的数为第t个a[t] = a[k],即序列(a[1], a[2], ……,a[t],……,a[k - 1],a[k]),a[t] = a[k];我们已经知道序列(a[1], a[2], ……,a[t])的序列个数为d(t),那么d(t - 1)就是重复的部分,这里需要自己做好思考,也是算法的关键部分,这里我要解释的地方是,为什么只需要找到离k最近的t使得a[t] = a[k]?给出的解释是:我们是从1 - n对数组进行遍历的,计算d(i)的i就是从1到n依次计算的,那么第一次遇到a[k] = a[t]的情况满足条件:有且仅有一个t使得a[t] = a[k],比如序列(1, 2, 3, 2, 4, 2),分别计算d(1),d(2),d(3),d(4),d(5),d(6);我们在计算d(4)的时候发现a[4] = a[2](假设下标从1开始),所以d(4) = 2*d(3) - d(2 -1) = 2d(3) - d(1);当计算d(6)的时候也有a[6] = a[4] = a[2],但是由于我们前面已经把a[2]重复的部分减掉了,所以不需要再减,d(6) = 2 * d(5) - d(4 - 1) = 2d(5) - d(3).
过程繁琐,我总结一下结论:
状态转移方程为:
d(k) = 2 * d(k - 1) + 1; a[k] != a[i],i = 1,2,3……k - 1;
d(k) = 2 * d(k - 1) - d(t - 1); 从k往前搜索,存在离k最近的t,使得a[t] = a[k].
#include <bits/stdc++.h> using namespace std; const int mod = 1e9 + 7; const int maxn = 1e6 + 5; int a[maxn], dp[maxn], mark[maxn]; int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); int n; cin >> n; for (int i = 1; i <= n; i++) { cin >> a[i]; } dp[0] = 1; for (int i = 1; i <= n; i++) { if (mark[a[i]]) { dp[i] = 2 * dp[i - 1] - dp[mark[a[i]] - 1]; dp[i] = (dp[i] % mod + mod) % mod; } else { dp[i] = 2 * dp[i - 1] % mod; } mark[a[i]] = i; } cout << ((dp[n] - 1) % mod + mod) % mod << endl; return 0; }