原题索引:https://www.luogu.com.cn/problem/P1005
题目描述
帅帅经常跟同学玩一个矩阵取数游戏:对于一个给定的 �×�n×m 的矩阵,矩阵中的每个元素 ��,�ai,j 均为非负整数。游戏规则如下:
- 每次取数时须从每行各取走一个元素,共 �n 个。经过 �m 次后取完矩阵内所有元素;
- 每次取走的各个元素只能是该元素所在行的行首或行尾;
- 每次取数都有一个得分值,为每行取数的得分之和,每行取数的得分 = 被取走的元素值 ×2�×2i,其中 �i 表示第 �i 次取数(从 11 开始编号);
- 游戏结束总得分为 �m 次取数得分之和。
帅帅想请你帮忙写一个程序,对于任意矩阵,可以求出取数后的最大得分。
输入格式
输入文件包括 �+1n+1 行:
第一行为两个用空格隔开的整数 �n 和 �m。
第 2∼�+12∼n+1 行为 �×�n×m 矩阵,其中每行有 �m 个用单个空格隔开的非负整数。
输出格式
输出文件仅包含 11 行,为一个整数,即输入矩阵取数后的最大得分。
输入输出样例
输入 #1复制
2 3
1 2 3
3 4 2
输出 #1复制
82
说明/提示
【数据范围】
对于 60%60% 的数据,满足 1≤�,�≤301≤n,m≤30,答案不超过 10161016。
对于 100%100% 的数据,满足 1≤�,�≤801≤n,m≤80,0≤��,�≤10000≤ai,j≤1000。
【题目来源】
NOIP 2007 提高第三题。
题解
首先这个题就是在每行上的[i,j]区间的最大值
dp[i][j]=max{dp[i+1][j]+2^(m-(j-i))×v[i],dp[i][j-1]+2^(m-(j-i))×v[j]}
但是这题会爆long long,得用高精,但是高精很慢,这个方程又是O(nm^2)的,于是我受到讨论“为什么int128会CE”的启发,自己手写了个128位的整数运算,一点循环都没用,比高精快多了,主要就是各种高低位的操作。不过值得一提的就是在输出时肯定要模10除10,我用了这样一个除法公式:
设A为128位整数,AH为高64位,AL为低64位,则:
A/n=((AH/n)<<64)|((((AH%n)<<64)+AL)/n)
A%n=(((AH%n)<<64)+AL)%n
也就是说,商的高位为AH/n,低位就是AH%n与AL组合成的128位整数除以n的商(余数就是A%n了),这个商肯定不会爆64位。
其余的涉及到的运算如下:
若b为int,则:
A\*b=(AH\*b)<<64+(AL\*b)
若B为128位,则:
A+B=(AH+BH)<<64+(AL+BL)
使用以上两个公式的时候要注意进位问题。
于是,有了128位运算,我就用记搜解决了此题。代码如下:
#include<iostream> #include<cstdio> #include<stack> #include<cstring> using namespace std; //从这里到100行都是int128的运算代码 typedef struct _tword{ unsigned long long rah;//高64位 unsigned long long ral;//低64位 friend bool operator <(const _tword &a,const _tword &b){ if(a.rah<b.rah)return(1); if(a.rah>b.rah)return(0); return(a.ral<b.ral); } }tword;//这个结构体就是128位整数了 unsigned long long hbt=1; unsigned long long man32;//0~31位全1 unsigned long long man32h;//32~63位全1 void add(tword a,tword b,tword &res){ unsigned long long al=a.ral&man32; unsigned long long ah=(a.ral&man32h)>>32; unsigned long long bl=b.ral&man32; unsigned long long bh=(b.ral&man32h)>>32; //因为这里需要处理进位,所以我们要把128位拆成4个32位并把它们扩充至64位,这样就能通过对结果的高32位的处理来加上进位 al+=bl; ah=ah+bh+((al&man32h)>>32); res.rah=a.rah+b.rah+((ah&man32h)>>32); res.ral=((ah&man32)<<32)|(al&man32); }//128位加法 void kuomul(unsigned long long a,unsigned long long b,tword &res){//计算64位×64位=128位 unsigned long long al=a&man32; unsigned long long ah=(a&man32h)>>32; unsigned long long bl=b&man32; unsigned long long bh=(b&man32h)>>32; //ah、al为a的高低32位,bh、bl为b的高低32位,则a*b=(ah*bh)<<32+(ah<<32)*bl+(bh<<32)*al+al*bl unsigned long long albl=al*bl; unsigned long long ahbh=ah*bh; unsigned long long albh=al*bh; unsigned long long ahbl=ah*bl; tword r1,r2,r3,r4; r1.rah=ahbh; r1.ral=0; r2.rah=0; r2.ral=albl; r3.ral=(albh&man32)<<32; r3.rah=(albh&man32h)>>32; r4.ral=(ahbl&man32)<<32; r4.rah=(ahbl&man32h)>>32; res.rah=0; res.ral=0; add(res,r1,res); add(res,r2,res); add(res,r3,res); add(res,r4,res);//把四项相加,得出结果 } void mul(tword a,int b,tword &res){//128乘int的运算 tword tmp; unsigned long long ah=a.rah*b; kuomul(a.ral,b,tmp); unsigned long long al=tmp.ral; res.rah=ah+tmp.rah; res.ral=al; } int mod10(tword &hint){//把128位hint除以10,并返回余数 unsigned long long eah=hint.rah / 10; unsigned long long mod=((hint.rah%10)<<32)|((hint.ral&man32h)>>32); unsigned long long low=hint.ral&man32; hint.rah=eah; unsigned long long eal=(mod/10)<<32; low=low|((mod%10)<<32); eal=eal|(low/10); hint.ral=eal; return(low%10); } stack<char> ss;//字符栈 void print(tword hint){//输出128位整数 if(hint.rah==0&&hint.ral==0){ printf("0");//为0直接输出0 } else{ while(hint.rah!=0||hint.ral!=0){ ss.push(mod10(hint)+'0'); } while(!ss.empty()){ putchar(ss.top()); ss.pop(); } } } int v[81];//矩阵的一行 int line; unsigned char bv[81][81]; tword f[81][81]; int m; void dp(int l,int r,tword &res){//记搜 if(bv[l][r]){res=f[l][r];return;} //注意这里一旦找到就赶紧return,一般的记搜都是带返回值的,我因为要传递结构体就没带返回值,开始的时候忘了加return然后就一直出错 bv[l][r]=1; tword resx; resx.rah=0; resx.ral=0; int bit=m-(r-l); unsigned long long bt=1; if(bit<64){ bt=bt<<bit; resx.ral=resx.ral|bt; } else{ bt=bt<<(bit-64); resx.rah=resx.rah|bt; }//resx即为2^i,直接按位就行了 if(l==r){//区间内只有一个数 mul(resx,v[l],resx); f[l][r]=resx; res=resx; } else{ tword left; dp(l+1,r,left); tword lefta=resx; mul(lefta,v[l],lefta); add(left,lefta,left); tword right; dp(l,r-1,right); tword righta=resx; mul(righta,v[r],righta); add(right,righta,right); //left为取左边能获得的最大值,right为取右边能获得的最大值,最后在他们中间取个最大的 if(left<right){ f[l][r]=right; res=right; } else{ f[l][r]=left; res=left; } } } inline int get(){//读入优化 char c; int n; while(1){ c=getchar(); if(c>='0'&&c<='9')break; } n=c-'0'; while(1){ c=getchar(); if(c>='0'&&c<='9'){ n=n*10+c-'0'; } else{ return(n); } } } int main(){ man32=0xffffffff; man32h=0xffffffff00000000; int height,width; cin>>height>>width; m=width; tword sum; sum.rah=0; sum.ral=0; for(register int i=1;i<=height;i++){ memset(bv,0,sizeof(bv)); for(register int j=1;j<=width;j++){ v[j]=get(); } tword tmp; dp(1,width,tmp); add(sum,tmp,sum);//把各行的最大得分依次相加 } print(sum);//输出 return(0); }