vlambda博客
学习文章列表

深入源码分析golang之WaitGroup

点击关注 不迷路






深入源码分析golang之WaitGroup

什么是sync.WaitGroup


        官方文档对其的描述是:WaitGroup等待一组goroutine的任务完成。主goroutine调用添加以设置要等待的goroutine的数量。然后,每个goroutine都会运行并在完成后调用Done。同时,可以使用Wait来阻塞,直到所有goroutine完成。我们来看官网给的一个例子:

 1package main
2
3import (
4    "sync"
5)
6
7type httpPkg struct{}
8
9func (httpPkg) Get(url string) {}
10
11var http httpPkg
12
13func main() {
14    var wg sync.WaitGroup
15    var urls = []string{
16        "http://www.golang.org/",
17        "http://www.google.com/",
18        "http://www.somestupidname.com/",
19    }
20    for _, url := range urls {
21        // 增加waitGroup计数
22        wg.Add(1)
23        // 启动goroutine获取url
24        go func(url string) {
25            //等获取url的goroutine完成,将waitGroup计数减1
26            defer wg.Done()
27            // 获取url
28            http.Get(url)
29        }(url)
30    }
31    // 等待所有goroutine完成
32    wg.Wait()
33}


深入源码分析golang之WaitGroup

源码剖析

深入源码分析golang之WaitGroup

WaitGroup的实现:WaitGroup的数据结构主要包括一个noCopy的辅助字段,以及一个具有复合含义的state1字段。接下来分别来了解下这两个字段的内部逻辑。


noCopy机制:Go中没有原生的禁止拷贝的方式,所以如果有的结构体,你希望使用者无法拷贝,只能指针传递保证全局唯一的话,可以这么干,定义一个结构体叫noCopy,要实现sync.Locker 这个接口。

1type noCopy struct{}
2
3// nocopy 只有在使用 go vet 检查时才能显示错误,编译正常
4func (*noCopy) Lock() {}
5func (*noCopy) UnLock() {}


state1处理:总共分配了12个字节,在这里被设计成三种状态。其中对齐的8个字节作为状态位(state),高32位为记录计数器的数量,低32位为等待goroutine的数量值。其余的4个字节作为信号量存储(sema)。由于操作系统分为32位和64位,64位的原子操作需要64位对齐,但是32位编译器保证不了,于是这里就采用了动态识别当前我们操作的64位数到底是不是在8字节对齐的位置上面。具体见源码state方法:

 1// 得到state的地址和信号量的地址
2func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
3    if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
4        // 如果地址是64bit对齐的,数组前两个元素做state,后一个元素做信号量
5        return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
6    } else {
7        // 如果地址是32bit对齐的,数组后两个元素用来做state,它可以用来做64bit的原子操作,第一个元素32bit用来做信号量
8        return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
9    }
10}

深入源码分析golang之WaitGroup


Add方法实现:主要操作的state1字段中计数值部分,计数器部分的逻辑主要是通过state(),在上面有提及。每次调用Add方法就会增加相应数量的计数器。如果计数器为零,则释放等待时阻塞的所有goroutine。如果计数器变为负数,请添加恐慌。如果计数器值大于0,说明此时还有任务没有完成,那么调用者就变成等待者,需要加入wait队列,并且阻塞自己。参数可正可负数。

Add方法源码分析如下

