看看不一样的ConcurrentHashMap

HashMap是Java中常见的数据结构,它结合了数组和链表的特点,查找和增删改操作均十分高效,但HashMap不适合在多线程环境下使用(非线程安全的集合),在多线程下对HashMap进行操作可能出现各种问题:

  • 多线程put的时候可能导致元素丢失
  • JDK 1.8 扩容采用的是“头插法”,在高并发下会出现死循环
  • HashMap 在并发下存在数据覆盖、遍历的同时进行修改会抛出 ConcurrentModificationException 异常等问题

HashMap虽好,但不适合在多线程环境下使用,那有线程安全的HashMap吗,答案是有的,JDK为我们提供了三种线程安全的HashMap:

  1. HashTable
  2. synchronizedMap
  3. ConcurrentHashMap

synchronizedMap是通过Collections.synchronizedMap()方法得到的,它将所有Map操作通过synchronized块进行修饰,与HashTable类似(HashTable是在所有方法加上synchronized关键字),这两个Map在多线程环境下是线程安全的,但他们的并发性能很差,同一时刻只能有一个线程进行操作,效率太低。
为了解决HashMap线程不安全以及synchronizedMap和HashTable并发效率低下的问题,Doug Lea大师为我们准备了兼具高效和安全的HashMap --> ConcurrentHashMap
本文就来讲讲ConcurrentHashMap是如何做到高效和安全的,由于ConcurrentHashMap在JDK1.7 和JDK1.8的实现不同,本文就分别介绍这两个版本的实现原理
图片说明

JDK1.7的ConcurrentHashMap

内部结构

ConcurrentHashMap内部是由Segments数组结构和HashEntry数组结构组成,Segment是一种可重入锁(继承自ReentrantLock),在ConcurrentHashMap扮演锁的角色;HashEntry则是真正用于存储数据的数据结构。一个ConcurrentHashMap中包含一个Segments数组,一个Segment中包含一个HashEntry数组,HashEntry是一个链表结构,所以Segment是一个散列表的结构(与HashMap类似)。ConcurrentHashMap的结构如图所示:
图片说明
Segments数组实现了分段锁,对ConcurrentHashMap进行访问时需要获取对应Segment的锁,这样多线程在访问容器不同数据段中的数据时就不会互相影响,线程之间锁竞争就会大大减小,从而提高了并发效率,同时也能保证安全访问数据

Segment相关代码如下

    /**
     * 段是哈希表的专用版本。该子类是ReentrantLock的子类,为了简化一些锁并避免单独构造。
     */
    static final class Segment<K,V> extends ReentrantLock implements Serializable {

        private static final long serialVersionUID = 2249069246763182397L;

        /**
         * 在可能阻塞获取以准备锁定的段操作之前,尝试在预扫描中尝试锁定的最大次数。
         * 在多处理器上,使用有限数量的重试可以维护在定位节点时获取的缓存。
         */
        static final int MAX_SCAN_RETRIES =
            Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

        /**
         * 每个段。元素是通过entryAtsetEntryAt访问的,提供可变的语义。
         */
        transient volatile HashEntry<K,V>[] table;

        /**
         * 元素数。
         * 仅在锁内或其他保持可见性的易失性读取中访问。
         */
        transient int count;

        /**
         * HashEntry的操作总数。即使这可能溢出32位,它也为CHM isEmpty()和size()方法中的稳定性检查提供了足够的准确性。
         * 仅在锁内或其他保持可见性的易失性读取中访问。
         */
        transient int modCount;

        /**
         * 当表的大小超过此阈值时,将对其进行扩容并重新哈希处理。 
         * 此字段的值始终为(int)(capacity * loadFactor)
         */
        transient int threshold;

        /**
         * 哈希表的负载因子。即使所有段的该值都相同,也将复制该值以避免需要链接到外部对象。
         */
        final float loadFactor;

        Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
            this.loadFactor = lf;
            this.threshold = threshold;
            this.table = tab;
        }
    }

HashEntry的相关代码如下:

    /**
     * ConcurrentHashMap列表条目。它永远不会导出为用户可见的Map.Entry
     */
    static final class HashEntry<K,V> {
        final int hash;
        final K key;
        volatile V value;
        volatile HashEntry<K,V> next;

        HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.value = value;
            this.next = next;
        }

        /**
         * 使用易失性写语义设置下一个字段。
         */
        final void setNext(HashEntry<K,V> n) {
            UNSAFE.putOrderedObject(this, nextOffset, n);
        }

        // Unsafe mechanics
        static final sun.misc.Unsafe UNSAFE;
        static final long nextOffset;
        static {
            try {
                UNSAFE = sun.misc.Unsafe.getUnsafe();
                Class k = HashEntry.class;
                nextOffset = UNSAFE.objectFieldOffset
                    (k.getDeclaredField("next"));
            } catch (Exception e) {
                throw new Error(e);
            }
        }
    }

构造函数

先来看看几个重要的参数:

    /**
     * 默认初始容量,在没有在构造函数中另外指定时使用
     */
    static final int DEFAULT_INITIAL_CAPACITY = 16;

    /**
     * 默认加载因子,在没有在构造函数中另外指定时使用
     */
    static final float DEFAULT_LOAD_FACTOR = 0.75f;

    /**
     * 默认并发级别,在没有在构造函数中另外指定时使用。
     */
    static final int DEFAULT_CONCURRENCY_LEVEL = 16;

    /**
     * 最大容量,如果两个构造函数都使用参数隐式指定了更高的值,则使用该容量。
     * 必须是2的幂且小于等于 1 << 30,以确保条目可以使用int进行索引
     */
    static final int MAXIMUM_CAPACITY = 1 << 30;

    /**
     * 每段表的最小容量。必须为2的幂,至少为2的幂,以免在延迟构造后立即调整下次使用时的大小。
     */
    static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

    /**
     * 允许的最大段数;用于绑定构造函数参数。必须是小于1 << 24的2的幂。
     */
    static final int MAX_SEGMENTS = 1 << 16; // slightly conservative

    /**
     * 在锁定整个表之前,size和containsValue()方法的不同步重试次数。
     * 如果表进行连续修改,这将用于避免无限制的重试,这将导致无法获得准确的结果。
     */
    static final int RETRIES_BEFORE_LOCK = 2;

    /**
     * 用于编入段的掩码值。密钥的哈希码的高位用于选择段。
     */
    final int segmentMask;

    /**
     * 段内索引的移位值。
     */
    final int segmentShift;

    /**
     * 段,每个段都是一个专用的哈希表
     */
    final Segment<K,V>[] segments;

