概述
在并发编程中,控制主程序等待所有 Goroutine 完成任务是一项关键任务。Go 语言提供了 sync.WaitGroup 来解决这一问题。
本文将讲解 sync.WaitGroup 的使用方法、原理以及在实际项目中的应用场景,用清晰的代码示例和详细的注释,助力读者掌握并发编程中等待组的使用技巧。
1. 基本使用
1
package main import ( "fmt" "sync" "time") func main() { var wg sync.WaitGroup for i := 1; i <= 3; i++ { wg.Add(1) go worker(i, &wg) } wg.Wait() fmt.Println("All workers have completed.")} func worker(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Worker %d started\n", id) time.Sleep(2 * time.Second) fmt.Printf("Worker %d completed\n", id)}
在上面示例中,用一个 sync.WaitGroup 实例 wg,然后使用 wg.Add(1) 来增加计数,表示有一个 Goroutine 需要等待。
在每个 Goroutine 的结束处,使用 defer wg.Done() 来减少计数,表示一个 Goroutine 已完成。
最后,用 wg.Wait() 来等待所有 Goroutine 完成。
1.
package main import ( "fmt" "sync" "time") func main() { var wg sync.WaitGroup for i := 1; i <= 3; i++ { wg.Add(1) go workerWithError(i, &wg) } wg.Wait() fmt.Println("All workers have completed.")} func workerWithError(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Worker %d started\n", id) time.Sleep(2 * time.Second) // 模拟错误发生 if id == 2 { fmt.Printf("Worker %d encountered an error\n", id) return } fmt.Printf("Worker %d completed\n", id)}
有时候,需要在 Goroutine 中处理错误。在这个示例中,当 id 为 2 时,模拟了一个错误的情况。
通过在错误发生时提前返回,可以确保计数正确减少,避免等待组出现死锁。
2. 多级等待组
2.1
package main import ( "fmt" "sync" "time") func main() { var outerWG sync.WaitGroup var innerWG sync.WaitGroup for i := 1; i <= 2; i++ { outerWG.Add(1) go outerWorker(i, &outerWG, &innerWG) } outerWG.Wait() fmt.Println("All outer workers have completed.")} func outerWorker(id int, outerWG, innerWG *sync.WaitGroup) { defer outerWG.Done() fmt.Printf("Outer Worker %d started\n", id) for j := 1; j <= 3; j++ { innerWG.Add(1) go innerWorker(id, j, innerWG) } innerWG.Wait() fmt.Printf("Outer Worker %d completed\n", id)} func innerWorker(outerID, innerID int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Inner Worker %d of Outer Worker %d started\n", innerID, outerID) time.Sleep(2 * time.Second) fmt.Printf("Inner Worker %d of Outer Worker %d completed\n", innerID, outerID)}
在示例中,使用了嵌套的 sync.WaitGroup。
外部的等待组 outerWG 等待所有外部 Goroutine 完成,而每个外部 Goroutine 内部的 innerWG 则等待其内部的所有 Goroutine 完成。
2.
package main import ( "fmt" "sync" "time") func main() { var dynamicWG sync.WaitGroup for i := 1; i <= 3; i++ { dynamicWG.Add(1) go dynamicWorker(i, &dynamicWG) } // 模拟动态添加更多任务 time.Sleep(1 * time.Second) for i := 4; i <= 6; i++ { dynamicWG.Add(1) go dynamicWorker(i, &dynamicWG) } dynamicWG.Wait() fmt.Println("All dynamic workers have completed.")} func dynamicWorker(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Dynamic Worker %d started\n", id) time.Sleep(2 * time.Second) fmt.Printf("Dynamic Worker %d completed\n", id)}
在上述示例中,创建了一个等待组 dynamicWG,然后在运行时动态添加了更多的任务。
用这种方式,可以动态地管理需要等待的 Goroutine 数量。
3. 超时处理
3.1
package main import ( "fmt" "sync" "time") func main() { var timeoutWG sync.WaitGroup for i := 1; i <= 3; i++ { timeoutWG.Add(1) go timeoutWorker(i, &timeoutWG) } // 等待最多5秒,超时则不再等待 timeout := time.After(5 * time.Second) done := make(chan struct{}) go func() { timeoutWG.Wait() close(done) }() select { case <-done: fmt.Println("All timeout workers have completed.") case <-timeout: fmt.Println("Timeout reached. Not all workers have completed.") }} func timeoutWorker(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Timeout Worker %d started\n", id) time.Sleep(time.Duration(id) * time.Second) fmt.Printf("Timeout Worker %d completed\n", id) }
在上面示例中,用 time.After 创建了一个 5 秒的超时通道。
在另一个 Goroutine 中监听等待组的完成情况,可以在超时或任务完成时得知等待的最终结果。
3.2
package main import ( "errors" "fmt" "sync" "time") func main() { var timeoutWG sync.WaitGroup for i := 1; i <= 3; i++ { timeoutWG.Add(1) go timeoutWorkerWithError(i, &timeoutWG) } // 等待最多5秒,超时则返回错误 err := waitWithTimeout(&timeoutWG, 5*time.Second) if err != nil { fmt.Printf("Timeout reached. Not all workers have completed. Error: %v\n", err) } else { fmt.Println("All timeout workers have completed.") }} func timeoutWorkerWithError(id int, wg *sync.WaitGroup) { defer wg.Done() fmt.Printf("Timeout Worker %d started\n", id) time.Sleep(time.Duration(id) * time.Second) // 模拟错误发生 if id == 2 { fmt.Printf("Timeout Worker %d encountered an error\n", id) return } fmt.Printf("Timeout Worker %d completed\n", id)} func waitWithTimeout(wg *sync.WaitGroup, timeout time.Duration) error { done := make(chan struct{}) go func() { defer close(done) wg.Wait() }() select { case <-done: return nil case <-time.After(timeout): return errors.New("timeout reached") }}
有时候,希望在程序超时的时候返回一个错误。
在这个示例中,用封装等待组的超时检查,可以在主程序中获得一个清晰的错误提示。
总结
通过讨论 sync.WaitGroup 的基本用法、避免常见错误以及实际应用,深入了解了这个强大的同步工具。
在 Go 语言并发编程中,合理使用 sync.WaitGroup 能够优雅地处理并发等待,确保主程序在所有任务完成后再继续执行。