1func (wg *WaitGroup) Add(delta int) {
2    //获取state1中的状态位和信号量
3    statep, semap := wg.state()
4    //用来goroutine的竞争检测,可忽略。
5    if race.Enabled {
6        _ = *statep 
7        if delta < 0 {
8            race.ReleaseMerge(unsafe.Pointer(wg))
9        }
10        race.Disable()
11        defer race.Enable()
12    }
13    // uint64(delta)<<32 将delta左移32
14    // 因为高32位表示计数器,所以delta左移32位,
15    // 增加到计数位。
16    state := atomic.AddUint64(statep, uint64(delta)<<32)
17    // 当前计数器的值
18    v := int32(state >> 32)
19    // 阻塞的wait goroutine数量
20    w := uint32(state)
21    if race.Enabled && delta > 0 && v == int32(delta) {
22        race.Read(unsafe.Pointer(semap))
23    }
24    // 计数器的值<0,panic
25    if v < 0 {
26        panic("sync: negative WaitGroup counter")
27    }
28    // 当wait goroutine数量不为0时,累加后的counter值和delta相等,
29    // 说明Add()和Wait()同时调用了,所以发生panic,
30    // 因为正确的做法是先Add()后Wait(),
31    // 也就是已经调用了wait()就不允许再添加任务了
32    if w != 0 && delta > 0 && v == int32(delta) {
33        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
34    }
35    // add调用结束
36    if v > 0 || w == 0 {
37        return
38    }
39    // 能走到这里说明当前Goroutine Counter计数器为0,
40    // Waiter Counter计数器大于0, 
41    // 到这里数据也就是允许发生变动了,如果发生变动了,则出发panic
42    if *statep != state {
43        panic("sync: WaitGroup misuse: Add called concurrently with Wait")
44    }
45    // 所有的状态位清0
46    *statep = 0
47    for ; w != 0; w-- {
48        // 首先让信号量加一,然后检查是否有正在等待的Goroutine,如果没有,直接返回;
49        // 如果有,调用goready函数唤醒一个Goroutine。
50        runtime_Semrelease(semap, false0)
51    }
52
Add方法调用流程图如下

深入源码分析golang之WaitGroup


Done方法实现:内部调用了Add(-1)的方法。详情看Add方法

1//Done方法其实就是Add(-1)
2func (wg *WaitGroup) Done() {
3    wg.Add(-1)
4}


Wait方法实现:阻塞主goroutine直到WaitGroup计数器变为0。

Wait方法源码分析如下

1// 等待并阻塞,直到WaitGroup计数器为0
2func (wg *WaitGroup) Wait() {
3    // 获取waitgroup状态位和信号量
4    statep, semap := wg.state() 
5    if race.Enabled { 
6        _ = *statep 
7        race.Disable()
8    }
9    for {
10        // 使用原子操作读取state,是为了保证Add中的写入操作已经完成
11        state := atomic.LoadUint64(statep)
12        v := int32(state >> 32//获取计数器(高32位)
13        w := uint32(state) //获取wait goroutine数量(低32位)
14        if v == 0 { // 计数器为0,跳出死循环,不用阻塞
15            if race.Enabled {
16                race.Enable()
17                race.Acquire(unsafe.Pointer(wg))
18            }
19            return
20        }
21        // 使用CAS操作对`waiter Counter`计数器进行+1操作,
22        // 外面有for循环保证这里可以进行重试操作
23        if atomic.CompareAndSwapUint64(statep, state, state+1) {
24            if race.Enabled && w == 0 {
25                race.Write(unsafe.Pointer(semap))
26            }
27            // 在这里获取信号量,使线程进入睡眠状态,
28            // 与Add方法中runtime_Semrelease增加信号量相对应,
29            // 也就是当最后一个任务调用Done方法
30            // 后会调用Add方法对goroutine counter的值减到0,
31            // 就会走到最后的增加信号量
32            runtime_Semacquire(semap)
33            // 在Add方法中增加信号量时已经将statep的值设为0了,
34            // 如果这里不是0,说明在wait之后又调用了Add方法,
35            // 使用时机不对,触发panic
36            if *statep != 0 {
37                panic("sync: WaitGroup is reused before previous Wait has returned")
38            }
39            if race.Enabled {
40                race.Enable()
41                race.Acquire(unsafe.Pointer(wg))
42            }
43            return
44        }
45    }
46}

Wait方法调用流程图如下

 
 



扫码关注

获取更多干货内容