上面的参数知道大概就行,在接下来的代码中就能理解这些参数的作用的,接下来看看构造方法:

    /**
     * 使用指定的初始容量,负载因子和并发级别创建一个新的空映射。
     *
     * @param 初始容量。该实现执行内部大小调整以容纳许多元素。
     * 
     * @param loadFactor  负载系数阈值,用于控制调整大小。
     * 当每个仓的平均元素数超过此阈值时,可以执行大小调整。
     * 
     * @param concurrencyLevel 估计的并发更新线程数。该实现执行内部大小调整以尝试容纳这么多线程。
     * 
     * @throws IllegalArgumentException 如果初始容量为负,或者负载因子或concurrencyLevel为非正数。
     */
    @SuppressWarnings("unchecked")
    public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
        //对非法输入进行处理
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        // 若并发线程数大于最大段数,则等于最大段数    
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
        // 为保证能通过位与运算的散列算法来定位segments数组索引,要保证数组长度为2的幂,查找最适合参数的二乘幂
        int sshift = 0;
        int ssize = 1;
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        this.segmentShift = 32 - sshift;
        this.segmentMask = ssize - 1;

        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        //初始化每个segment中的HashEntry长度
        int c = initialCapacity / ssize;
        //如果c大于1,cap会取大于等于c的2次方,所以cap要么等于1要么等于2的幂次方
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
        // 创建segments数组,并初始化segments[0]
        Segment<K,V> s0 = new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
        UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
        this.segments = ss;
    }

    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(int initialCapacity) {
        this(initialCapacity, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
    }

    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this(Math.max((int) (m.size() / DEFAULT_LOAD_FACTOR) + 1, DEFAULT_INITIAL_CAPACITY), DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
        putAll(m);
    }

由上述代码可以看到ConcurrentHashMap的构造方法最终会调用到同一个方法,只要集中看这个方法即可。
ConcurrentHashMap初始化的时候做了以下几个事情:

  1. 处理非法输入
  2. 计算segmentShiftsegmentMask
  3. 计算每个Segments中HashEntry的容量cap
  4. 初始化segments数组同时初始化Segment,并加入到segments[0]的位置上

segments数组的大小ssize是通过concurrencyLevel来计算的,为保证能通过位与运算的散列算法来定位segments数组索引,要保证数组长度为2的幂,需要计算一个大于或等于concurrencyLevel的最小2的N次方值来作为数组的长度。
segmentShift用于定位参与散列运算的位数,segmentShift等于32 - sshift,使用32是因为ConcurrentHashMap的hash() 方法输出值最大是32位的。
segmentMask是散列运算的掩码,segmentMask记录的是ssize - 1的值,ssize是2的幂次方,所以segmentMask每个二进制位都是1.
这两个参数有些难理解,后面讲到get() 方法时就明白这两个参数的含义了。

get()方法

话不多说,先上代码

public V get(Object key) {
        Segment<K,V> s; // manually integrate access methods to reduce overhead
        HashEntry<K,V>[] tab;
        //将key的hashcode进行再散列,减少hash冲突
        int h = hash(key);
        //散列算法,定位元素在segments数组的位置
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        //Segment不为空且Segment内部的HashEntry不为空,则继续查找
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null && (tab = s.table) != null) {
            //散列算法定位HashEntry,遍历HashEntry,直到找到对应key,没有则退出循环
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {
                K k;
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

get() 方法主要做了几件事:

  1. 通过散列计算获取元素在segments数组的位置
  2. 获取对应的Segment,若不为空则继续则通过散列算法获取Segment内部的HashEntry数组对应的HashEntry
  3. 若HashEntry不为空则遍历HashEntry找到key对应的值

get() 操作之所以高效是因为ConcurrentHashMap中的共享变量都被定义成volatile类型,从Segment和HashEntry的成员变量中可以看出,这样做能保证所有线程都能看到最新的值,根据Java内存模型的happen before 原则,对volatile类型的写入操作先于读取操作,即两个线程同时修改和获取volatile变量,get操作也能得到最新值,这样get() 无需加锁,进而提高了并发效率。

get() 方法并发高效的原因知道了,那ConcurrentHashMap是如何定位元素位置的呢?get() 方法中出现了hash() 方法,先看一下hash() 方法

    private int hash(Object k) {
        int h = hashSeed;

        if ((0 != h) && (k instanceof String)) {
            return sun.misc.Hashing.stringHash32((String) k);
        }

        h ^= k.hashCode();

        // 使用单字Wang/Jenkins哈希的变体来扩展位,以规范化段和索引位置。
        h += (h <<  15) ^ 0xffffcd7d;
        h ^= (h >>> 10);
        h += (h <<   3);
        h ^= (h >>>  6);
        h += (h <<   2) + (h << 14);
        return h ^ (h >>> 16);
    }

这个方法官方的解释是应用于hashcode上,使hashcode的高位和低位充分参与散列算法,减少散列冲突。这么做可以使元素均匀分布在不同的Segment中从而提高存取效率。如果散列的质量差到极点,那所有的元素都会位于同一个Segment中,分段锁的意义就没了,而且并发效率会大大降低。
让hashcode充分散列之后就可以使用散列算法定位元素的位置了,定位HashEntry和定位Segment的算法是一样,但有些细微的差别

//定位Segment
(h >>> segmentShift) & segmentMask

//定位HashEntry
(tab.length - 1) & h

上面说过segmentMask记录的是ssize-1的值,即segments长度-1的值,segmentShift是散列值向右偏移的位数,即实际是向右偏移,让散列值的高位参与运算;定位Segment是用再散列的值的高位进行运算,而定位HashEntry则是用再散列的值直接进行运算,这么做的目的是避免两次散列后的值一样,使元素在Segment散列开了,而没有在HashEntry内部散列开,进而增加了冲突的可能。

看完了get() 方法,接下来看看put() 方法,看看大师的手法o( ̄▽ ̄)o

put()方法

废话不多说,还是先看代码

    public V put(K key, V value) {
        Segment<K,V> s;
        if (value == null)
            throw new NullPointerException();
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        //易失性,在ensureSegment重新检查
        if ((s = (Segment<K,V>)UNSAFE.getObject(segments, (j << SSHIFT) + SBASE)) == null) 
            s = ensureSegment(j);
        return s.put(key, hash, value, false);
    }

    private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        //Segment不存在,新建一个
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
            //使用segments[0]的参数创建新的Segment
            Segment<K,V> proto = ss[0]; // use segment 0 as prototype
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            //二次检查,防止其他线程先创建了Segment,而覆盖其创建的Segment
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { 
                // recheck
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
                    //通过CAS将新建的Segment加到segments数组中
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

put() 方法做了几件事:

  1. 通过散列算法获取key对应的Segment的下标
  2. 获取该Segment,若Segment不存在则调用ensureSegment() 方法重新确认 (创建Segment在这个方法中,Segment的参数从Segment[0]获取)
  3. 调用Segment的put() 将元素加入到Segment中

Segment.put()

put() 方法较为简单(其他的可以看看注释),详细的接下来看Segment中put() 方法,这个方法才是核心

    /**
     * ConcurrentHashMap.Segment
     */
    final V put(K key, int hash, V value, boolean onlyIfAbsent) {
        // 尝试锁定,若无法获取锁,则先扫描Segment并不断尝试获取锁直到成功
        HashEntry<K,V> node = tryLock() ? null : scanAndLockForPut(key, hash, value);
        V oldValue;
        try {
            HashEntry<K,V>[] tab = table;
            int index = (tab.length - 1) & hash;
            // 查找key对应的HashEntry
            HashEntry<K,V> first = entryAt(tab, index);
            // 遍历查找key对应的元素所在位置
            for (HashEntry<K,V> e = first;;) {
                // 如果查找的HashEntry不为空,即存在hash冲突,则先比较key是否相同,
                // 相同则根据onlyIfAbsent决定是否要更新,不相同则将指针指向next,继续查找,
                // 直到找到相同的key或是遍历到链表最后一个节点
                if (e != null) {
                    K k;
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        oldValue = e.value;
                        //如果不允许修改已经存在的元素,则跳过
                        if (!onlyIfAbsent) {
                            e.value = value;
                            ++modCount;
                        }
                        break;
                    }
                    e = e.next;
                }
                else {
                    // 在等待获取锁时扫描Segment,并可能会返回一个新的HashEntry
                    // 节点不为空则将该节点的next指向first节点
                    // 节点为空则创建一个
                    if (node != null)
                        node.setNext(first);
                    else
                        node = new HashEntry<K,V>(hash, key, value, first);
                    int c = count + 1;
                    // 若节点超过了阈值,则需要扩容,并将节点插入到新的HashEntry数组中
                    if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                        rehash(node);
                    else
                        setEntryAt(tab, index, node);
                    ++modCount;
                    count = c;
                    oldValue = null;
                    break;
                }
            }
        } finally {
            // 释放锁
            unlock();
        }
        return oldValue;
    }

Segment中的put() 方法主要做了这几件事:
4. 尝试获取当前Segment的锁,若获取成功则继续下面的操作;若未能获取到锁则扫描Segment中的HashEntry同时不断尝试获取锁
5. 计算要插入元素在HashEntrys数组中的下标,通过一个死循环来不断查找要插入的位置,将元素插入到HashEntry对应的位置中:若定位的位置存在节点则遍历该节点找到key相同的节点,并替换对应的值,若没有相同的key则新建一个HashEntry对象并插入到HashEntry链表头部;若定位的位置不存在节点,则新建一个HashEntry对象
6. 统计当前节点总数是否超过阈值,超过则执行扩容操作
7. 将新建的HashEntry对象插入到对应位置,然后释放锁

为了防止当前线程未能获取到锁而阻塞的情况,put() 方法让当前线程扫描Segment中的HashEntry,并不断尝试获取锁,即让当前线程空转,来看一下scanAndLockForPut() 方法:

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
    HashEntry<K,V> first = entryForHash(this, hash);
    HashEntry<K,V> e = first;
    HashEntry<K,V> node = null;
    // 定位节点时为负
    int retries = -1; 
    while (!tryLock()) {
        HashEntry<K,V> f; 
        // 在下面重新检查
        if (retries < 0) {
            if (e == null) {
                // 推测性的创建节点,并让当前线程空转直到获取到锁
                if (node == null) 
                    node = new HashEntry<K,V>(hash, key, value, null);
                retries = 0;
            }
            // 存在key相同节点,则只让当前线程空转直到获取到锁
            else if (key.equals(e.key))
                retries = 0;
            else
                e = e.next;
        }
        // 循环次数+1,若循环次数超过上限值直接获取锁,获取锁之后退出循环
        else if (++retries > MAX_SCAN_RETRIES) {
            lock();
            break;
        }
        // 每循环两次对定位的HashEntry进行检测是否发生变化,发生变化则需要重新遍历
        else if ((retries & 1) == 0 && (f = entryForHash(this, hash)) != first) {
            // 当前HashEntry发生变化,需要重新遍历
            e = first = f; 
            retries = -1;
        }
    }
    return node;
}

对于这个方法官方的解释是这样的:

/**
 * Scans for a node containing given key while trying to
 * acquire lock, creating and returning one if not found. Upon
 * return, guarantees that lock is held. UNlike in most
 * methods, calls to method equals are not screened: Since
 * traversal speed doesn't matter, we might as well help warm
 * up the associated code and accesses as well.
 *
 * @return a new node if key not found, else null
 */

/**
 * 在尝试获取锁的同时扫描包含给定密钥的节点,如果找不到,则创建并返回一个节点。返回时,保证锁定被保持。
 * 与大多数方法不同,不筛选对方法等于的调用:由于遍历速度并不重要,因此我们也可能会帮助预热关联的代码和访问。 
 * 如果找不到密钥,则返回一个新节点,否则返回null
 */

这个方法让未能获取锁的线程空转同时扫描HashEntry数组,起到预热的效果。

rehash()

当新建一个HashEntry以后会先进行扩容判断,若节点总数超过了阈值且HashEntry数组长度小于最大值,则进行扩容,扩容源码如下:

    /**
     * 将表的大小加倍并重新打包条目,还将给定的节点添加到新表中
     */
    @SuppressWarnings("unchecked")
    private void rehash(HashEntry<K,V> node) {
        /*
         * 将每个列表中的节点重新分类为新表。
         * 因为我们使用的是2的幂次展开,所以每个桶中的元素必须保持相同的索引或以2的幂进行偏移。
         * 通过捕获旧节点可以重复使用的情况(因为它们的下一个字段不会更改),我们消除了不必要的节点创建。
         * 统计上,在默认阈值下,当表加倍时,只有大约六分之一需要克隆。
         * 一旦它们被并发遍历表中的任何读取器线程不再引用,它们替换的节点将立即被垃圾回收。
         * 条目访问使用纯数组索引,因为它们后面是易失表写入。
         */
        HashEntry<K,V>[] oldTable = table;
        int oldCapacity = oldTable.length;
        int newCapacity = oldCapacity << 1;
        threshold = (int)(newCapacity * loadFactor);
        HashEntry<K,V>[] newTable =
            (HashEntry<K,V>[]) new HashEntry[newCapacity];
        int sizeMask = newCapacity - 1;
        for (int i = 0; i < oldCapacity ; i++) {
            HashEntry<K,V> e = oldTable[i];
            if (e != null) {
                HashEntry<K,V> next = e.next;
                int idx = e.hash & sizeMask;
                // 单节点的情况
                if (next == null)  
                    // 直接将原节点复制过去
                    newTable[idx] = e;
                else { 
                    // next不为空
                    // 重用同一插槽中的连续序列
                    HashEntry<K,V> lastRun = e;
                    int lastIdx = idx;
                    for (HashEntry<K,V> last = next; last != null; last = last.next) {
                        int k = last.hash & sizeMask;
                        if (k != lastIdx) {
                            lastIdx = k;
                            lastRun = last;
                        }
                    }
                    newTable[lastIdx] = lastRun;
                    // 复制其他节点
                    for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                        V v = p.value;
                        int h = p.hash;
                        int k = h & sizeMask;
                        HashEntry<K,V> n = newTable[k];
                        newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                    }
                }
            }
        }
        // 添加新节点
        int nodeIndex = node.hash & sizeMask; 
        node.setNext(newTable[nodeIndex]);
        newTable[nodeIndex] = node;
        table = newTable;
    }

rehash() 方法主要做了几件事:

  1. 创建新的HashEntry数组大小为原来的2倍
  2. 遍历原数组,重新计算原数组的元素在新数组的下标,并将元素复制到对应的位置上,若是遇到元素后面还有元素的,先计算复用属于同一插槽的同意序列,再复制其他节点到新数组
  3. 最后将新节点插入到扩容后的新数组中

这样ConcurrentHashMap整个put过程就梳理完成了,上一张图总结一下put() 做了哪些事
图片说明

size()方法

除了get()put() 方法,ConcurrentHashMap中的size()方法也是比较特殊的,来看看它的实现

    public int size() {
        // 尝试几次以获得准确的计数。如果由于表中的持续异步更改而导致失败,请使用锁定。
        final Segment<K,V>[] segments = this.segments;
        int size;
        // 如果大小溢出32位,则返回true
        boolean overflow; 
        // modCounts的总和
        long sum;         
        // 之前的总和
        long last = 0L;   
        // 不加锁重试次数
        int retries = -1; 
        try {
            for (;;) {
                // 每次不加锁读取每个Segment中的count,若重试次数超过两次则将所有Segment加锁,然后统计总数
                if (retries++ == RETRIES_BEFORE_LOCK) {
                    for (int j = 0; j < segments.length; ++j)
                        ensureSegment(j).lock(); 
                }
                sum = 0L;
                size = 0;
                overflow = false;
                // 统计每个Segment中的元素总数和modCount
                for (int j = 0; j < segments.length; ++j) {
                    Segment<K,V> seg = segmentAt(segments, j);
                    if (seg != null) {
                        sum += seg.modCount;
                        int c = seg.count;
                        if (c < 0 || (size += c) < 0)
                            overflow = true;
                    }
                }
                if (sum == last)
                    break;
                last = sum;
            }
        } finally {
            // 释放锁
            if (retries > RETRIES_BEFORE_LOCK) {
                for (int j = 0; j < segments.length; ++j)
                    segmentAt(segments, j).unlock();
            }
        }
        return overflow ? Integer.MAX_VALUE : size;
    }

size() 方法先尝试不加锁遍历两次统计每个Segment中元素总数,若在两次遍历中modCount发生了变化,则两次统计的modCount也是不一样的(对ConcurrentHashMap修改一次,就会让modCount+1),此时就会将所有Segment锁住,然后再统计元素总数。
这么做的原因是,在累加count的过程中,之前累加过的count发生变化的概率比较小,这样可以减小加锁的开销以及加锁导致其他线程被阻塞导致效率低下(当然在写多的环境下,还是要加锁重读才能保证准确性)

JDK1.7的ConcurrentHashMap主要方法分析完了,其他部分这里没有介绍,感兴趣的可以去看源码,源码里有相关的注释,理解起来也方便。

JDK1.8的ConcurrentHashMap

JDK1.8的ConcurrentHashMap相比JDK1.7的ConcurrentHashMap做了很大改动,弃用了Segment这一个内部对象,改成了CAS+synchronized来提高并发效率和安全性,同时内部还用了红黑树来提高查找效率。接下来看看JDK1.8的ConcurrentHashMap到底做了哪些改动吧。

内部结构

由于废弃了JDK1.7 Segment+HashEntry的结构,JDK1.8的ConcurrentHashMap的内部结构和JDK1.8的HashMap结构一样,只是对于读写操作有一些不同。JDK1.8的ConcurrentHashMap的内部结构图如下:
ConcurrentHashMap
ConcurrentHashMap内部除了Node内部类,还定义了TreeNode、TreeBin、ForwardingNode这几个内部类
来看一下它们的定义:
先看一下定义的参数

    /*
     * 节点哈希字段的编码
     */
    // 当前节点正在被转移
    static final int MOVED     = -1; 
    // 当前节点是树节点
    static final int TREEBIN   = -2;
    // 临时保留的hash值 
    static final int RESERVED  = -3; 
    // 普通节点哈希的可用位
    static final int HASH_BITS = 0x7fffffff;

Node

    /**
     * 键值输入。此类永远不会导出为可变的Map.Entry,但可以用于只读遍历。
     * 具有负哈希字段的Node的子类是特殊的,并且包含空键和值(但从不导出)。
     * 键和值永远不会为null。
     */
    static class Node<K,V> implements Map.Entry<K,V> {
        final int hash;
        final K key;
        volatile V val;
        volatile Node<K,V> next;

        Node(int hash, K key, V val, Node<K,V> next) {
            this.hash = hash;
            this.key = key;
            this.val = val;
            this.next = next;
        }

        public final K getKey()       { return key; }
        public final V getValue()     { return val; }
        public final int hashCode()   { return key.hashCode() ^ val.hashCode(); }
        public final String toString(){ return key + "=" + val; }
        public final V setValue(V value) {
            throw new UnsupportedOperationException();
        }

        public final boolean equals(Object o) {
            Object k, v, u; Map.Entry<?,?> e;
            return ((o instanceof Map.Entry) &&
                    (k = (e = (Map.Entry<?,?>)o).getKey()) != null &&
                    (v = e.getValue()) != null &&
                    (k == key || k.equals(key)) &&
                    (v == (u = val) || v.equals(u)));
        }

        /**
         * 对map.get()的虚拟支持;在子类中重写。
         */
        Node<K,V> find(int h, Object k) {
            Node<K,V> e = this;
            if (k != null) {
                do {
                    K ek;
                    if (e.hash == h &&
                        ((ek = e.key) == k || (ek != null && k.equals(ek))))
                        return e;
                } while ((e = e.next) != null);
            }
            return null;
        }
    }

TreeBin

    static final class TreeBin<K,V> extends Node<K,V> {
        TreeNode<K,V> root;
        volatile TreeNode<K,V> first;
        volatile Thread waiter;
        volatile int lockState;
        // 锁状态
        // 写入时
        static final int WRITER = 1; 
        // 等待获取锁
        static final int WAITER = 2; 
        // 设置读取锁定的增量值
        static final int READER = 4; 

        /***省略部分代码***/

        TreeBin(TreeNode<K,V> b) {
            super(TREEBIN, null, null, null);
            this.first = b;
            TreeNode<K,V> r = null;
            for (TreeNode<K,V> x = b, next; x != null; x = next) {
                next = (TreeNode<K,V>)x.next;
                x.left = x.right = null;
                if (r == null) {
                    x.parent = null;
                    x.red = false;
                    r = x;
                }
                else {
                    K k = x.key;
                    int h = x.hash;
                    Class<?> kc = null;
                    for (TreeNode<K,V> p = r;;) {
                        int dir, ph;
                        K pk = p.key;
                        if ((ph = p.hash) > h)
                            dir = -1;
                        else if (ph < h)
                            dir = 1;
                        else if ((kc == null &&
                                  (kc = comparableClassFor(k)) == null) ||
                                 (dir = compareComparables(kc, k, pk)) == 0)
                            dir = tieBreakOrder(k, pk);
                            TreeNode<K,V> xp = p;
                        if ((p = (dir <= 0) ? p.left : p.right) == null) {
                            x.parent = xp;
                            if (dir <= 0)
                                xp.left = x;
                            else
                                xp.right = x;
                            r = balanceInsertion(r, x);
                            break;
                        }
                    }
                }
            }
            this.root = r;
            assert checkInvariants(root);
        }
        /***省略部分代码***/
    }

TreeNode

    static final class TreeNode<K,V> extends Node<K,V> {
        TreeNode<K,V> parent;  // red-black tree links
        TreeNode<K,V> left;
        TreeNode<K,V> right;
        TreeNode<K,V> prev;    // needed to unlink next upon deletion
        boolean red;

        TreeNode(int hash, K key, V val, Node<K,V> next,
                 TreeNode<K,V> parent) {
            super(hash, key, val, next);
            this.parent = parent;
        }
        /***省略部分代码***/
    }

ForwardingNode

    static final class ForwardingNode<K,V> extends Node<K,V> {
        final Node<K,V>[] nextTable;
        ForwardingNode(Node<K,V>[] tab) {
            super(MOVED, null, null, null);
            this.nextTable = tab;
        }
    }

从代码上来看,TreeNode、TreeBin和ForwardingNode都是Node的子类,TreeNode是红黑树的真实节点,TreeBin是封装了多个TreeNode的红黑树,所有关于红黑树的操作都由TreeBin实现。ForwardingNode是表在进行扩容时,用来标记需要移动的节点用的。ForwardingNode和TreeBin的hash均为负值,这么做的目的是在进行查询操作的时候可以将这些特殊情况一起处理(看get() 方法可以理解)

构造方法

先来看一下几个重要的参数:

    /**
     * 默认为null,初始化发生在第一次插入操作,默认大小为16的数组,
     * 用来存储Node节点数据,扩容时大小总是2的幂次方。
     */
    transient volatile Node<K,V>[] table;

    /**
     * 默认为null,扩容时新生成的数组,其大小为原数组的两倍
     */
    private transient volatile Node<K,V>[] nextTable;

    /**
     * 基本计数器值,主要在没有竞争时使用,通过CAS更新。
     */
    private transient volatile long baseCount;

    /**
     * 表初始化和大小调整控制。
     * 如果为负,则表将被初始化或调整大小:-1表示在初始化,-N表示有N-1个线程正在进行扩容操作。
     * 否则,当table为null时,保留创建时要使用的初始表大小,或者默认为0。
     * 初始化后,保留下一个要调整表大小的元素计数值
     */
    private transient volatile int sizeCtl;

    /**
     * 调整大小时要拆分的下一个表索引(加1)。
     */
    private transient volatile int transferIndex;

    /**
    * 因为sizeCtl处理了多种状态,需要其他的属性协助参与
    */
    // 扩容时用到,通过sizeCtl生成标记,回归标记状态,说明扩容完毕,状态有效位为16位
    private static int RESIZE_STAMP_BITS = 16;
    // 扩容线程最大数量,为 2^16,
    private static final int MAX_RESIZERS = (1 << (32 - RESIZE_STAMP_BITS)) - 1;
    // 扩容时用到,还原sizeCtl记录的扩容标记
    private static final int RESIZE_STAMP_SHIFT = 32 - RESIZE_STAMP_BITS;

ConcurrentHashMap使用sizeCtl这个参数来控制并发扩容的操作,这个参数讲到put() 方法时会看到是它是如何控制扩容的。
接下来看看ConcurrentHashMap的构造方法:

    /**
     * 默认构造函数,构造参数都是使用默认的
     */
    public ConcurrentHashMap() {
    }

    public ConcurrentHashMap(int initialCapacity) {
        if (initialCapacity < 0)
            throw new IllegalArgumentException();
        int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
                   MAXIMUM_CAPACITY :
                   tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
        this.sizeCtl = cap;
    }

    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, 1);
    }

    public ConcurrentHashMap(int initialCapacity, float loadFactor, int concurrencyLevel) {
        if (!(loadFactor > 0.0f) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        // 指定并发线程数
        if (initialCapacity < concurrencyLevel) 
            initialCapacity = concurrencyLevel;  
        long size = (long)(1.0 + (long)initialCapacity / loadFactor);
        int cap = (size >= (long)MAXIMUM_CAPACITY) ?
            MAXIMUM_CAPACITY : tableSizeFor((int)size);
        this.sizeCtl = cap;
    }

    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this.sizeCtl = DEFAULT_CAPACITY;
        putAll(m);
    }

    public void putAll(Map<? extends K, ? extends V> m) {
        // 尝试调整表的大小来放入Map中所有元素
        tryPresize(m.size());
        for (Map.Entry<? extends K, ? extends V> e : m.entrySet())
            putVal(e.getKey(), e.getValue(), false);
    }

    private static final int tableSizeFor(int c) {
        int n = c - 1;
        n |= n >>> 1;
        n |= n >>> 2;
        n |= n >>> 4;
        n |= n >>> 8;
        n |= n >>> 16;
        return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
    }

