HashMap无论是 1.7 还是 1.8 其实都能看出 JDK 没有对它做任何的同步操作,所以并发会出问题,甚至出现死循环导致系统不可用。这个问题就交给ConcurrentHashMap。
ConcurrentHashMap是一个 在juc包下的 map, 线程安全。 在jdk.1.8 之前采用数组+ 链表的结构 并且采用分段锁机制 来保证线程安全,而jdk1.8 改成了 数组+ 链表+ 红黑树,线程安全方面也改成了 cas+ synchronized 来保证线程安全。
ConcurrentHashMap类图如下:
本篇博文我们分析JDK1.7下ConcurrentHashMap的实现。
【1】核心属性和构造
① 核心属性
// table的默认初始化容量 static final int DEFAULT_INITIAL_CAPACITY = 16; // table的默认负载因子 static final float DEFAULT_LOAD_FACTOR = 0.75f; /** * The default concurrency level for this table, used when not * otherwise specified in a constructor. */ // table的默认并发级别,换句话说其实是默认多少个“segment” static final int DEFAULT_CONCURRENCY_LEVEL = 16; // 最大容量 static final int MAXIMUM_CAPACITY = 1 << 30; // per-segment tables的最小容量,就是每段最少2个哈希桶位置 static final int MIN_SEGMENT_TABLE_CAPACITY = 2; //段的最大数量 static final int MAX_SEGMENTS = 1 << 16; // slightly conservative // 在size()和containsValue()方法中,如果循环次数==RETRIES_BEFORE_LOCK , //则对每一段都进行加锁 static final int RETRIES_BEFORE_LOCK = 2; /** * Mask value for indexing into segments. The upper bits of a * key's hash code are used to choose the segment. */ //用于索引到段的掩码值。key的散列码的高位用于选择段。 final int segmentMask; //段内索引的移位值 final int segmentShift; /** * The segments, each of which is a specialized hash table. */ // 段数组,每一个端都是一个特殊的哈希表 final Segment<K,V>[] segments; //三个常见的数据对象,使用了transient 不参与序列化和反序列 transient Set<K> keySet; transient Set<Map.Entry<K,V>> entrySet; transient Collection<V> values;
这里可能对segmentMask、segmentShift以及segments比较疑惑,别着急我们慢慢往下看。
// 默认情况下segmentShift =28 segmentMask =15 this.segmentShift = 32 - sshift; this.segmentMask = ssize - 1;
② 核心对象HashEntry
如下所示,从成员来讲与HashMap中的Entry类似,都是hash、key、value、next
。不同的是这里value和next使用了volatile 修饰,保证其他线程能够读取到当前变量的最新值。而且其内部使用了安全类 UNSAFE来保证volatile 语义
static final class HashEntry<K,V> { final int hash; final K key; volatile V value; volatile HashEntry<K,V> next; HashEntry(int hash, K key, V value, HashEntry<K,V> next) { this.hash = hash; this.key = key; this.value = value; this.next = next; } /** * Sets next field with volatile write semantics. (See above * about use of putOrderedObject.) */ // 这里使用了安全类 UNSAFE来保证volatile 语义 final void setNext(HashEntry<K,V> n) { UNSAFE.putOrderedObject(this, nextOffset, n); } // Unsafe mechanics static final sun.misc.Unsafe UNSAFE; static final long nextOffset; static { try { UNSAFE = sun.misc.Unsafe.getUnsafe(); Class k = HashEntry.class; nextOffset = UNSAFE.objectFieldOffset (k.getDeclaredField("next")); } catch (Exception e) { throw new Error(e); } } }
③ 核心类Segment
如下所示,其实ConcurrentHashMap首先构建了Segment[]
,然后每一个Segment又包含了table,table最小长度是2。
static final class Segment<K,V> extends ReentrantLock implements Serializable { //tryLock 的最大次数 static final int MAX_SCAN_RETRIES = Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1; // 每一段的哈希桶, //元素通过entryAt/setEntryAt方法访问或者赋值以确保volatile 语义 transient volatile HashEntry<K,V>[] table; //元素的个数 transient int count; // 结构修改计数器 transient int modCount; // 需要rehashed的临界值/阈值 = capacity * loadFactor // 注意这些都是针对整个段来讲的,而不是某个tab[index]. transient int threshold; // 负载因子 ,对所有segment来说是一致的,其是一个副本以避免与外部对象关联 final float loadFactor; //每一段的构造,主要包括负载因子、阈值以及哈希桶 Segment(float lf, int threshold, HashEntry<K,V>[] tab) { this.loadFactor = lf; this.threshold = threshold; this.table = tab; } //... }
③ 核心构造函数
如下所示是一系列重载的构造函数:
// DEFAULT_CONCURRENCY_LEVEL=16 public ConcurrentHashMap(int initialCapacity, float loadFactor) { this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL); } public ConcurrentHashMap(int initialCapacity) { this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL); } // DEFAULT_INITIAL_CAPACITY = 16 // DEFAULT_LOAD_FACTOR = 0.75 // DEFAULT_CONCURRENCY_LEVEL = 16 public ConcurrentHashMap() { this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL); } //使用给定的Map初始化 public ConcurrentHashMap(Map<? extends K, ? extends V> m) { this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1, DEFAULT_INITIAL_CAPACITY), DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL); putAll(m); }
可以看到其本质都是依赖于下面这个构造函数,这也是我们需要重点分析的。
// 假设为16 0.75 16 public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) { if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0) throw new IllegalArgumentException(); //最大值65536 if (concurrencyLevel > MAX_SEGMENTS) concurrencyLevel = MAX_SEGMENTS; // Find power-of-two sizes best matching arguments int sshift = 0; // 段的个数,如果小于16就增长到16 sshift记录增长的次数 int ssize = 1; while (ssize < concurrencyLevel) { ++sshift; ssize <<= 1; } // 默认情况下 = 32-4 = 28 this.segmentShift = 32 - sshift; //默认情况下 = 16-1 = 15 this.segmentMask = ssize - 1; if (initialCapacity > MAXIMUM_CAPACITY) initialCapacity = MAXIMUM_CAPACITY; //默认c=16/16=1 int c = initialCapacity / ssize; if (c * ssize < initialCapacity) ++c; // MIN_SEGMENT_TABLE_CAPACITY=2 //cap 为每段内数组的大小,默认是2 int cap = MIN_SEGMENT_TABLE_CAPACITY; while (cap < c) cap <<= 1; // create segments and segments[0] // 默认情况下so(0.75,1,HashEntry[2]) Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor), (HashEntry<K,V>[])new HashEntry[cap]); //初始化16个段 Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize]; //把s0放到ss中 UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0] this.segments = ss; }
从上面代码可知s0段内tab[]
的默认大小为2,阈值为1。这也就意味着上面初始化Segment<K,V> s0
只能装入一个HashEntry即在插入第一个元素的时候不会触发扩容,插入第二个元素的时候就会进行第一次扩容。
也就是说,默认情况下数据结构如下图:
从上图可以发现其底层数据结构本质还是数组+链表。无非是在最外层分成了不同的段Segment,段内持有最少两个数组索引位置
。同一个索引位置,通过next构成了链表。
【2】核心方法get
这里核心逻辑是首先定位到某个Segment,然后获取到Segment持有的tab[],再根据hash(key)定位到某个tab[i](索引位置或者称之为槽位)
。
public V get(Object key) { Segment<K,V> s; // manually integrate access methods to reduce overhead HashEntry<K,V>[] tab; // 获取到key的散列值 int h = hash(key); // 默认情况下 h 无符号右移28位 & 15 ,然后 左移SSHIFT ,然后 +SBASE // 其实也就是定位哪个Segment long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; // 获取定位到的段 if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) { // 使用(tab.length - 1) & h定位段内哪个数组位置 / 索引位置 for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); e != null; e = e.next) { K k; // 这段for循环就是基本的链表遍历 if ((k = e.key) == key || (e.hash == h && key.equals(k))) return e.value; } } return null; }
方法总结如下:
- ① 计算key的散列值并进而计算出属于哪个段Segment
- ② UNSAFE获取到定位目标段
- ③
((tab.length - 1) & h)) << TSHIFT) + TBASE
定位到段内tab[]
中索引位置 - ④ 基本的链表遍历
可以看到无论是获取段还是获取某个元素,这里都是用了UNSAFE.getObjectVolatile
来保证读取到目标的最新值(内存可见性)。
【3】核心方法put
虽然 HashEntry 中的 value 是用 volatile 关键词修饰的,但是并不能保证并发的原子性,所以 put 操作时仍然需要加锁处理。volatile 关键词只保证读取时候的内存可见性(读取到最新值)。
put是首先定位到段,对段进行加锁,然后put,最后解锁。故而其支持最大N个并发(N是段的个数,默认是16)。
public V put(K key, V value) { Segment<K,V> s; if (value == null) throw new NullPointerException(); //计算key的散列值 int hash = hash(key); // hash 右移 28位 然后与 15 进行 & 操作 // segmentMask:散列运算的掩码 int j = (hash >>> segmentShift) & segmentMask; // 尝试获取段,判断是否为null,为null则创建 if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment s = ensureSegment(j); return s.put(key, hash, value, false); }
可以看到这里首先计算key的hash值与段的索引位置,尝试获取到段然后触发s.put(key, hash, value, false)
。
① ensureSegment
ensureSegment方法是为了确保Segment,如果不存在则创建Segment。这里需要特别注意的是使用到了“自旋”(while循环)和CAS(UNSAFE.compareAndSwapObject
)。
private Segment<K,V> ensureSegment(int k) { final Segment<K,V>[] ss = this.segments; // 段的索引位置进行偏移 long u = (k << SSHIFT) + SBASE; // raw offset Segment<K,V> seg; //如果获取到的段位null if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { Segment<K,V> proto = ss[0]; // use segment 0 as prototype int cap = proto.table.length;//获取到cap float lf = proto.loadFactor;//负载因子 int threshold = (int)(cap * lf);//计算阈值 //实例化tab[] HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap]; //再次判断是否为null if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // recheck //实例化段 Segment<K,V> s = new Segment<K,V>(lf, threshold, tab); //这一步是自旋,当获取到段位null时,调用UNSAFE的CAS算法进行赋值 //while循环,直到成功 while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s)) break; } } } return seg; }
ok,获取到段后我们继续往下看如何put。
② put(K key, int hash, V value, boolean onlyIfAbsent)
final V put(K key, int hash, V value, boolean onlyIfAbsent) { //如果tryLock返回true,那么node为null;这里触发的是父类ReentrantLock的tryLock //否则scanAndLockForPut 自旋获取锁 HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value); V oldValue; try { HashEntry<K,V>[] tab = table; // 数组中的索引位置 int index = (tab.length - 1) & hash; // 处于索引位置的结点 HashEntry<K,V> first = entryAt(tab, index); for (HashEntry<K,V> e = first;;) { if (e != null) {// K k; // 如果key相等,则直接覆盖旧值 if ((k = e.key) == key || (e.hash == hash && key.equals(k))) { oldValue = e.value; if (!onlyIfAbsent) { e.value = value; ++modCount; } break; } // 向后遍历 e = e.next; } //如果不存在当前key,那么头插法,插入链表,first作为node的next结点 else { if (node != null) node.setNext(first); else node = new HashEntry<K,V>(hash, key, value, first); int c = count + 1; //如果元素个数大于threshold ,且tab.length < 1 << 30 if (c > threshold && tab.length < MAXIMUM_CAPACITY) //其实就是扩容 rehash(node); else //如果不需要扩容就将node放到目标位置 setEntryAt(tab, index, node); ++modCount; count = c; oldValue = null; break; } } } finally { // 释放锁 unlock(); } return oldValue; }
put流程如下:
① 尝试获取锁,tryLock() 或者scanAndLockForPut
② 将当前 Segment 中的 table 通过 key 的 hash (tab.length - 1) & hash定位到 tab[]中的索引位置。
③ 遍历索引位置的链表,获取每一个结点进行判断。如果不为空则判断传入的 key 和当前遍历的 key 是否相等,相等则覆盖旧的 value。
④ 如果第三步不成功则判断node是否为空,如果不为空则node.setNext(first);否则需要新建一个 HashEntry 。判断是否需要扩容,如果需要就进行rehash(node),不需要扩容就setEntryAt(tab, index, node);。
⑤ 最后会解除在 1 中所获取当前 Segment 的锁。
可以看到这里链表插入元素采用了“头插法”。
③ scanAndLockForPut(K key, int hash, V value)
在put第一步的时候会尝试获取锁,如果获取失败肯定就有其他线程存在竞争,则利用 scanAndLockForPut() 自旋获取锁。
在尝试获取锁时扫描包含给定key的节点,如果未找到,则可能创建并返回一个。返回时,保证lock被保持。这个方法返回的node不一定为null,但是一定持有了锁。
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) { // 根据hash确定其在哪个段的哪个数组位置 HashEntry<K,V> first = entryForHash(this, hash); HashEntry<K,V> e = first; HashEntry<K,V> node = null; int retries = -1; // negative while locating node // 自旋获取锁 while (!tryLock()) { HashEntry<K,V> f; // to recheck first below if (retries < 0) { if (e == null) { if (node == null) // speculatively create node node = new HashEntry<K,V>(hash, key, value, null); retries = 0; } else if (key.equals(e.key)) retries = 0; else e = e.next; } // 当自旋次数大于MAX_SCAN_RETRIES(1 或者 64),直接使用lock()来获取锁 else if (++retries > MAX_SCAN_RETRIES) { lock(); break; } // 当retries 为偶数时 且 //f = entryForHash(this, hash)) != first-也就是entry发生了改变 //比如第一次、第三次、第五次 else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) { e = first = f; // re-traverse if entry changed retries = -1; } } return node; }
原理上来说:ConcurrentHashMap 采用了分段锁技术,其中 Segment 继承于 ReentrantLock。不会像 HashTable 那样不管是 put 还是 get 操作都需要做同步处理。理论上 ConcurrentHashMap 支持 CurrencyLevel (Segment 数组数量)的线程并发。每当一个线程占用锁访问一个 Segment 时,不会影响到其他的 Segment。
【4】核心方法rehash
这里会对某个段持有的tab[]
进行二倍扩容,然后重新梳理链表进行定位并将新结点node放入。
private void rehash(HashEntry<K,V> node) { HashEntry<K,V>[] oldTable = table; int oldCapacity = oldTable.length; //二倍扩容 int newCapacity = oldCapacity << 1; //新的临界值 threshold = (int)(newCapacity * loadFactor); //实例化扩容后的数组进行迁移 HashEntry<K,V>[] newTable = (HashEntry<K,V>[]) new HashEntry[newCapacity]; // 容量大小掩码 其实就是length-1 int sizeMask = newCapacity - 1; for (int i = 0; i < oldCapacity ; i++) { //数组索引位置的头结点 HashEntry<K,V> e = oldTable[i]; if (e != null) { HashEntry<K,V> next = e.next; // 在新tab[]中的索引位置 int idx = e.hash & sizeMask; //只有一个节点,直接换地方 if (next == null) // Single node on list newTable[idx] = e; else { // Reuse consecutive sequence at same slot // 链表遍历 HashEntry<K,V> lastRun = e; //记录idx=e.hash & sizeMask int lastIdx = idx; for (HashEntry<K,V> last = next; last != null; last = last.next) { //当前遍历节点的索引位置 int k = last.hash & sizeMask; if (k != lastIdx) { lastIdx = k;//修改lastIdx lastRun lastRun = last; } } //记录 i 位置链表遍历最后的lastIdx,放到新的数组里面 //把lastRun及以后的节点指向 lastIdx位置 newTable[lastIdx] = lastRun; // 其他节点则采用头插法放到k = h & sizeMask位置, //k 可能等于idx等于lastIdx // Clone remaining nodes for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { V v = p.value; int h = p.hash; int k = h & sizeMask; HashEntry<K,V> n = newTable[k]; newTable[k] = new HashEntry<K,V>(h, p.key, v, n); } } } } //把新结点node采用头插法插入newTable int nodeIndex = node.hash & sizeMask; // add the new node node.setNext(newTable[nodeIndex]); newTable[nodeIndex] = node; table = newTable; }
这个方法流程还是很清晰的,梳理如下:
这里首先对oldCapacity进行了遍历,对每一个tab[i]进行链表遍历确定其在新数组的位置。
在链表遍历过程中会尝试找到index发生变化的lastIdx与lastRun,通过newTable[lastIdx] = lastRun;代码把lastRun及以后的节点指向 lastIdx位置。
然后再处理节点e到节点lastRun的节点,采用头插法插入到newTable[k]位置。
【5】统计元素个数size
也就是统计map中key-value键值对的个数。如果map包含的元素个数超过了Integer.MAX_VALUE,那么就返回Integer.MAX_VALUE。
尝试几次以获得准确的计数。如果由于表中的连续异步更改而导致失败,则求助于锁定(也就是会锁住全部Segment)。
public int size() { // Try a few times to get accurate count. On failure due to // continuous async changes in table, resort to locking. final Segment<K,V>[] segments = this.segments; int size; boolean overflow; // true if size overflows 32 bits long sum; // sum of modCounts long last = 0L; // previous sum int retries = -1; // first iteration isn't retry try { // 无限循环 for (;;) { // 如果尝试次数达到了RETRIES_BEFORE_LOCK ,就将每一个segment加锁 // 先拿retries与RETRIES_BEFORE_LOCK进行==判断,然后retries+1 if (retries++ == RETRIES_BEFORE_LOCK) { for (int j = 0; j < segments.length; ++j) ensureSegment(j).lock(); // force creation } sum = 0L; size = 0; overflow = false; for (int j = 0; j < segments.length; ++j) { Segment<K,V> seg = segmentAt(segments, j); if (seg != null) { sum += seg.modCount;//结构修改次数 int c = seg.count; //元素个数 // 判断是否溢出 if (c < 0 || (size += c) < 0) overflow = true; } } // 如果前后一致,break,否则就更新last为当前sum if (sum == last) break; last = sum; } } finally { // 解锁 if (retries > RETRIES_BEFORE_LOCK) { for (int j = 0; j < segments.length; ++j) segmentAt(segments, j).unlock(); } } // 返回size ,如果溢出了,就返回Integer.MAX_VALUE=2^31-1 return overflow ? Integer.MAX_VALUE : size; }
也就是说首先直接将所有的 Segment 不加锁, 直接统计数量。统计过程中同时对每个 Segment 的 modCount 进行加总(modCount 记录了每个 Segment 被修改的次数)。重复上面的过程, 然后比较前后两次 modCount 总和是否一样, 相等就说明中间没有线程更改过结构(比如添加或者移除), 直接返回得到的 size 大小即可。
如果重试次数达到了3次,也就是总共循环了四次,那么直接将所有的Segment加锁进行元素数量统计。