linkkkkk
题意:
将一个长度为 n 的序列分为 m段,使得总价值最大。
一段区间的价值表示为区间内不同数字的个数。
n≤35000,m≤50
思路:
考虑朴素的d p方程:d p [ i ] [ j ]表示将前i个数字分成j段得到的最大价值。
转移为d p [ i ] [ j ] = m a x ( d p [ k ] [ j − 1 ] + c n t [ k + 1 ] [ i ] )
其中c n t [ k + 1 ] [ i ]表示区间[ k + 1 , i ]的价值。
这样时间复杂度是O ( k n 2 )的,O ( k n )是无法优化的,看怎么能够优化c n t [ k + 1 ] [ i ]的计算。
正常的思路是开一个桶暴力维护,但是每个数作用的区间都是可以预处理出来的。假设第i个数上一次出现的位置为p r e i,那么只有在[ p r e i + 1 , i ]时,a i的贡献才是1。这样就变成了区间修改,区间求和,用线段树就可以维护。
由于m很小,在外层枚举m。对于d p [ i ] [ j ]建树,初值为d p [ i ] [ j − 1 ],表示第i个数跟前面的段在一起。然后对于每个a i都更新贡献的范围,同时更新后求出[ j − 1 , i ]的最大值。时间复杂度为O ( k n l o g n )
还有一种决策单调性优化dp的解法
代码:
// Problem: D. The Bakery // Contest: Codeforces - Codeforces Round #426 (Div. 2) // URL: https://codeforces.com/contest/834/problem/D?mobile=true // Memory Limit: 256 MB // Time Limit: 2500 ms // // Powered by CP Editor (https://cpeditor.org) #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef pair<int, int>PII; inline ll read(){ll x = 0, f = 1;char ch = getchar();while(ch < '0' || ch > '9'){if(ch == '-')f = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}return x * f;} inline void write(ll x){if (x < 0) x = ~x + 1, putchar('-');if (x > 9) write(x / 10);putchar(x % 10 + '0');} #define rep(i,a,b) for(int i=(a);i<=(b);i++) #define per(i,a,b) for(int i=(a);i>=(b);i--) ll ksm(ll a, ll b,ll mod){ll res = 1;while(b){if(b&1)res=res*a%mod;a=a*a%mod;b>>=1;}return res;} #define read read() #define debug(x) cout<<#x<<":"<<x<<endl; const int maxn=35050,inf=0x3f3f3f3f; int n,m,a[maxn]; int dp[maxn][55]; int pre[maxn],pos[maxn]; struct node{ int l,r,laz,maxx; }tr[maxn*4]; void pushup(int u){ tr[u].maxx=max(tr[u<<1].maxx,tr[u<<1|1].maxx); } void pushdown(int u){ if(tr[u].laz){ tr[u<<1].maxx+=tr[u].laz; tr[u<<1|1].maxx+=tr[u].laz; tr[u<<1].laz+=tr[u].laz; tr[u<<1|1].laz+=tr[u].laz; tr[u].laz=0; } } void build(int u,int l,int r,int p){ tr[u].l=l,tr[u].r=r; tr[u].laz=0;tr[u].maxx=0; if(l==r){ tr[u].maxx=dp[l-1][p-1]; return ; } int mid=(l+r)/2; build(u<<1,l,mid,p);build(u<<1|1,mid+1,r,p); pushup(u); } void update(int u,int l,int r,int ql,int qr){ if(ql<=l&&r<=qr){ tr[u].maxx++; tr[u].laz++; return ; } pushdown(u); int mid=(l+r)/2; if(ql<=mid) update(u<<1,l,mid,ql,qr); if(qr>mid) update(u<<1|1,mid+1,r,ql,qr); pushup(u); } int query(int u,int l,int r,int ql,int qr){ if(ql<=l&&r<=qr){ return tr[u].maxx; } pushdown(u); int mid=(l+r)/2; int ans=0; if(ql<=mid) ans=max(ans,query(u<<1,l,mid,ql,qr)); if(qr>mid) ans=max(ans,query(u<<1|1,mid+1,r,ql,qr)); return ans; } int main(){ n=read,m=read; rep(i,1,n){ a[i]=read; pre[i]=pos[a[i]]+1; pos[a[i]]=i; } for(int j=1;j<=m;j++){ build(1,1,n,j); for(int i=1;i<=n;i++){ update(1,1,n,pre[i],i); if(j-1<=i) dp[i][j]=query(1,1,n,j-1,i); } } write(dp[n][m]); return 0; }
思路2:
再用莫队维护某段区间的不同数的个数。
代码:
// Problem: D. The Bakery // Contest: Codeforces - Codeforces Round #426 (Div. 2) // URL: https://codeforces.com/contest/834/problem/D?mobile=true // Memory Limit: 256 MB // Time Limit: 2500 ms // // Powered by CP Editor (https://cpeditor.org) #include<bits/stdc++.h> using namespace std; typedef long long ll; typedef unsigned long long ull; typedef pair<ll, ll>PLL; typedef pair<int, int>PII; typedef pair<double, double>PDD; typedef pair<string,string>PSS; #define I_int ll inline ll read(){ll x = 0, f = 1;char ch = getchar();while(ch < '0' || ch > '9'){if(ch == '-')f = -1;ch = getchar();}while(ch >= '0' && ch <= '9'){x = x * 10 + ch - '0';ch = getchar();}return x * f;} inline void write(ll x){if (x < 0) x = ~x + 1, putchar('-');if (x > 9) write(x / 10);putchar(x % 10 + '0');} #define read read() #define closeSync ios::sync_with_stdio(0);cin.tie(0);cout.tie(0) #define multiCase int T;cin>>T;for(int t=1;t<=T;t++) #define rep(i,a,b) for(int i=(a);i<=(b);i++) #define repp(i,a,b) for(int i=(a);i<(b);i++) #define per(i,a,b) for(int i=(a);i>=(b);i--) ll ksm(ll a, ll b,ll mod){ll res = 1;while(b){if(b&1)res=res*a%mod;a=a*a%mod;b>>=1;}return res;} const int maxn=35050,mod=1e9+7; const double pi = acos(-1); int dp[maxn][50],a[maxn],n,m; int cnt[maxn],ans,L=1,R; void add(int x){ cnt[x]++; if(cnt[x]==1) ans++; } void del(int x){ cnt[x]--; if(!cnt[x]) ans--; } int cul(int l,int r){ while(L<l) del(a[L++]); while(L>l) add(a[--L]); while(R<r) add(a[++R]); while(R>r) del(a[R--]); return ans; } void solve(int l,int r,int ql,int qr,int tot){ if(l>r) return ; int mid=(l+r)/2,qmid=ql; for(int i=ql;i<=min(qr,mid);i++){ int now=dp[i-1][tot-1]+cul(i,mid); if(now>dp[mid][tot]){ dp[mid][tot]=now;qmid=i; } } solve(l,mid-1,ql,qmid,tot); solve(mid+1,r,qmid,qr,tot); } int main(){ n=read,m=read; rep(i,1,n) a[i]=read; // memset(dp,0x3f,sizeof dp); dp[0][0]=0; for(int i=1;i<=m;i++) solve(1,n,1,n,i); cout<<dp[n][m]<<"\n"; return 0; }