「Golang」sync.WaitGroup源码讲解

sync.WaitGroup介绍

当我们在开发过程中,经常需要在开启多个goroutine后,等待全部的goroutine执行完毕后才进行下一步的业务逻辑执行。此时我们可能会采用轮询的方式去定时侦测已经开启的多个goroutine的业务是否执行完毕,但是这样性能很低,并且持续占用cpu时间片很消耗cpu的资源,此时我们就该使用sync.WaitGroup来完成此次操作。举个🌰,下列代码是开了10个goroutine后等待其各睡眠5秒之后进行后续操作的sync.WaitGroup方法实现。

func main() {
   
	// 创建对象
	var wait sync.WaitGroup
	for i := 0; i < 10; i++ {
   
		// 为需要等待结束的goroutine数量+1
		wait.Add(1)
		go func() {
   
			time.Sleep(5*time.Second)
			// 结束 使得需要等待的数量-1
			// 等同于 wait.Add(-1)
			wait.Done()
		}()
	}
	// 等待所有执行完毕
	wait.Wait()
	fmt.Println("wait done")
}

上述代码的睡眠5秒可以替换为任何需要在goroutine中执行的业务逻辑,上述代码中出现了下述几个sync.WaitGroup中提供的方法,提供方法很少很简洁,接下来就开始解析一下sync.WaitGroup的相关信息。

func (wg *WaitGroup) Add(delta int) 
func (wg *WaitGroup) Done() 
func (wg *WaitGroup) Wait()

sync.WaitGroup源代码解析

1:sync.WaitGroup结构体的解析

type WaitGroup struct {
   
	// 一个防止sync.WaitGroup被复制的标记结构体
	noCopy noCopy
	// 该数组在32为系统与64位系统中代表的用途不同
	// 首先说64位系统:
	// state1[0]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
	// state1[1]代表调用了Wait方法等待结束的Waiter的数量
	// state1[2]代表Waiter的信号量
	// 其中 state1[0]与state1[1]作者称为计数标记
	// 在32位系统中:
	// state1[0]代表Waiter的信号量
	// state1[1]代表当前sync.WaitGroup 调用Add方法增加了多少的couter
	// state1[2]代表调用了Wait方法等待结束的Waiter的数量
	// 其中 state1[1]与state1[2]作者称为计数标记
	state1 [3]uint32
}

Add(delta int)方法源代码解析