ConcurrentHashMap的构造方法没有做太多的事,只是简单设置了一下参数,通过tableSizeFor() 保证了内部表的长度是2的幂次方,到这里与HashMap的构造方法十分相似。
看完了构造函数,接下来看看get() 做了什么

get()方法

话不多说先来看看源码

    public V get(Object key) {
        Node<K,V>[] tab; Node<K,V> e, p; int n, eh; K ek;
        int h = spread(key.hashCode());
        if ((tab = table) != null && (n = tab.length) > 0 && (e = tabAt(tab, (n - 1) & h)) != null) {
            // 若所在位置存在元素且key相同则返回对应的val
            if ((eh = e.hash) == h) {
                if ((ek = e.key) == key || (ek != null && key.equals(ek)))
                    return e.val;
            }
            // 当节点在TreeBin中或者正在扩容时,节点的hash值会被设置成负数 (分析1)
            else if (eh < 0)
                return (p = e.find(h, key)) != null ? p.val : null;
            // 若元素存在链表中,遍历查找    
            while ((e = e.next) != null) {
                if (e.hash == h && ((ek = e.key) == key || (ek != null && key.equals(ek))))
                    return e.val;
            }
        }
        return null;
    }

从源码来看,ConcurrentHashMap的get() 方法与HashMap的get() 方法差异较大,这里分析一下:

  1. 通过spread() 方法对hashcode再散列
  2. 获取table,若为空则直接返回null
  3. 通过散列算法定位key在table中的位置,
    1). 若所在位置存在元素且key相同则返回对应的val
    2). 若所在位置存在元素,但元素的hash值为负数,说明这个元素在TreeBin中或者是表正在扩容,则调用find() 方法去找元素
    3). 若所在位置存在元素,但key不相同,则说明元素在链表中,则遍历链表查找元素

