Skip to content

Commit 93782cc

Browse files
jonjohnsonjrgopherbot
authored andcommitted
errgroup: use WithCancelCause to cancel context
Fixes golang/go#59355 Change-Id: Ib6a88e7e5fefe7b0d5672035af16d109aabcbf1e Reviewed-on: https://go-review.googlesource.com/c/sync/+/481255 TryBot-Result: Gopher Robot <[email protected]> Run-TryBot: Bryan Mills <[email protected]> Reviewed-by: Bryan Mills <[email protected]> Run-TryBot: Ian Lance Taylor <[email protected]> Reviewed-by: Michael Knyszek <[email protected]> Auto-Submit: Bryan Mills <[email protected]>
1 parent 4966af6 commit 93782cc

File tree

4 files changed

+89
-5
lines changed

4 files changed

+89
-5
lines changed

errgroup/errgroup.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ type token struct{}
2020
// A zero Group is valid, has no limit on the number of active goroutines,
2121
// and does not cancel on error.
2222
type Group struct {
23-
cancel func()
23+
cancel func(error)
2424

2525
wg sync.WaitGroup
2626

@@ -43,7 +43,7 @@ func (g *Group) done() {
4343
// returns a non-nil error or the first time Wait returns, whichever occurs
4444
// first.
4545
func WithContext(ctx context.Context) (*Group, context.Context) {
46-
ctx, cancel := context.WithCancel(ctx)
46+
ctx, cancel := withCancelCause(ctx)
4747
return &Group{cancel: cancel}, ctx
4848
}
4949

@@ -52,7 +52,7 @@ func WithContext(ctx context.Context) (*Group, context.Context) {
5252
func (g *Group) Wait() error {
5353
g.wg.Wait()
5454
if g.cancel != nil {
55-
g.cancel()
55+
g.cancel(g.err)
5656
}
5757
return g.err
5858
}
@@ -76,7 +76,7 @@ func (g *Group) Go(f func() error) {
7676
g.errOnce.Do(func() {
7777
g.err = err
7878
if g.cancel != nil {
79-
g.cancel()
79+
g.cancel(g.err)
8080
}
8181
})
8282
}
@@ -105,7 +105,7 @@ func (g *Group) TryGo(f func() error) bool {
105105
g.errOnce.Do(func() {
106106
g.err = err
107107
if g.cancel != nil {
108-
g.cancel()
108+
g.cancel(g.err)
109109
}
110110
})
111111
}

errgroup/go120.go

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.20
6+
// +build go1.20
7+
8+
package errgroup
9+
10+
import "context"
11+
12+
func withCancelCause(parent context.Context) (context.Context, func(error)) {
13+
return context.WithCancelCause(parent)
14+
}

errgroup/go120_test.go

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build go1.20
6+
// +build go1.20
7+
8+
package errgroup_test
9+
10+
import (
11+
"context"
12+
"errors"
13+
"testing"
14+
15+
"golang.org/x/sync/errgroup"
16+
)
17+
18+
func TestCancelCause(t *testing.T) {
19+
errDoom := errors.New("group_test: doomed")
20+
21+
cases := []struct {
22+
errs []error
23+
want error
24+
}{
25+
{want: nil},
26+
{errs: []error{nil}, want: nil},
27+
{errs: []error{errDoom}, want: errDoom},
28+
{errs: []error{errDoom, nil}, want: errDoom},
29+
}
30+
31+
for _, tc := range cases {
32+
g, ctx := errgroup.WithContext(context.Background())
33+
34+
for _, err := range tc.errs {
35+
err := err
36+
g.TryGo(func() error { return err })
37+
}
38+
39+
if err := g.Wait(); err != tc.want {
40+
t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
41+
"g.Wait() = %v; want %v",
42+
g, tc.errs, err, tc.want)
43+
}
44+
45+
if tc.want == nil {
46+
tc.want = context.Canceled
47+
}
48+
49+
if err := context.Cause(ctx); err != tc.want {
50+
t.Errorf("after %T.TryGo(func() error { return err }) for err in %v\n"+
51+
"context.Cause(ctx) = %v; tc.want %v",
52+
g, tc.errs, err, tc.want)
53+
}
54+
}
55+
}

errgroup/pre_go120.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright 2023 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
//go:build !go1.20
6+
// +build !go1.20
7+
8+
package errgroup
9+
10+
import "context"
11+
12+
func withCancelCause(parent context.Context) (context.Context, func(error)) {
13+
ctx, cancel := context.WithCancel(parent)
14+
return ctx, func(error) { cancel() }
15+
}

0 commit comments

Comments
 (0)