深度讲解Go语言-WaitGroup
大家好,我是gopher_looklook,现任某独角兽企业Go语言工程师,喜欢钻研Go源码,发掘各项技术在大型Go微服务项目中的最佳实践,期待与各位小伙伴多多交流,共同进步!
概念
sync.WaitGroup是Go语言中用于协调多个goroutine同步的核心工具,用于等待一组goroutine完成它们的任务。
简单示例
package main import ( "fmt" "sync" "time" ) func worker(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Worker %d starting\n", id) time.Sleep(1 * time.Second) fmt.Printf("Worker %d done\n", id) } func main() { var wg sync.WaitGroup numWorkers := 3 wg.Add(numWorkers) for i := 0; i < numWorkers; i++ { go worker(i, &wg) } wg.Wait() // 最常见的用法,此时只有一个等待者 fmt.Println("All workers have finished") }
- 输出
- 分析
可以看到,当调用wg.Add(numWorkers)时,表示我们要执行numWorkers组子goroutine,执行wg.Wait()则会阻塞当前goroutine。只有当全部子goroutine全部执行完成,最后一个子goroutine执行完wg.Done()后,当前被wg.Wait()阻塞的goroutine才能继续往下执行。
背景知识-Go源码中信号量的实现
Go语言提供了两个函数用于实现信号量的控制runtime_Semrelease和runtime_Semacquire。
func runtime_Semrelease(s *uint32, handoff bool, skipframes int) func runtime_Semacquire(s *uint32)
1.runtime_Semrelease 函数的主要作用是释放信号量。
- 当一个 goroutine 使用完资源后,会调用 runtime_Semrelease 将信号量的值加 1,表示释放了一个资源。
2.runtime_Semacquire 函数的主要作用是尝试获取信号量。
- 如果当前信号量的值大于 0,runtime_Semacquire 会将信号量的值减 1,表示成功获取了一个资源,然后立即返回,调用该函数的 goroutine 可以继续执行后续操作。
- 如果当前信号量的值为 0,说明没有可用资源,调用 runtime_Semacquire 的 goroutine 会被阻塞。该 goroutine 会被放入等待队列中,直到有其他 goroutine 调用 runtime_Semrelease 释放信号量,才有可能被唤醒继续执行。
sync.WaitGroup正是使用了获取和释放信号量的操作,实现了等待一组子goroutine完成任务,并通知到正在等待中的goroutine的功能。
源码解读
- go源码版本: go 1.23.0
- 源码
package sync import ( "sync/atomic" "unsafe" ) type WaitGroup struct { noCopy noCopy state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count. sema uint32 } func (wg *WaitGroup) Add(delta int) { state := wg.state.Add(uint64(delta) << 32) v := int32(state >> 32) w := uint32(state) if v < 0 { panic("sync: negative WaitGroup counter") } if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } if v > 0 || w == 0 { return } // This goroutine has set counter to 0 when waiters > 0. // Now there can't be concurrent mutations of state: // - Adds must not happen concurrently with Wait, // - Wait does not increment waiters if it sees counter == 0. // Still do a cheap sanity check to detect WaitGroup misuse. if wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } // Reset waiters count to 0. wg.state.Store(0) for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) } } // Done decrements the [WaitGroup] counter by one. func (wg *WaitGroup) Done() { wg.Add(-1) } // Wait blocks until the [WaitGroup] counter is zero. func (wg *WaitGroup) Wait() { for { state := wg.state.Load() v := int32(state >> 32) w := uint32(state) // Increment waiters count. if wg.state.CompareAndSwap(state, state+1) { runtime_Semacquire(&wg.sema) if wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return } } }
这里我删减了部分与竞态检测相关的代码。可以看到,sync.WaitGroup的核心功能集中体现在3个方法(Add/Done/Wait)上。 这3个方法互相依赖,相辅相成,需要一起配合使用。
WaitGroup.Add
func (wg *WaitGroup) Add(delta int) { state := wg.state.Add(uint64(delta) << 32) v := int32(state >> 32) w := uint32(state) if v < 0 { panic("sync: negative WaitGroup counter") } if w != 0 && delta > 0 && v == int32(delta) { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } if v > 0 || w == 0 { return } // This goroutine has set counter to 0 when waiters > 0. // Now there can't be concurrent mutations of state: // - Adds must not happen concurrently with Wait, // - Wait does not increment waiters if it sees counter == 0. // Still do a cheap sanity check to detect WaitGroup misuse. if wg.state.Load() != state { panic("sync: WaitGroup misuse: Add called concurrently with Wait") } // Reset waiters count to 0. wg.state.Store(0) for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) } }
首先是Add方法。Add方法会将64位int型的整数 wg.state 拆分成了高32位和低32位,分别用于存储不同含义的值:v 和 w。
- v:计数器,代表有多少个子goroutine未完成任务。
- w:等待者数量,代表有多少个正在等待所有任务完成的goroutine的数量。
在我们上述的例子中,有1个主goroutine在等待3个子goroutine完成任务。因此v=3,w=1,但是w的值需要等到调用Wait方法时才设置。
刚开始调用Add方法时,记录了要等待完成任务的子goroutine数量,存储在wg.state字段的高32位。赋值完成后,v等于3,w等于0(还没赋值)。经过几个异常判断的if条件检验。程序会在以下代码跳出Add函数。
func (wg *WaitGroup) Add(delta int) { ...... if v > 0 || w == 0 { return } ...... }
WaitGroup.Wait
func (wg *WaitGroup) Wait() { for { state := wg.state.Load() v := int32(state >> 32) w := uint32(state) // Increment waiters count. if wg.state.CompareAndSwap(state, state+1) { runtime_Semacquire(&wg.sema) if wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return } } }
Wait() 函数的执行时间在Done() 之前,因此优先分析Wait函数。Wait函数写了一个for循环。在for循环里,首先将wg.state的高32位和低32位拆分开,分别赋值给v和w,分别代表计数器和等待者数量。计数器不需要更新,而等待者数量w每调用一次Wait函数都需要自增1,也就是下面这行代码:
wg.state.CompareAndSwap(state, state+1)
之后会调用runtime_Semacquire函数,确保调用Wait函数的地方都会阻塞当前goroutine的进一步执行
runtime_Semacquire(&wg.sema)
WaitGroup.Done
func (wg *WaitGroup) Done() { wg.Add(-1) }
通过以上对Add函数和Wait函数分析,我们知道了在等待所有子goroutine完成任务之前,外层等待的goroutine都会被阻塞,阻塞的原因是由于Wait函数在for循环中尝试获取信号量,但是并没有可用的信号量可以获取。
Done函数实际调用了Add函数,因此只需要分析Add函数的执行过程即可。在最后一个子goroutine执行Done()之前,每次调用Done函数,都会将wg.state字段的高32位减1,即将计数器减1。并在以下这行代码跳出Add函数。
func (wg *WaitGroup) Add(delta int) { ...... if v > 0 || w == 0 { return } ...... }
最后一个子goroutine调用Done函数时,wg.state字段的32位减到0(计数器归0),w>0(存在等待者)。此时会将wg.state重新设置为0,并执行到下面这段代码,用于释放w个信号量,唤醒w个阻塞的等待者。之后退出Add函数,也即正常退出Done函数。
func (wg *WaitGroup) Add(delta int) { for ; w != 0; w-- { runtime_Semrelease(&wg.sema, false, 0) } }
处于等待的goroutine(在我们的例子当中是调用了wg.Wait函数的main goroutine)由于runtime_Semrelease释放了信号,for循环尝试获取信号量成功,wg.state也已经被重新设置为0,执行return跳出for循环,程序不再阻塞。
func (wg *WaitGroup) Wait() { for { ...... if wg.state.CompareAndSwap(state, state+1) { runtime_Semacquire(&wg.sema) if wg.state.Load() != 0 { panic("sync: WaitGroup is reused before previous Wait has returned") } return } } }
工作流程
为了方便理解,我们可以利用流程图来演示程序执行的不同时刻,各个goroutine的执行情况。这里以上述的简单示例代码为例。
常见误区
大部分情况下,我们在使用WaitGroup时只会调用一次wg.Wait() 函数。通过上述对源码的分析,我们知道WaitGroup其实是允许有多个goroutine等待操作完成的,例如下面这段代码。
package main import ( "fmt" "sync" "sync/atomic" "time" ) func main() { var wg sync.WaitGroup // 设置等待组的计数器为 3,代表有 3 个任务要完成 wg.Add(3) var num atomic.Int32 // 模拟 3 个子goroutine完成 for i := 0; i < 3; i++ { _i := i go func() { time.Sleep(2 * time.Second) num.Add(1) fmt.Printf("sub goroutine %d is working. \n", _i) wg.Done() }() } // 创建 10 个等待的 goroutine for i := 0; i < 10; i++ { go func(id int) { // 进入等待状态,直到等待组的计数器变为 0 wg.Wait() fmt.Printf("The waiting goroutine %d has been awakened, get num: %d\n", id, num.Load()) }(i) } // 等待一段时间,确保所有输出都能显示 time.Sleep(4 * time.Second) fmt.Println("All goroutines have been awakened.") }
总结
在本篇文章中,我们通过一段常见的示例代码,演示了sync.WaitGroup的基础用法。之后按照程序执行时间线逐步分析源码,探究了sync.WaitGroup可以等待一组子goroutine执行完成的原因。
以上便是我对WaitGroup源码的分析和总结,如果这篇文章对屏幕前的牛友学习Go语言有帮助的话,欢迎点赞+关注,你的支持是我创作的最大动力!