Node.find()

get() 方法实际上理解起来还是很轻松的,这里有点让人费解是分析1的代码,为什么不是在TreeNode中查找,这是因为一方面在查找的过程中也可能会发生扩容(概率较小),所以需要标记一下,防止脏读,另一方面表的默认结构是数组+链表,除非链表上的节点超过阈值(默认为8)才会转成红黑树,所以ConcurrentHashMap中红黑树其实很少(甚至可能没有),综合这两种特殊情况就直接合并到一起,定义find() 方法来查找元素(当然find() 的实现是不同的)
在上面我们可以知道ForwardingNode、TreeNode、TreeBin都是属于Node的子类,Node中虽然定义了find() 方法,但只是实现了Node链表的遍历查找,对于特殊情况需要子类自己实现,来看看这三个子类是如何实现find() 方法的
1.ForwardingNode的find()

Node<K,V> find(int h, Object k) {
    // 循环,以避免在转发节点上任意深度递归
    outer: for (Node<K,V>[] tab = nextTable;;) {
        Node<K,V> e; int n;
        if (k == null || tab == null || (n = tab.length) == 0 ||
            (e = tabAt(tab, (n - 1) & h)) == null)
            return null;
        for (;;) {
            int eh; K ek;
            if ((eh = e.hash) == h &&
                ((ek = e.key) == k || (ek != null && k.equals(ek))))
                return e;
            if (eh < 0) {
                // 节点为ForwardingNode则指向扩容后的新数组,在新数组中查找元素
                if (e instanceof ForwardingNode) {
                    tab = ((ForwardingNode<K,V>)e).nextTable;
                    continue outer;
                }
                else
                    // 属于TreeBin的情况
                    return e.find(h, k);
            }
            if ((e = e.next) == null)
                return null;
        }
    }
}

