Skip to content

Commit 683371f

Browse files
feat(schema/appdata): async listener mux'ing (#20879)
Co-authored-by: cool-developer <[email protected]>
1 parent 897f4f8 commit 683371f

File tree

4 files changed

+569
-0
lines changed

4 files changed

+569
-0
lines changed

schema/appdata/async.go

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package appdata
2+
3+
import (
4+
"context"
5+
"sync"
6+
)
7+
8+
// AsyncListenerOptions are options for async listeners and listener mux's.
9+
type AsyncListenerOptions struct {
10+
// Context is the context whose Done() channel listeners use will use to listen for completion to close their
11+
// goroutine. If it is nil, then context.Background() will be used and goroutines may be leaked.
12+
Context context.Context
13+
14+
// BufferSize is the buffer size of the channels to use. It defaults to 0.
15+
BufferSize int
16+
17+
// DoneWaitGroup is an optional wait-group that listener goroutines will notify via Add(1) when they are started
18+
// and Done() after they are cancelled and completed.
19+
DoneWaitGroup *sync.WaitGroup
20+
}
21+
22+
// AsyncListenerMux returns a listener that forwards received events to all the provided listeners asynchronously
23+
// with each listener processing in a separate go routine. All callbacks in the returned listener will return nil
24+
// except for Commit which will return an error or nil once all listeners have processed the commit. The context
25+
// is used to signal that the listeners should stop listening and return. bufferSize is the size of the buffer for the
26+
// channels used to send events to the listeners.
27+
func AsyncListenerMux(opts AsyncListenerOptions, listeners ...Listener) Listener {
28+
asyncListeners := make([]Listener, len(listeners))
29+
commitChans := make([]chan error, len(listeners))
30+
for i, l := range listeners {
31+
commitChan := make(chan error)
32+
commitChans[i] = commitChan
33+
asyncListeners[i] = AsyncListener(opts, commitChan, l)
34+
}
35+
mux := ListenerMux(asyncListeners...)
36+
muxCommit := mux.Commit
37+
mux.Commit = func(data CommitData) error {
38+
if muxCommit != nil {
39+
err := muxCommit(data)
40+
if err != nil {
41+
return err
42+
}
43+
}
44+
45+
for _, commitChan := range commitChans {
46+
err := <-commitChan
47+
if err != nil {
48+
return err
49+
}
50+
}
51+
return nil
52+
}
53+
54+
return mux
55+
}
56+
57+
// AsyncListener returns a listener that forwards received events to the provided listener listening in asynchronously
58+
// in a separate go routine. The listener that is returned will return nil for all methods including Commit and
59+
// an error or nil will only be returned in commitChan once the sender has sent commit and the receiving listener has
60+
// processed it. Thus commitChan can be used as a synchronization and error checking mechanism. The go routine
61+
// that is being used for listening will exit when context.Done() returns and no more events will be received by the listener.
62+
// bufferSize is the size of the buffer for the channel that is used to send events to the listener.
63+
// Instead of using AsyncListener directly, it is recommended to use AsyncListenerMux which does coordination directly
64+
// via its Commit callback.
65+
func AsyncListener(opts AsyncListenerOptions, commitChan chan<- error, listener Listener) Listener {
66+
packetChan := make(chan Packet, opts.BufferSize)
67+
res := Listener{}
68+
ctx := opts.Context
69+
if ctx == nil {
70+
ctx = context.Background()
71+
}
72+
done := ctx.Done()
73+
74+
go func() {
75+
if opts.DoneWaitGroup != nil {
76+
opts.DoneWaitGroup.Add(1)
77+
}
78+
79+
var err error
80+
for {
81+
select {
82+
case packet := <-packetChan:
83+
if err != nil {
84+
// if we have an error, don't process any more packets
85+
// and return the error and finish when it's time to commit
86+
if _, ok := packet.(CommitData); ok {
87+
commitChan <- err
88+
return
89+
}
90+
} else {
91+
// process the packet
92+
err = listener.SendPacket(packet)
93+
// if it's a commit
94+
if _, ok := packet.(CommitData); ok {
95+
commitChan <- err
96+
if err != nil {
97+
return
98+
}
99+
}
100+
}
101+
102+
case <-done:
103+
close(packetChan)
104+
if opts.DoneWaitGroup != nil {
105+
opts.DoneWaitGroup.Done()
106+
}
107+
return
108+
}
109+
}
110+
}()
111+
112+
if listener.InitializeModuleData != nil {
113+
res.InitializeModuleData = func(data ModuleInitializationData) error {
114+
packetChan <- data
115+
return nil
116+
}
117+
}
118+
119+
if listener.StartBlock != nil {
120+
res.StartBlock = func(data StartBlockData) error {
121+
packetChan <- data
122+
return nil
123+
}
124+
}
125+
126+
if listener.OnTx != nil {
127+
res.OnTx = func(data TxData) error {
128+
packetChan <- data
129+
return nil
130+
}
131+
}
132+
133+
if listener.OnEvent != nil {
134+
res.OnEvent = func(data EventData) error {
135+
packetChan <- data
136+
return nil
137+
}
138+
}
139+
140+
if listener.OnKVPair != nil {
141+
res.OnKVPair = func(data KVPairData) error {
142+
packetChan <- data
143+
return nil
144+
}
145+
}
146+
147+
if listener.OnObjectUpdate != nil {
148+
res.OnObjectUpdate = func(data ObjectUpdateData) error {
149+
packetChan <- data
150+
return nil
151+
}
152+
}
153+
154+
if listener.Commit != nil {
155+
res.Commit = func(data CommitData) error {
156+
packetChan <- data
157+
return nil
158+
}
159+
}
160+
161+
return res
162+
}

