最近在练习rust的时候,给数据库插入数据的时候,需要生成主键id。一般情况下会有以下几种生成主键的方式:
- 主键自增
- UUID
- 第三方工具包
- 雪花算法
但是在这个练手项目中,我们还是选择了自己实现一个雪花算法。
介绍
关于雪花算法的组成部分(64bit)
- 第一位占用1bit,它的值始终为0。
- 时间戳占用41bit,精确到毫秒,可以容纳69年的时间。
- 工作机器ID占用10bit,高位5bit是数据节点ID,低位5bit是工作节点ID,最多可以容纳1024个节点。
- 序列号占用12bit,每个节点每毫秒开始从0不断累加,最多可以累加到4095,一共可以产生4096个ID。
也就是说,按照以上这个设计,同一毫秒内可以生成:4096*1024=4194304个ID。
rust实现
- 首先,我们先把一些初始的静态值定义好
// 开始时间戳(2022-08-01) const TWEPOCH: u128 = 1659283200000; // 机器id所占的位数 const WORKER_ID_BITS: u128 = 5; // 数据节点所占的位数 const DATA_CENTER_ID_BITS: u128 = 5; // 支持最大的机器ID,最大是31 const MAX_WORKER_ID: u128 = (-1 ^ (-1 << WORKER_ID_BITS)) as u128; // 支持的最大数据节点ID,结果是31 const MAX_DATA_CENTER_ID: u128 = (-1 ^ (-1 << DATA_CENTER_ID_BITS)) as u128; // 序列号所占的位数 const SEQUENCE_BITS: u128 = 12; // 工作节点标识ID向左移12位 const WORKER_ID_SHIFT: u128 = SEQUENCE_BITS; // 数据节点标识ID向左移动17位(12位序列号+5位工作节点) const DATA_CENTER_ID_SHIFT: u128 = SEQUENCE_BITS + WORKER_ID_BITS; // 时间戳向左移动22位(12位序列号+5位工作节点+5位数据节点) const TIMESTAMP_LEFT_SHIFT: u128 = SEQUENCE_BITS + WORKER_ID_BITS + DATA_CENTER_ID_BITS; // 生成的序列掩码,这里是4095 const SEQUENCE_MASK: u128 = (-1 ^ (-1 << SEQUENCE_BITS)) as u128; 复制代码
- 上面的TWEPOCH、WORKER_ID_BITS、DATA_CENTER_ID_BITS、SEQUENCE_BITS这几个值可以根据自己的实际情况做出对应的调整。
【注意】TWEPOCH的值不能超过当前的日期对应的时间戳。 - 定义结构体SnowflakeIdWorkerInner
// 这是一个内部结构体,只在这个mod里面使用 struct SnowflakeIdWorkerInner { // 工作节点ID worker_id: u128, // 数据节点ID data_center_id: u128, // 序列号 sequence: u128, // 上一次时间戳 last_timestamp: u128, } impl SnowflakeIdWorkerInner { fn new(worker_id: u128, data_center_id: u128) -> Result<SnowflakeIdWorkerInner> { // 校验worker_id合法性 if worker_id > MAX_WORKER_ID { return Err(Error::msg(format!("workerId:{} must be less than {}", worker_id, MAX_WORKER_ID))); } // 校验data_center_id合法性 if data_center_id > MAX_DATA_CENTER_ID { return Err(Error::msg(format!("datacenterId:{} must be less than {}", data_center_id, MAX_DATA_CENTER_ID))); } // 创建SnowflakeIdWorkerInner对象 Ok(SnowflakeIdWorkerInner { worker_id, data_center_id, sequence: 0, last_timestamp: 0, }) } // 获取下一个id fn next_id(&mut self) -> Result<u128> { // 获取当前时间戳 let mut timestamp = Self::get_time()?; // 如果当前时间戳小于上一次的时间戳,那么跑异常 if timestamp < self.last_timestamp { return Err(Error::msg(format!("Clock moved backwards. Refusing to generate id for {} milliseconds", self.last_timestamp - timestamp))); } // 如果当前时间戳等于上一次的时间戳,那么计算出序列号目前是第几位 if timestamp == self.last_timestamp { self.sequence = (self.sequence + 1) & SEQUENCE_MASK; // 如果计算出来的序列号等于0,那么重新获取当前时间戳 if self.sequence == 0 { timestamp = Self::til_next_mills(self.last_timestamp)?; } } else { // 如果当前时间戳大于上一次的时间戳,序列号置为0。因为又开始了新的毫秒,所以序列号要从0开始。 self.sequence = 0; } // 把当前时间戳赋值给last_timestamp,以便下一次计算next_id self.last_timestamp = timestamp; // 把上面计算得到的对应数值按要求移位拼接起来 Ok(((timestamp - TWEPOCH) << TIMESTAMP_LEFT_SHIFT) | (self.data_center_id << DATA_CENTER_ID_SHIFT) | (self.worker_id << WORKER_ID_SHIFT) | self.sequence) } // 计算一个大于上一次时间戳的时间戳 fn til_next_mills(last_timestamp: u128) -> Result<u128> { // 获取当前时间戳 let mut timestamp = Self::get_time()?; // 如果当前时间戳一直小于上次时间戳,那么一直循环获取,直至当前时间戳大于上次获取的时间戳 while timestamp <= last_timestamp { timestamp = Self::get_time()?; } // 返回满足要求的时间戳 Ok(timestamp) } // 获取当前时间戳 fn get_time() -> Result<u128> { match SystemTime::now().duration_since(UNIX_EPOCH) { Ok(s) => { Ok(s.as_millis()) } Err(_) => { Err(Error::msg("get_time error!")) } } } } 复制代码
- 在上述代码中,一共做了两件事情:
- 根据定义的结构体实现了new函数,让开发者可以根据自己的需求创建对应的SnowflakeIdWorkerInner实例对象。
- 最核心的next_id函数,里面按照雪花算法的设计实现了如何获取next_id的逻辑。
- 【注意】next_id函数需要保证多线程并发安全。
并发安全
为了解决开发人员在使用这个功能的安全,我们还要做一步最关键的加锁机制:
// 定义一个结构体包装SnowflakeIdWorkerInner,并使用Mutex锁 // 相当于我们在SnowflakeIdWorker里面包装了一个SnowflakeIdWorkerInner的引用 // 这也就意味着我们从SnowflakeIdWorker里面获取的永远是同一个SnowflakeIdWorkerInner引用。 #[derive(Clone)] pub struct SnowflakeIdWorker(Arc<Mutex<SnowflakeIdWorkerInner>>); impl SnowflakeIdWorker { // 给出一个new函数,以便开发人员创建SnowflakeIdWorker对象 pub fn new(worker_id: u128, data_center_id: u128) -> Result<SnowflakeIdWorker> { Ok( Self(Arc::new(Mutex::new(SnowflakeIdWorkerInner::new(worker_id, data_center_id)?))) ) } // 真正的加锁next_id pub fn next_id(&self) -> Result<u128> { // 先获取锁 let mut inner = self.0.lock().map_err(|e| Error::msg(e.to_string()))?; // 再调用内部的next_id();这一步直至最后都是线程安全的。 inner.next_id() // 这一步后会自动释放inner锁。 } } 复制代码
上面这一步代码非常关键,它保证了开发人员在使用过程中可以不用关心如何保证线程安全的问题,可以直接无脑调api即可达到目的。
多线程测试
为了测试我们的这个雪花算法是否能正常运行,我们编写了以下测试代码:
fn main(){ // 创建一个SnowflakeIdWorker对象 let mut id_generator = SnowflakeIdWorker::new(2, 2).unwrap(); // 创建一个数组用来装子线程 let mut handles = vec![]; for _ in 0..6 { // 每个子线程clone一份SnowflakeIdWorker,但是可以SnowflakeIdWorker包装的引用是同一份 let mut id_generator = id_generator.clone(); // 创建子线程 let handle = thread::spawn(move || { // 调用生成id let id = (&mut id_generator).next_id(); // 打印内存地址,证明引用是同一个 println!("{:p}", &id_generator); // 打印生成的id println!("{}", id.unwrap()); }); // 存储子线程 handles.push(handle); } // 主线程等待所有子线程执行完毕 for handle in handles { handle.join().unwrap(); } } 复制代码
根据最后打印输出的结果,可以验证我们的代码是能达到我们的目标的。
最终的代码
use std::sync::{Arc, Mutex}; use std::time::{SystemTime, UNIX_EPOCH}; use anyhow::{Result, Error}; // 开始时间戳(2022-08-01) const TWEPOCH: u128 = 1659283200000; // 机器id所占的位数 const WORKER_ID_BITS: u128 = 5; // 数据节点所占的位数 const DATA_CENTER_ID_BITS: u128 = 5; // 支持最大的机器ID,最大是31 const MAX_WORKER_ID: u128 = (-1 ^ (-1 << WORKER_ID_BITS)) as u128; // 支持的最大数据节点ID,结果是31 const MAX_DATA_CENTER_ID: u128 = (-1 ^ (-1 << DATA_CENTER_ID_BITS)) as u128; // 序列号所占的位数 const SEQUENCE_BITS: u128 = 12; // 工作节点标识ID向左移12位 const WORKER_ID_SHIFT: u128 = SEQUENCE_BITS; // 数据节点标识ID向左移动17位(12位序列号+5位工作节点) const DATA_CENTER_ID_SHIFT: u128 = SEQUENCE_BITS + WORKER_ID_BITS; // 时间戳向左移动22位(12位序列号+5位工作节点+5位数据节点) const TIMESTAMP_LEFT_SHIFT: u128 = SEQUENCE_BITS + WORKER_ID_BITS + DATA_CENTER_ID_BITS; // 生成的序列掩码,这里是4095 const SEQUENCE_MASK: u128 = (-1 ^ (-1 << SEQUENCE_BITS)) as u128; #[derive(Clone)] pub struct SnowflakeIdWorker(Arc<Mutex<SnowflakeIdWorkerInner>>); impl SnowflakeIdWorker { pub fn new(worker_id: u128, data_center_id: u128) -> Result<SnowflakeIdWorker> { Ok( Self(Arc::new(Mutex::new(SnowflakeIdWorkerInner::new(worker_id, data_center_id)?))) ) } pub fn next_id(&self) -> Result<u128> { let mut inner = self.0.lock().map_err(|e| Error::msg(e.to_string()))?; inner.next_id() } } // 这是一个内部结构体,只在这个mod里面使用 struct SnowflakeIdWorkerInner { // 工作节点ID worker_id: u128, // 数据节点ID data_center_id: u128, // 序列号 sequence: u128, // 上一次时间戳 last_timestamp: u128, } impl SnowflakeIdWorkerInner { fn new(worker_id: u128, data_center_id: u128) -> Result<SnowflakeIdWorkerInner> { // 校验worker_id合法性 if worker_id > MAX_WORKER_ID { return Err(Error::msg(format!("workerId:{} must be less than {}", worker_id, MAX_WORKER_ID))); } // 校验data_center_id合法性 if data_center_id > MAX_DATA_CENTER_ID { return Err(Error::msg(format!("datacenterId:{} must be less than {}", data_center_id, MAX_DATA_CENTER_ID))); } // 创建SnowflakeIdWorkerInner对象 Ok(SnowflakeIdWorkerInner { worker_id, data_center_id, sequence: 0, last_timestamp: 0, }) } // 获取下一个id fn next_id(&mut self) -> Result<u128> { // 获取当前时间戳 let mut timestamp = Self::get_time()?; // 如果当前时间戳小于上一次的时间戳,那么跑异常 if timestamp < self.last_timestamp { return Err(Error::msg(format!("Clock moved backwards. Refusing to generate id for {} milliseconds", self.last_timestamp - timestamp))); } // 如果当前时间戳等于上一次的时间戳,那么计算出序列号目前是第几位 if timestamp == self.last_timestamp { self.sequence = (self.sequence + 1) & SEQUENCE_MASK; // 如果计算出来的序列号等于0,那么重新获取当前时间戳 if self.sequence == 0 { timestamp = Self::til_next_mills(self.last_timestamp)?; } } else { // 如果当前时间戳大于上一次的时间戳,序列号置为0。因为又开始了新的毫秒,所以序列号要从0开始。 self.sequence = 0; } // 把当前时间戳赋值给last_timestamp,以便下一次计算next_id self.last_timestamp = timestamp; // 把上面计算得到的对应数值按要求移位拼接起来 Ok(((timestamp - TWEPOCH) << TIMESTAMP_LEFT_SHIFT) | (self.data_center_id << DATA_CENTER_ID_SHIFT) | (self.worker_id << WORKER_ID_SHIFT) | self.sequence) } // 计算一个大于上一次时间戳的时间戳 fn til_next_mills(last_timestamp: u128) -> Result<u128> { // 获取当前时间戳 let mut timestamp = Self::get_time()?; // 如果当前时间戳一直小于上次时间戳,那么一直循环获取,直至当前时间戳大于上次获取的时间戳 while timestamp <= last_timestamp { timestamp = Self::get_time()?; } // 返回满足要求的时间戳 Ok(timestamp) } // 获取当前时间戳 fn get_time() -> Result<u128> { match SystemTime::now().duration_since(UNIX_EPOCH) { Ok(s) => { Ok(s.as_millis()) } Err(_) => { Err(Error::msg("get_time error!")) } } } }