2.TreeNode的find()

Node<K,V> find(int h, Object k) {
    return findTreeNode(h, k, null);
}

/**
 * 从给定的根开始,返回给定键的TreeNode(如果找不到,则返回null)。
 * 真正的查询操作
 */
final TreeNode<K,V> findTreeNode(int h, Object k, Class<?> kc) {
    if (k != null) {
        TreeNode<K,V> p = this;
        do  {
            int ph, dir; K pk; TreeNode<K,V> q;
            TreeNode<K,V> pl = p.left, pr = p.right;
            if ((ph = p.hash) > h)
                p = pl;
            else if (ph < h)
                p = pr;
            else if ((pk = p.key) == k || (pk != null && k.equals(pk)))
                return p;
            else if (pl == null)
                p = pr;
            else if (pr == null)
                p = pl;
            else if ((kc != null ||
                      (kc = comparableClassFor(k)) != null) &&
                     (dir = compareComparables(kc, k, pk)) != 0)
                p = (dir < 0) ? pl : pr;
            else if ((q = pr.findTreeNode(h, k, kc)) != null)
                return q;
            else
                p = pl;
        } while (p != null);
    }
    return null;
}

3.TreeBin的find()

/**
 * 返回匹配的节点;如果没有,则返回null。
 * 尝试从树的根开始比较进行搜索,但是在锁不可用时继续线性搜索。
 * 封装了findTreeNode()方法
 */