// state 方法用于根据系统是32位还是64位返回对应的state1字段的对应的计数和信号量的地址
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
   
	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
   
		// 64位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
		// 取出state1[0]与state1[1]的所有二进制位
		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
	} else {
   
		// 32位系统返回state1与state1[2],由于数组是连续内存所以可以通过首地址
		// 取出state1[0]与state1[1]的所有二进制位
		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
	}
}
// Add 方法传入一个增量 可以为赋值则代表 Done
// 例如 调用Done方法就是对Add方法传入了-1 
// 即Add(-1)
func (wg *WaitGroup) Add(delta int) {
   
	// 根据系统位数返回计数标记和信号量标记
	statep, semap := wg.state()
	// 做了race检查和异常的检查
	if race.Enabled {
   
		_ = *statep // 如果信号量是个空指针,则报错
		if delta < 0 {
   
			race.ReleaseMerge(unsafe.Pointer(wg))
		}
		race.Disable()
		defer race.Enable()
	}
	// 将计数标记的高32位的值+delta
	// 如果是64位系统则代表state1[0]+delta
	state := atomic.AddUint64(statep, uint64(delta)<<32)
	// 转换成正常的数字
	// 例如 以64为系统为例 当原state1[0]为0时
	// atomic.AddUint64(statep, uint64(delta)<<32)
	// 当delta==1时
	// 结果为 2^32-1
	// 该操作就是将此数值转变为1
	v := int32(state >> 32)
	// w代表需要等待的数量即 64位系统中的state1[1]为state
	w := uint32(state)
	// 做race判断,并判断delta是否为负数 如果是的话并且v与delta相等则做一些race的同步
	// fixme 此处解释存疑
	if race.Enabled && delta > 0 && v == int32(delta) {
   
		// The first increment must be synchronized with Wait.
		// Need to model this as a read, because there can be
		// several concurrent wg.counter transitions from 0.
		race.Read(unsafe.Pointer(semap))
	}
	// 如果v<0则代表delta的传入的为负值,并且该负值与原couter相减后小于0
	// 说白了一点就说Add传入的负值超出了原有couter的数量
	if v < 0 {
   
		panic("sync: negative WaitGroup counter")
	}
	// 如果等待数量不是0 并且delta>0 且v==delta 则代表出现了
	// 同时并发调用了Add和Wait
	if w != 0 && delta > 0 && v == int32(delta) {
   
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 正常情况 直接返回
	if v > 0 || w == 0 {
   
		return
	}
	
	// 也是同时调用Add和Wait
	if *statep != state {
   
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// 如果计数值v为0并且waiter的数量w不为0
	// 则代表delta传入的值使得couter变为了0,但是还是有waiter在等待的话
	// 就把statep即state1[0]与state1[1]设置为0
	// 并唤醒所有的正在等待的阻塞goroutine
	*statep = 0
	for ; w != 0; w-- {
   
		runtime_Semrelease(semap, false, 0)
	}
}

Add(delta int)方法的总体执行过程大概如下:

  1. 根据64位还是32位系统获取对应的标记位
  2. 对计数标记位做delta的Add
  3. 判断Add之后的state是否为一些非法情况,比如v<0等
  4. 如果v和w分别为大于0和等于0则正常返回,代表Add成功
  5. 否则判断当v为0时则代表没有counter了,但是还有waiter那么就把state整体设置为0,随后唤醒调用了 Wait()的阻塞的goroutine。

Done() 方法解析

// Done 没什么可说的 就是调用了一下Add(-1)
func (wg *WaitGroup) Done() {
   
	wg.Add(-1)
}

Wait() 方法解析

func (wg *WaitGroup) Wait() {
   
	// 根据系统位数返回计数标记和信号量标记
	statep, semap := wg.state()
	// race检测
	if race.Enabled {
   
		_ = *statep // trigger nil deref early
		race.Disable()
	}
	// 循环校验是否所有的goroutine都调用了Done
	for {
   
		//原子获取值
		state := atomic.LoadUint64(statep)
		// v 代表 couter数量 ,即高32位
		v := int32(state >> 32)
		// w 代表waiter数量 ,即低32位
		w := uint32(state)
		// 如果v==0 代表没有 couter了则不用等待了直接返回
		if v == 0 {
   
			// Counter is 0, no need to wait.
			if race.Enabled {
   
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
		// 如果statep的值与state相等 还有需要等待完成的goroutine 此时则waiter+1
		if atomic.CompareAndSwapUint64(statep, state, state+1) {
   
		// race检测
			if race.Enabled && w == 0 {
   
				race.Write(unsafe.Pointer(semap))
			}
			// 阻塞等待直到被Add唤醒
			runtime_Semacquire(semap)
			// 如果被唤醒了 但是发现地址中的值不是0 代表唤醒错误 panic
			if *statep != 0 {
   
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			// race检测
			if race.Enabled {
   
				race.Enable()
				race.Acquire(unsafe.Pointer(wg))
			}
			return
		}
	}
}

Wait()方法的总体执行过程大概如下:

  1. 获取标记位的地址
  2. 获取值
  3. 判断v是否为0,如果是则无须等待。
  4. 如果v不是0并且statep的值与state相等,则代表还有需要等待完成的goroutine,此时则waiter+1,然后阻塞等待Add方法唤醒

总结

sync.WaitGroup 使用的规范

  1. Add(delta int) 方法可以设置为负值,但是必须要确保这个负值delta加上当前计数器的数量的结果大于0。否则会panic。🌰如下
func main() {
   
	var wait sync.WaitGroup
	wait.Add(1) // 没问题 计数器为1
	wait.Add(-1)// 没问题 计数器为0
	
	wait.Add(-3) // panic 此时计数器为0-3 出错
}


func main() {
   
	var wait sync.WaitGroup
	wait.Add(1) // 没问题 计数器为1
	wait.Add(1)// 没问题 计数器为2
	
	wait.Done() // 没问题 计数器为1
	wait.Done()// 没问题 计数器为0
	
	wait.Done()// panic 此时计数器为-1
}

2.Add(delta int)方法必须在 Wait() 方法调用之前全部调用完毕,否则会出现panic。举个🌰,本实例的设想是进入goroutine后Add(delta int)但是Wait() 方法调用早于goroutine中Add(delta int),所以此时Wait()计数器为0,则不等待直接跳过。代码如下:

func main() {
   
	var wait sync.WaitGroup
	
	go func() {
   
		// 故意sleep 代表执行逻辑
		time.Sleep(1*time.Millisecond)
		fmt.Println("Add")
		wait.Add(1)
		wait.Done()
	}()
	go func() {
   
		// 故意sleep 代表执行逻辑
		time.Sleep(1*time.Millisecond)
		fmt.Println("Add")
		wait.Add(1)
		wait.Done()
	}()
	wait.Wait()
	fmt.Println("Done")
}

3.不可以在前一个Wait()还未结束时,复用sync.WaitGroup ,举个例子代码:

func main() {
   
	var wait sync.WaitGroup
	wait.Add(1)
	go func() {
   
		// 故意sleep 代表执行逻辑
		time.Sleep(1*time.Millisecond)
		wait.Done() // 正常结束
		wait.Add(1) // 此处panic 因为wait还未结束就再次复用
	}()
	wait.Wait()
	fmt.Println("Done")
}

sync.WaitGroup 虽然可以重用,但是是有一个前提的,那就是必须等到上一轮的 Wait() 完成之后,才能重用 sync.WaitGroup 执行下一轮的 Add(delta int)/Wait() ,如果你在 Wait() 还没执行完的时候就调用下一轮 Add 方法,就有可能出现 panic。

至此sync.WaitGroup的源码解析就解析完了,可能有些地方有些理解上的错误,请各位谅解并且帮忙指出修改意见,如果这篇文章能帮到你,这是我的荣幸。

全部评论

相关推荐

头像
04-26 15:00
已编辑
算法工程师
点赞 评论 收藏
转发
点赞 收藏 评论
分享
牛客网
牛客企业服务