容斥原理应用
样例解释
n = 10, p1=2,p2=3p1=2,p2=3, 求1-10中能满足能整除p1p1或p2p2的个数, 即2,3,4,6,8,9,10,共7个
解题思路
记SiSi为1-n 中能整除pipi的集合,那么根据容斥原理, 所有数的个数为各个集合的并集,计算公式如下
⋃i=1mSi=S1+S2+…+Sm−(S1⋂S2+S1⋂S3+…+Sm−1⋂Sm)+(S1⋂S2⋂S3+…+Sm−2⋂Sm−1⋂Sm)+…+(−1)m−1(⋂i=1mS)
⋃i=1mSi=S1+S2+…+Sm−(S1⋂S2+S1⋂S3+…+Sm−1⋂Sm)+(S1⋂S2⋂S3+…+Sm−2⋂Sm−1⋂Sm)+…+(−1)m−1(⋂i=1mS)
以题目样例为例
S1={2,4,6,8,10},S2={3,6,9},S1⋂S2={6},故S1⋃S2={2,3,4,6,8,9,10}S1={2,4,6,8,10},S2={3,6,9},S1⋂S2={6},故S1⋃S2={2,3,4,6,8,9,10}
实现思路
每个集合实际上并不需要知道具体元素是什么,只要知道这个集合的大小,大小为|Si|=n/pi|Si|=n/pi, 比如题目中|S1|=10/2=5,|S2|=10/3=3|S1|=10/2=5,|S2|=10/3=3
交集的大小如何确定?因为pipi均为质数,这些质数的乘积就是他们的最小公倍数,n除这个最小公倍数就是交集的大小,故|S1⋂S2|=n/(p1∗p2)=10/(2∗3)=1|S1⋂S2|=n/(p1∗p2)=10/(2∗3)=1
如何用代码表示每个集合的状态?这里使用的二进制,以m = 4为例,所以需要4个二进制位来表示每一个集合选中与不选的状态,1101m=41101⏞m=4,这里表示选中集合S1,S2,S4S1,S2,S4,故这个集合中元素的个数为 n/(p1∗p2∗p4)n/(p1∗p2∗p4), 因为集合个数是3个,根据公式,前面的系数为(−1)3−1=1(−1)3−1=1。所以到当前这个状态时,应该是res+=n/(p1∗p2∗p4)res+=n/(p1∗p2∗p4) 。这样就可以表示的范围从0000到11110000到1111的每一个状态
用二进制表示状态的小技巧非常常用,后面的状态压缩DP也用到了这个技巧,因此一定要掌握
AC代码
#include using namespace std; typedef long long LL; const int N = 20; int p[N], n, m; int main() { cin >> n >> m; for(int i = 0; i < m; i++) cin >> p[i];
int res = 0; //枚举从1 到 1111...(m个1)的每一个集合状态, (至少选中一个集合) for(int i = 1; i < 1 << m; i++) { int t = 1; //选中集合对应质数的乘积 int s = 0; //选中的集合数量 //枚举当前状态的每一位 for(int j = 0; j < m; j++){ //选中一个集合 if(i >> j & 1){ //乘积大于n, 则n/t = 0, 跳出这轮循环 if((LL)t * p[j] > n){ t = -1; break; } s++; //有一个1,集合数量+1 t *= p[j]; } } if(t == -1) continue; if(s & 1) res += n / t; //选中奇数个集合, 则系数应该是1, n/t为当前这种状态的集合数量 else res -= n / t; //反之则为 -1 } cout << res << endl; return 0; }
参考文献
简单的容斥原理介绍请看下图:
C++ 代码
简单的容斥原理介绍请看下图:
本题思路:
将题目所给出的m个数可以看成是m位的二进制数,例如
当p[N]={2,3}时,此时会有01,10,11三种情况
而二进制的第零位表示的是p[0]上面的数字2,第1位表示p[1]上面的数字3
所以当i=1时表示只选择2的情况,当i=2(10)时,表示只选择3的情况,当i=3时,表示2和3相乘的情况
在过程中可以用标记变量t记录,可以按照t的值来选择是”+”还是“-”
代码如下:
#include
#include
using namespace std;
typedef long long LL;
const int N=20;
int p[N];
int main()
{
int n,m;
cin>>n>>m;
for(int i=0;i<m;i++) cin>>p[i];
int res=0; for(int i=1;i<1<<m;i++)//1<<m表示小于2^m { int t=1,cnt=0; for(int j=0;j<m;j++) if(i>>j&1) { if((LL)t*p[j]>n) { t=-1; break; } t*=p[j]; ++cnt; } if(t!=-1) { if(cnt%2) res=res+n/t; else res=res-n/t; } } printf("%d",res); return 0; }