final Node<K,V> find(int h, Object k) {
    if (k != null) {
        for (Node<K,V> e = first; e != null; ) {
            int s; K ek;
            if (((s = lockState) & (WAITER|WRITER)) != 0) {
                if (e.hash == h &&
                    ((ek = e.key) == k || (ek != null && k.equals(ek))))
                    return e;
                e = e.next;
            }
            else if (U.compareAndSwapInt(this, LOCKSTATE, s, s + READER)) {
                TreeNode<K,V> r, p;
                try {
                    p = ((r = root) == null ? null : r.findTreeNode(h, k, null));
                } finally {
                    Thread w;
                    if (U.getAndAddInt(this, LOCKSTATE, -READER) == (READER|WAITER) && (w = waiter) != null)
                        LockSupport.unpark(w);
                }
                return p;
            }
        }
    }
    return null;
}

TreeBin的find() 并不是实际执行查找的,它只是对TreeNode的findTreeNode() 方法进行封装,保证了在树节点正在被修改时查找的准确性,TreeNode中的find() 方法才是真正执行查找的方法,内部也是直接交给了findTreeNode() 方法;被ForwardingNode标记的节点表示正在移动,所以ForwardingNode的find() 方***指向新数组,在新数组中查找元素。

spread()

与JDK1.7版本一样,JDK1.8的ConcurrentHashMap会对hashcode再散列降低散列冲突的概率,调用的方法是spread()
spread() 方法源码如下:

    /**
     * 将散列的较高位散布(XOR)较低,也将最高位强制为0。
     * 由于该表使用2的幂掩码,因此仅在当前掩码上方的位中变化的哈希集将始终发生冲突。 (众所周知的示例是在小表中包含连续整数的Float键集。)
     * 因此,我们应用了将向下扩展较高位的影响的变换。
     * 在速度,实用性和位扩展质量之间需要权衡。
     * 由于许多常见的哈希集已经合理分布(因此无法从扩展中受益),并且由于我们使用树来处理容器中的大量冲突,因此我们仅以最便宜的方式对一些移位后的位进行XOR,以减少系统损失,以及合并最高位的影响,否则由于表范围的限制,这些位将永远不会在索引计算中使用。
     */
    static final int spread(int h) {
        return (h ^ (h >>> 16)) & HASH_BITS;
    }

JDK1.8版本的spread() 方法相比较JDK1.7的hash() 方法简洁多了,这是应为JDK1.8的ConcurrentHashMap内部结构比起JDK1.7版本更为复杂,即使hash冲突非常严重,在查找,插入时的效率不会像JDK1.7降低的那么明显(遍历链表的时间复杂度为O(n),而转成红黑树之后为O(logn)),所以权衡速度,实用性和位扩展质量之后,简化了再散列的算法,让hashcode的高16位和低16位进行异或运算,实际让高16位的二进制值也参与到散列定位元素的算法中,可以降低hash冲突的概率。

接下来看看put() 方法的实现

put()方法

话不多说先来看代码:

    public V put(K key, V value) {
        return putVal(key, value, false);
    }

    /** put和putIfAbsent的实现 */
    final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
        // 再hash,降低冲突
        int hash = spread(key.hashCode());
        int binCount = 0;
        for (Node<K,V>[] tab = table;;) {
            Node<K,V> f; int n, i, fh;
            // 表为空,则先执行初始化
            if (tab == null || (n = tab.length) == 0)
                tab = initTable();
            // 散列算法确定要插入元素的位置,并获取当前对应Node对象f,
            // 使用Unsafe.getObjectVolatile()方法来获取,而不是直接使用table[index]来获取,在java内存模型中,
            // 每个线程都有一个工作内存,里面存储着table的副本,虽然table是volatile修饰的,
            // 但不能保证线程每次都拿到table中的最新元素(即不能保证对table操作的原子性),
            // Unsafe.getObjectVolatile可以直接获取指定内存的数据,保证了每次拿到数据都是最新的。    
            else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) {
                // 定位位置为null时(不存在Node对象),使用CAS插入
                // 如果CAS成功,说明Node节点已经插入,随后addCount(1L, binCount)方***检查当前容量是否需要进行扩容。
                // 如果CAS失败,说明有其它线程提前插入了节点,自旋重新尝试在这个位置插入节点。
                if (casTabAt(tab, i, null, new Node<K,V>(hash, key, value, null)))
                    break;                   
            }
            // 定位位置所在区域正在扩容,则先等扩容完毕获得新的table后再进行插入
            else if ((fh = f.hash) == MOVED)
                tab = helpTransfer(tab, f);
            // 定位位置不为null(存在Node对象,即发生了hash冲突),则对当前位置对应的Node进行加锁后插入元素
            else {
                V oldVal = null;
                synchronized (f) {
                    // 在节点f上进行同步,节点插入之前,再次利用tabAt(tab, i) == f判断,防止被其它线程修改
                    if (tabAt(tab, i) == f) {
                        // 当前Node对象在链表中
                        if (fh >= 0) {
                            // 用于统计链表长度
                            binCount = 1;
                            for (Node<K,V> e = f;; ++binCount) {
                                K ek;
                                // 链表中存在key相同的元素,则根据onlyIfAbsent来确定是否更新value
                                if (e.hash == hash && ((ek = e.key) == key ||  (ek != null && key.equals(ek)))) {
                                    oldVal = e.val;
                                    if (!onlyIfAbsent)
                                        e.val = value;
                                    break;
                                }
                                Node<K,V> pred = e;
                                // 链表中不存在key相同的元素,则插入到链表末尾
                                if ((e = e.next) == null) {
                                    pred.next = new Node<K,V>(hash, key, value, null);
                                    break;
                                }
                            }
                        }
                        // 当前Node对象在红黑树中,调用putTreeVal()插入元素
                        else if (f instanceof TreeBin) {
                            Node<K,V> p;
                            binCount = 2;
                            if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key, value)) != null) {
                                oldVal = p.val;
                                if (!onlyIfAbsent)
                                    p.val = value;
                            }
                        }
                    }
                }
                if (binCount != 0) {
                    // 加如链表中的节点个数超过阈值(默认为8),则进行树化,提高查找效率
                    if (binCount >= TREEIFY_THRESHOLD)
                        treeifyBin(tab, i);
                    if (oldVal != null)
                        return oldVal;
                    break;
                }
            }
        }
        // 统计节点总数,若超出阈值则需要扩容
        addCount(1L, binCount);
        return null;
    }

