|
20 | 20 | import static org.mockito.Mockito.when;
|
21 | 21 |
|
22 | 22 | import com.google.devtools.build.lib.remote.grpc.SharedConnectionFactory.SharedConnection;
|
| 23 | +import com.google.devtools.build.lib.remote.util.RxNoGlobalErrorsRule; |
23 | 24 | import io.reactivex.rxjava3.core.Single;
|
24 | 25 | import io.reactivex.rxjava3.observers.TestObserver;
|
25 |
| -import io.reactivex.rxjava3.plugins.RxJavaPlugins; |
26 | 26 | import java.io.IOException;
|
27 | 27 | import java.util.concurrent.Semaphore;
|
28 | 28 | import java.util.concurrent.atomic.AtomicBoolean;
|
29 | 29 | import java.util.concurrent.atomic.AtomicInteger;
|
30 | 30 | import java.util.concurrent.atomic.AtomicReference;
|
31 |
| -import org.junit.After; |
32 | 31 | import org.junit.Before;
|
33 | 32 | import org.junit.Rule;
|
34 | 33 | import org.junit.Test;
|
|
42 | 41 | @RunWith(JUnit4.class)
|
43 | 42 | public class SharedConnectionFactoryTest {
|
44 | 43 | @Rule public final MockitoRule mockito = MockitoJUnit.rule();
|
45 |
| - |
46 |
| - private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null); |
| 44 | + @Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule(); |
47 | 45 |
|
48 | 46 | @Mock private Connection connection;
|
49 | 47 | @Mock private ConnectionFactory connectionFactory;
|
50 | 48 |
|
51 | 49 | @Before
|
52 | 50 | public void setUp() {
|
53 |
| - RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set); |
54 |
| - |
55 | 51 | when(connectionFactory.create()).thenAnswer(invocation -> Single.just(connection));
|
56 | 52 | }
|
57 | 53 |
|
58 |
| - @After |
59 |
| - public void tearDown() throws Throwable { |
60 |
| - // Make sure rxjava didn't receive global errors |
61 |
| - Throwable t = rxGlobalThrowable.getAndSet(null); |
62 |
| - if (t != null) { |
63 |
| - throw t; |
64 |
| - } |
65 |
| - } |
66 |
| - |
67 | 54 | @Test
|
68 | 55 | public void create_smoke() {
|
69 | 56 | SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 1);
|
@@ -125,32 +112,37 @@ public void create_belowMaxConcurrency_shareConnections() {
|
125 | 112 |
|
126 | 113 | @Test
|
127 | 114 | public void create_concurrentCreate_shareConnections() throws InterruptedException {
|
128 |
| - SharedConnectionFactory factory = new SharedConnectionFactory(connectionFactory, 2); |
129 |
| - Semaphore semaphore = new Semaphore(0); |
130 |
| - AtomicBoolean finished = new AtomicBoolean(false); |
131 |
| - Thread t = |
132 |
| - new Thread( |
133 |
| - () -> { |
134 |
| - factory |
135 |
| - .create() |
136 |
| - .doOnSuccess( |
137 |
| - conn -> { |
138 |
| - assertThat(conn.getUnderlyingConnection()).isEqualTo(connection); |
139 |
| - semaphore.release(); |
140 |
| - Thread.sleep(Integer.MAX_VALUE); |
141 |
| - finished.set(true); |
142 |
| - }) |
143 |
| - .blockingSubscribe(); |
144 |
| - |
145 |
| - finished.set(true); |
146 |
| - }); |
147 |
| - t.start(); |
148 |
| - semaphore.acquire(); |
| 115 | + int maxConcurrency = 10; |
| 116 | + SharedConnectionFactory factory = |
| 117 | + new SharedConnectionFactory(connectionFactory, maxConcurrency); |
| 118 | + AtomicReference<Throwable> error = new AtomicReference<>(null); |
| 119 | + Runnable runnable = |
| 120 | + () -> { |
| 121 | + try { |
| 122 | + TestObserver<SharedConnection> observer = factory.create().test(); |
| 123 | + |
| 124 | + observer |
| 125 | + .assertNoErrors() |
| 126 | + .assertValue(conn -> conn.getUnderlyingConnection() == connection) |
| 127 | + .assertComplete(); |
| 128 | + } catch (Throwable e) { |
| 129 | + error.set(e); |
| 130 | + } |
| 131 | + }; |
| 132 | + Thread[] threads = new Thread[maxConcurrency]; |
| 133 | + for (int i = 0; i < threads.length; ++i) { |
| 134 | + threads[i] = new Thread(runnable); |
| 135 | + } |
149 | 136 |
|
150 |
| - TestObserver<SharedConnection> observer = factory.create().test(); |
| 137 | + for (Thread thread : threads) { |
| 138 | + thread.start(); |
| 139 | + } |
| 140 | + for (Thread thread : threads) { |
| 141 | + thread.join(); |
| 142 | + } |
151 | 143 |
|
152 |
| - observer.assertValue(conn -> conn.getUnderlyingConnection() == connection).assertComplete(); |
153 |
| - assertThat(finished.get()).isFalse(); |
| 144 | + assertThat(error.get()).isNull(); |
| 145 | + verify(connectionFactory, times(1)).create(); |
154 | 146 | }
|
155 | 147 |
|
156 | 148 | @Test
|
|
0 commit comments