schema/appdata/async_test.go

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package appdata
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
"testing"
8+
)
9+
10+
func TestAsyncListenerMux(t *testing.T) {
11+
t.Run("empty", func(t *testing.T) {
12+
listener := AsyncListenerMux(AsyncListenerOptions{}, Listener{}, Listener{})
13+
14+
if listener.InitializeModuleData != nil {
15+
t.Error("expected nil")
16+
}
17+
if listener.StartBlock != nil {
18+
t.Error("expected nil")
19+
}
20+
if listener.OnTx != nil {
21+
t.Error("expected nil")
22+
}
23+
if listener.OnEvent != nil {
24+
t.Error("expected nil")
25+
}
26+
if listener.OnKVPair != nil {
27+
t.Error("expected nil")
28+
}
29+
if listener.OnObjectUpdate != nil {
30+
t.Error("expected nil")
31+
}
32+
33+
// commit is not expected to be nil
34+
})
35+
36+
t.Run("call cancel", func(t *testing.T) {
37+
ctx, cancel := context.WithCancel(context.Background())
38+
wg := &sync.WaitGroup{}
39+
var calls1, calls2 []string
40+
listener1 := callCollector(1, func(name string, _ int, _ Packet) {
41+
calls1 = append(calls1, name)
42+
})
43+
listener2 := callCollector(2, func(name string, _ int, _ Packet) {
44+
calls2 = append(calls2, name)
45+
})
46+
res := AsyncListenerMux(AsyncListenerOptions{
47+
BufferSize: 16, Context: ctx, DoneWaitGroup: wg,
48+
}, listener1, listener2)
49+
50+
callAllCallbacksOnces(t, res)
51+
52+
expectedCalls := []string{
53+
"InitializeModuleData",
54+
"StartBlock",
55+
"OnTx",
56+
"OnEvent",
57+
"OnKVPair",
58+
"OnObjectUpdate",
59+
"Commit",
60+
}
61+
62+
checkExpectedCallOrder(t, calls1, expectedCalls)
63+
checkExpectedCallOrder(t, calls2, expectedCalls)
64+
65+
// cancel and expect the test to finish - if all goroutines aren't canceled the test will hang
66+
cancel()
67+
wg.Wait()
68+
})
69+
70+
t.Run("error on commit", func(t *testing.T) {
71+
var calls1, calls2 []string
72+
listener1 := callCollector(1, func(name string, _ int, _ Packet) {
73+
calls1 = append(calls1, name)
74+
})
75+
listener1.Commit = func(data CommitData) error {
76+
return fmt.Errorf("error")
77+
}
78+
listener2 := callCollector(2, func(name string, _ int, _ Packet) {
79+
calls2 = append(calls2, name)
80+
})
81+
res := AsyncListenerMux(AsyncListenerOptions{}, listener1, listener2)
82+
83+
err := res.Commit(CommitData{})
84+
if err == nil || err.Error() != "error" {
85+
t.Fatalf("expected error, got %v", err)
86+
}
87+
})
88+
}
89+
90+
func TestAsyncListener(t *testing.T) {
91+
t.Run("call cancel", func(t *testing.T) {
92+
commitChan := make(chan error)
93+
ctx, cancel := context.WithCancel(context.Background())
94+
wg := &sync.WaitGroup{}
95+
var calls []string
96+
listener := callCollector(1, func(name string, _ int, _ Packet) {
97+
calls = append(calls, name)
98+
})
99+
res := AsyncListener(AsyncListenerOptions{BufferSize: 16, Context: ctx, DoneWaitGroup: wg},
100+
commitChan, listener)
101+
102+
callAllCallbacksOnces(t, res)
103+
104+
err := <-commitChan
105+
if err != nil {
106+
t.Fatalf("expected nil, got %v", err)
107+
}
108+
109+
checkExpectedCallOrder(t, calls, []string{
110+
"InitializeModuleData",
111+
"StartBlock",
112+
"OnTx",
113+
"OnEvent",
114+
"OnKVPair",
115+
"OnObjectUpdate",
116+
"Commit",
117+
})
118+
119+
calls = nil
120+
121+
// expect wait group to return after cancel is called
122+
cancel()
123+
wg.Wait()
124+
})
125+
126+
t.Run("error", func(t *testing.T) {
127+
commitChan := make(chan error)
128+
var calls []string
129+
listener := callCollector(1, func(name string, _ int, _ Packet) {
130+
calls = append(calls, name)
131+
})
132+
133+
listener.OnKVPair = func(updates KVPairData) error {
134+
return fmt.Errorf("error")
135+
}
136+
137+
res := AsyncListener(AsyncListenerOptions{BufferSize: 16}, commitChan, listener)
138+
139+
callAllCallbacksOnces(t, res)
140+
141+
err := <-commitChan
142+
if err == nil || err.Error() != "error" {
143+
t.Fatalf("expected error, got %v", err)
144+
}
145+
146+
checkExpectedCallOrder(t, calls, []string{"InitializeModuleData", "StartBlock", "OnTx", "OnEvent"})
147+
})
148+
}

0 commit comments

Comments
 (0)