put() 方法的实际实现是putVal() 方法,这个方法主要做了几件事:

  1. 处理非法输入,再散列要插入元素的key的hashcode
  2. 若表为空,则先执行初始化,初始化完成之后再继续插入元素
  3. 计算得到要插入的位置,当这个位置为空时,使用CAS进行插入;当这个位置被标记为正在移动,则等扩容后得到新的table再执行插入;若这个位置不为空(产生hash冲突),将当前位置的Node对象进行加锁,加锁后执行插入:
    1)当前Node对象在链表中,则遍历插入,并统计链表长度
    2)当前Node对象在红黑树中,则调用putTreeVal() 插入元素
    3)若链表长度大于等于阈值,则执行树化操作,转成红黑树
  4. 统计插入的节点总数,若超出阈值则执行扩容操作

详细的put流程在源码中用注释标注了,接下来看看几个重要的操作

initTable()

ConcurrentHashMap将table初始化延迟到第一次put() 进行,ConcurrentHashMap是支持并发扩容的,那它是如何保证table只被初始化一次的呢,来看看initTable() 方法

    private final Node<K,V>[] initTable() {
        Node<K,V>[] tab; int sc;
        // table为空,则执行初始化
        while ((tab = table) == null || tab.length == 0) {
            // 发现sizeCtl<0,意味着有线程执行CAS成功,则让出时间片,让执行初始化的线程执行
            if ((sc = sizeCtl) < 0)
                Thread.yield(); 
            // 通过CAS将sizeCtl修改为-1    
            else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
                try {
                    // 二次判断table是否被初始化,防止二次初始化
                    if ((tab = table) == null || tab.length == 0) {
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
                        // 创建table
                        @SuppressWarnings("unchecked")
                        Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
                        table = tab = nt;
                        // 标记当前能容纳的最大数据量,即 table长度 * 加载因子。
                        sc = n - (n >>> 2);
                    }
                } finally {
                    sizeCtl = sc;
                }
                break;
            }
        }
        return tab;
    }

sizeCtl默认为0,如果ConcurrentHashMap实例化时有传参数,sizeCtl会是一个2的幂次方的值。所以执行第一次put操作的线程会执行Unsafe.compareAndSwapInt()方法修改sizeCtl为-1,有且只有一个线程能够修改成功,其它线程通过Thread.yield()让出CPU时间片等待table初始化完成。

addCount()

将元素插入到表中以后,调用addCount() 方法进行统计元素总个数,超出阈值就要进行扩容处理,来看一下源码实现:

    /**
     * 增加计数,若超过阈值,则执行扩容。
     * 如果已经调整大小,则在工作可用时帮助执行转移。
     * 转移后重新检查占用率,以查看是否已经需要其他调整大小,因为调整大小是滞后的。
     *
     * @param x 要添加的计数
     * @param check 小于0则不检测是否发生扩容, 小于等于1时检查是否有其他线程参与
     */
    private final void addCount(long x, int check) {
        CounterCell[] as; long b, s;
        if ((as = counterCells) != null || !U.compareAndSwapLong(this, BASECOUNT, b = baseCount, s = b + x)) {
            CounterCell a; long v; int m;
            boolean uncontended = true;
            if (as == null || (m = as.length - 1) < 0 || (a = as[ThreadLocalRandom.getProbe() & m]) == null || !(uncontended = U.compareAndSwapLong(a, CELLVALUE, v = a.value, v + x))) {
                fullAddCount(x, uncontended);
                return;
            }
            if (check <= 1)
                return;
            s = sumCount();
        }
        // 重点看这边
        if (check >= 0) {
            Node<K,V>[] tab, nt; int n, sc;
            // 超过阈值,则进入循环执行扩容
            while (s >= (long)(sc = sizeCtl) && (tab = table) != null && (n = tab.length) < MAXIMUM_CAPACITY) {
                // 标记位,以调整大小为n的表的大小。向左移动RESIZE_STAMP_SHIFT时必须为负。
                int rs = resizeStamp(n);
                if (sc < 0) {
                    // 不满足扩容条件,则直接结束
                    if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 || sc == rs + MAX_RESIZERS || (nt = nextTable) == null || transferIndex <= 0)
                        break;
                    // 当前扩容正在进行,则参与迁移元素
                    if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))  
                        // 转移节点到新的数组
                        transfer(tab, nt);
                }
                // 超过阈值执行扩容,只有一个线程能执行创建新数组的操作
                else if (U.compareAndSwapInt(this, SIZECTL, sc, (rs << RESIZE_STAMP_SHIFT) + 2))
                    transfer(tab, null);
                s = sumCount();
            }
        }
    }

    private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {

        /***这部分用来确认每个分段的长度,若nextTab为空则创建新的数组***/

        int n = tab.length, stride;
        if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
            // 每个线程负责扩容区域的大小
            stride = MIN_TRANSFER_STRIDE; 
        // 创建新数组    
        if (nextTab == null) {            
            try {
                @SuppressWarnings("unchecked")
                Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1];
                nextTab = nt;
            } catch (Throwable ex) {      
                // 尝试应付 OOME
                sizeCtl = Integer.MAX_VALUE;
                return;
            }
            nextTable = nextTab;
            transferIndex = n;
        }
        int nextn = nextTab.length;
        // 将新数组用ForwardingNode进行标记
        ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
        // 表示是否要向下一个区域进行扩容
        boolean advance = true;
        // 扩容完成标记
        boolean finishing = false; 

        /****************************/

        for (int i = 0, bound = 0;;) {

            /***这部分用来确认每个线程扩容的区域***/

            Node<K,V> f; int fh;
            while (advance) {
                // nextIndex:用于记录当前的transferIndex值
                // nextBound:当前正在处理的区域的边界
                int nextIndex, nextBound;
                // 当前线程还没处理完此区域的扩容任务,因为还未超过边界bound
                if (--i >= bound || finishing)
                    advance = false;
                // 意味着已经没有区域要分配了,要不就是都处理完了,要不就是所有区域都有其他线程在处理了,不再需要当前线程的参与了    
                else if ((nextIndex = transferIndex) <= 0) {
                    i = -1;
                    advance = false;
                }
                // 通过CAS拿到自己负责的区域,并更新transferIndex,记住边界bound,以及当前处理的索引位置
                else if (U.compareAndSwapInt(this, TRANSFERINDEX, nextIndex, nextBound = (nextIndex > stride ? nextIndex - stride : 0))) {
                    bound = nextBound;
                    i = nextIndex - 1;
                    advance = false;
                }
            }

            /**********************************/

            /**********扩容前处理********/

            if (i < 0 || i >= n || i + n >= nextn) {
                int sc;
                // 扩容结束了
                if (finishing) { 
                    nextTable = null;
                    table = nextTab;
                    // 更新当前table的最大容量,相当于n * 1.5 = 2n * 0.75
                    sizeCtl = (n << 1) - (n >>> 1);
                    return;
                }
                // 说明当前线程完成了区域的扩容任务,更新sizeCtl的值成功
                if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
                    // 说明整个扩容还没结束,但是自己的扩容任务也完成了,直接返回
                    if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
                        return;
                    finishing = advance = true;
                    i = n; // recheck before commit
                }
            }
            else if ((f = tabAt(tab, i)) == null)
                // f记住了当前位置的 node,将其更新为ForwardingNode,以标记此位置已经处理,因为没有内容,尝试直接标记为处理了,失败会再进来
                advance = casTabAt(tab, i, null, fwd);
            // 表示已经在处理过程中    
            else if ((fh = f.hash) == MOVED)
                advance = true; 
            /**********************************/ 

            /********执行扩容操作(迁移元素,这部分与HashMap相似)*******/
            else {
                // 通过分段锁锁住节点保证不会被其他线程修改
                synchronized (f) {
                    if (tabAt(tab, i) == f) {
                        Node<K,V> ln, hn;
                        if (fh >= 0) {
                            int runBit = fh & n;
                            Node<K,V> lastRun = f;
                            for (Node<K,V> p = f.next; p != null; p = p.next) {
                                int b = p.hash & n;
                                if (b != runBit) {
                                    runBit = b;
                                    lastRun = p;
                                }
                            }
                            if (runBit == 0) {
                                ln = lastRun;
                                hn = null;
                            }
                            else {
                                hn = lastRun;
                                ln = null;
                            }
                            for (Node<K,V> p = f; p != lastRun; p = p.next) {
                                int ph = p.hash; K pk = p.key; V pv = p.val;
                                if ((ph & n) == 0)
                                    ln = new Node<K,V>(ph, pk, pv, ln);
                                else
                                    hn = new Node<K,V>(ph, pk, pv, hn);
                            }
                            setTabAt(nextTab, i, ln);
                            setTabAt(nextTab, i + n, hn);
                            setTabAt(tab, i, fwd);
                            advance = true;
                        }
                        else if (f instanceof TreeBin) {
                            TreeBin<K,V> t = (TreeBin<K,V>)f;
                            TreeNode<K,V> lo = null, loTail = null;
                            TreeNode<K,V> hi = null, hiTail = null;
                            int lc = 0, hc = 0;
                            for (Node<K,V> e = t.first; e != null; e = e.next) {
                                int h = e.hash;
                                TreeNode<K,V> p = new TreeNode<K,V>
                                    (h, e.key, e.val, null, null);
                                if ((h & n) == 0) {
                                    if ((p.prev = loTail) == null)
                                        lo = p;
                                    else
                                        loTail.next = p;
                                    loTail = p;
                                    ++lc;
                                }
                                else {
                                    if ((p.prev = hiTail) == null)
                                        hi = p;
                                    else
                                        hiTail.next = p;
                                    hiTail = p;
                                    ++hc;
                                }
                            }
                            ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
                                (hc != 0) ? new TreeBin<K,V>(lo) : t;
                            hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
                                (lc != 0) ? new TreeBin<K,V>(hi) : t;
                            setTabAt(nextTab, i, ln);
                            setTabAt(nextTab, i + n, hn);
                            setTabAt(tab, i, fwd);
                            advance = true;
                        }
                    }
                }
            }
            /**********************************/
        }
    }

当table容量不足的时候,即table的元素数量达到容量阈值sizeCtl,需要对table进行扩容。 整个扩容分为两部分:

  1. 构建一个nextTable,大小为table的两倍。
  2. 把table的数据复制到nextTable中。

这两个过程在单线程下实现很简单,但是ConcurrentHashMap是支持并发插入的,扩容操作自然也会有并发的出现,这种情况下,第二步可以支持节点的并发复制,详细的扩容过程在注释中标注了,小结一下扩容的过程

  1. 通过Unsafe.compareAndSwapInt修改sizeCtl值,保证只有一个线程能够初始化nextTable,扩容后的数组长度为原来的两倍,阈值为原来的1.5倍
  2. 以stride为长度单位,将table划分区域
  3. 每个参与扩容的线程,通过CAS竞争更新transferIndex,分配到负责的区域
  4. 每个线程知道自己负责的区域,不断循环将负责区域内的table[i]逐一进行迁移
  5. 如果table[i]是链表的头节点,就构造一个反序链表,把他们分别放在nextTable的i和i+n的位置上,移动完成,采用Unsafe.putObjectVolatile方法给table原位置赋值fwd
  6. 如果table[i]是TreeBin节点,也做一个反序处理,并判断是否需要untreeify(即是否转回成链表),把处理的结果分别放在nextTable的i和i+n的位置上,移动完成,同样采用Unsafe.putObjectVolatile方法给table原位置赋值fwd。
  7. 当处理完区域后,向上返回。由更上层控制,要不要继续进入transfer()

关于ConcurrentHashMap扩容原理更深入的部分可以参考这篇文章:见识不一样的ConcurrentHashMap

size()方法

由于ConcurrentHashMap将可能在各个位置上并发,因此当前容量是一个估值,如果加锁去确认容量有点得不偿失。ConcurrentHashMap就以内部类CounterCell对象,对应记录每个线程对数据的增删次数,然后综合得出结果,可以从size() 方法看出

    public int size() {
        long n = sumCount();
        return ((n < 0L) ? 0 :
                (n > (long)Integer.MAX_VALUE) ? Integer.MAX_VALUE :
                (int)n);
    }

    final long sumCount() {
        CounterCell[] as = counterCells; CounterCell a;
        long sum = baseCount;
        if (as != null) {
            for (int i = 0; i < as.length; ++i) {
                if ((a = as[i]) != null)
                    sum += a.value;
            }
        }
        return sum;
    }

JDK1.8的ConcurrentHashMap主要方法分析完了,至于其他方法可以自行阅读源码,相比较JDK1.7版本,JDK1.8的ConcurrentHashMap更为复杂,但对以前版本有重大提升。

总结

对于ConcurrentHashMap,分析了JDK1.7和JDK1.8的源码,两个版本实现方式差异较大,JDK1.8的版本是JDK1.7的优化版本,性能和效率相比较而言会高于JDK1.7版本,具体体现在下面几点:

  • 引入红黑树,解决多个元素出现hash冲突时查找效率的低下
  • 加锁方式从ReentrantLock变为synchronized:每个Segment是一个ReentrantLock(独占锁),操作table[i]时,就要显示通过加解锁方法,没获取到锁的线程就会被挂起,而基于synchronized的性能得到了升级,支持了锁升级的膨胀过程,当table[i]上为偏向锁和轻量级锁时,表现好于ReentrantLock;当table[i]上为重量级锁时,表现与ReentrantLock相差不多。总体来看,效率将好于ReentrantLock。
  • 支持并发扩容,提高扩容的效率

以上,有错的地方,请赐教

全部评论

相关推荐

点赞 收藏 评论
分享
牛客网
牛客企业服务