Skip to content

Commit 9d0c732

Browse files
coeuvrephilwo
authored andcommitted
Remote: Use AsyncTaskCache inside RemoteActionInputFetcher.
When using the dynamic scheduler, local actions may get interrupted or cancelled when the remote strategy is faster (e.g., remote cache hit). Ordinarily this isn't a problem, except when the local action is sharing a file download future with another local action. The interrupted thread of the local action cancels the future, and causes a CancellationExeception when the other local action thread tries to retrieve it. This resolves that problem by not letting threads/callers share the same future instance. The shared download future is only cancelled if all callers have requested cancellation. Fixes bazelbuild#12927. PiperOrigin-RevId: 362009791
1 parent ccad56c commit 9d0c732

File tree

3 files changed

+272
-114
lines changed

3 files changed

+272
-114
lines changed

src/main/java/com/google/devtools/build/lib/remote/RemoteActionInputFetcher.java

+52-70
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
import com.google.common.base.Preconditions;
2020
import com.google.common.collect.ImmutableSet;
2121
import com.google.common.flogger.GoogleLogger;
22-
import com.google.common.util.concurrent.FutureCallback;
23-
import com.google.common.util.concurrent.Futures;
2422
import com.google.common.util.concurrent.ListenableFuture;
2523
import com.google.common.util.concurrent.MoreExecutors;
2624
import com.google.devtools.build.lib.actions.ActionInput;
@@ -34,17 +32,17 @@
3432
import com.google.devtools.build.lib.profiler.SilentCloseable;
3533
import com.google.devtools.build.lib.remote.common.CacheNotFoundException;
3634
import com.google.devtools.build.lib.remote.common.RemoteActionExecutionContext;
35+
import com.google.devtools.build.lib.remote.util.AsyncTaskCache;
3736
import com.google.devtools.build.lib.remote.util.DigestUtil;
37+
import com.google.devtools.build.lib.remote.util.RxFutures;
3838
import com.google.devtools.build.lib.remote.util.TracingMetadataUtils;
3939
import com.google.devtools.build.lib.remote.util.Utils;
4040
import com.google.devtools.build.lib.sandbox.SandboxHelpers;
4141
import com.google.devtools.build.lib.vfs.Path;
42+
import io.reactivex.rxjava3.core.Completable;
4243
import java.io.IOException;
4344
import java.util.HashMap;
44-
import java.util.HashSet;
4545
import java.util.Map;
46-
import java.util.Set;
47-
import javax.annotation.concurrent.GuardedBy;
4846

4947
/**
5048
* Stages output files that are stored remotely to the local filesystem.
@@ -55,17 +53,10 @@
5553
class RemoteActionInputFetcher implements ActionInputPrefetcher {
5654

5755
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
56+
private final AsyncTaskCache.NoResult<Path> downloadCache = AsyncTaskCache.NoResult.create();
5857

5958
private final Object lock = new Object();
6059

61-
/** Set of successfully downloaded output files. */
62-
@GuardedBy("lock")
63-
private final Set<Path> downloadedPaths = new HashSet<>();
64-
65-
@VisibleForTesting
66-
@GuardedBy("lock")
67-
final Map<Path, ListenableFuture<Void>> downloadsInProgress = new HashMap<>();
68-
6960
private final String buildRequestId;
7061
private final String commandId;
7162
private final RemoteCache remoteCache;
@@ -110,11 +101,8 @@ public void prefetchFiles(
110101

111102
Path path = execRoot.getRelative(input.getExecPath());
112103
synchronized (lock) {
113-
if (downloadedPaths.contains(path)) {
114-
continue;
115-
}
116-
ListenableFuture<Void> download = downloadFileAsync(path, metadata);
117-
downloadsToWaitFor.putIfAbsent(path, download);
104+
downloadsToWaitFor.computeIfAbsent(
105+
path, key -> RxFutures.toListenableFuture(downloadFileAsync(path, metadata)));
118106
}
119107
}
120108
}
@@ -143,65 +131,59 @@ public void prefetchFiles(
143131
}
144132

145133
ImmutableSet<Path> downloadedFiles() {
146-
synchronized (lock) {
147-
return ImmutableSet.copyOf(downloadedPaths);
148-
}
134+
return downloadCache.getFinishedTasks();
135+
}
136+
137+
ImmutableSet<Path> downloadsInProgress() {
138+
return downloadCache.getInProgressTasks();
139+
}
140+
141+
@VisibleForTesting
142+
AsyncTaskCache.NoResult<Path> getDownloadCache() {
143+
return downloadCache;
149144
}
150145

151146
void downloadFile(Path path, FileArtifactValue metadata)
152147
throws IOException, InterruptedException {
153-
Utils.getFromFuture(downloadFileAsync(path, metadata));
148+
Utils.getFromFuture(RxFutures.toListenableFuture(downloadFileAsync(path, metadata)));
154149
}
155150

156-
private ListenableFuture<Void> downloadFileAsync(Path path, FileArtifactValue metadata)
157-
throws IOException {
158-
synchronized (lock) {
159-
if (downloadedPaths.contains(path)) {
160-
return Futures.immediateFuture(null);
161-
}
151+
private Completable downloadFileAsync(Path path, FileArtifactValue metadata) {
152+
Completable download =
153+
RxFutures.toCompletable(
154+
() -> {
155+
RequestMetadata requestMetadata =
156+
TracingMetadataUtils.buildMetadata(
157+
buildRequestId, commandId, metadata.getActionId());
158+
RemoteActionExecutionContext context =
159+
RemoteActionExecutionContext.create(requestMetadata);
160+
161+
Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
162+
163+
return remoteCache.downloadFile(context, path, digest);
164+
},
165+
MoreExecutors.directExecutor())
166+
.doOnComplete(() -> finalizeDownload(path))
167+
.doOnError(error -> deletePartialDownload(path))
168+
.doOnDispose(() -> deletePartialDownload(path));
169+
170+
return downloadCache.executeIfNot(path, download);
171+
}
162172

163-
ListenableFuture<Void> download = downloadsInProgress.get(path);
164-
if (download == null) {
165-
RequestMetadata requestMetadata =
166-
TracingMetadataUtils.buildMetadata(buildRequestId, commandId, metadata.getActionId());
167-
RemoteActionExecutionContext context = RemoteActionExecutionContext.create(requestMetadata);
168-
169-
Digest digest = DigestUtil.buildDigest(metadata.getDigest(), metadata.getSize());
170-
download = remoteCache.downloadFile(context, path, digest);
171-
downloadsInProgress.put(path, download);
172-
Futures.addCallback(
173-
download,
174-
new FutureCallback<Void>() {
175-
@Override
176-
public void onSuccess(Void v) {
177-
synchronized (lock) {
178-
downloadsInProgress.remove(path);
179-
downloadedPaths.add(path);
180-
}
181-
182-
try {
183-
path.chmod(0755);
184-
} catch (IOException e) {
185-
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
186-
}
187-
}
188-
189-
@Override
190-
public void onFailure(Throwable t) {
191-
synchronized (lock) {
192-
downloadsInProgress.remove(path);
193-
}
194-
try {
195-
path.delete();
196-
} catch (IOException e) {
197-
logger.atWarning().withCause(e).log(
198-
"Failed to delete output file after incomplete download: %s", path);
199-
}
200-
}
201-
},
202-
MoreExecutors.directExecutor());
203-
}
204-
return download;
173+
private void finalizeDownload(Path path) {
174+
try {
175+
path.chmod(0755);
176+
} catch (IOException e) {
177+
logger.atWarning().withCause(e).log("Failed to chmod 755 on %s", path);
178+
}
179+
}
180+
181+
private void deletePartialDownload(Path path) {
182+
try {
183+
path.delete();
184+
} catch (IOException e) {
185+
logger.atWarning().withCause(e).log(
186+
"Failed to delete output file after incomplete download: %s", path);
205187
}
206188
}
207189
}

src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java

+108-39
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,21 @@
1313
// limitations under the License.
1414
package com.google.devtools.build.lib.remote.util;
1515

16-
import com.google.common.base.Preconditions;
16+
import static com.google.common.base.Preconditions.checkState;
17+
1718
import com.google.common.collect.ImmutableSet;
19+
import io.reactivex.rxjava3.annotations.NonNull;
1820
import io.reactivex.rxjava3.core.Completable;
19-
import io.reactivex.rxjava3.core.Observable;
2021
import io.reactivex.rxjava3.core.Single;
22+
import io.reactivex.rxjava3.core.SingleObserver;
23+
import io.reactivex.rxjava3.disposables.Disposable;
24+
import io.reactivex.rxjava3.subjects.AsyncSubject;
2125
import java.util.HashMap;
2226
import java.util.Map;
2327
import java.util.Optional;
28+
import java.util.concurrent.CancellationException;
2429
import java.util.concurrent.atomic.AtomicInteger;
30+
import java.util.concurrent.atomic.AtomicReference;
2531
import javax.annotation.concurrent.GuardedBy;
2632
import javax.annotation.concurrent.ThreadSafe;
2733

@@ -42,11 +48,13 @@
4248
*/
4349
@ThreadSafe
4450
public final class AsyncTaskCache<KeyT, ValueT> {
45-
@GuardedBy("this")
51+
private final Object lock = new Object();
52+
53+
@GuardedBy("lock")
4654
private final Map<KeyT, ValueT> finished;
4755

48-
@GuardedBy("this")
49-
private final Map<KeyT, Observable<ValueT>> inProgress;
56+
@GuardedBy("lock")
57+
private final Map<KeyT, Execution> inProgress;
5058

5159
public static <KeyT, ValueT> AsyncTaskCache<KeyT, ValueT> create() {
5260
return new AsyncTaskCache<>();
@@ -59,14 +67,14 @@ private AsyncTaskCache() {
5967

6068
/** Returns a set of keys for tasks which is finished. */
6169
public ImmutableSet<KeyT> getFinishedTasks() {
62-
synchronized (this) {
70+
synchronized (lock) {
6371
return ImmutableSet.copyOf(finished.keySet());
6472
}
6573
}
6674

6775
/** Returns a set of keys for tasks which is still executing. */
6876
public ImmutableSet<KeyT> getInProgressTasks() {
69-
synchronized (this) {
77+
synchronized (lock) {
7078
return ImmutableSet.copyOf(inProgress.keySet());
7179
}
7280
}
@@ -82,6 +90,65 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
8290
return execute(key, task, false);
8391
}
8492

93+
private class Execution {
94+
private final Single<ValueT> task;
95+
private final AsyncSubject<ValueT> asyncSubject = AsyncSubject.create();
96+
private final AtomicInteger subscriberCount = new AtomicInteger(0);
97+
private final AtomicReference<Disposable> taskDisposable = new AtomicReference<>(null);
98+
99+
Execution(Single<ValueT> task) {
100+
this.task = task;
101+
}
102+
103+
public Single<ValueT> start() {
104+
if (taskDisposable.get() == null) {
105+
task.subscribe(
106+
new SingleObserver<ValueT>() {
107+
@Override
108+
public void onSubscribe(@NonNull Disposable d) {
109+
taskDisposable.compareAndSet(null, d);
110+
}
111+
112+
@Override
113+
public void onSuccess(@NonNull ValueT value) {
114+
asyncSubject.onNext(value);
115+
asyncSubject.onComplete();
116+
}
117+
118+
@Override
119+
public void onError(@NonNull Throwable e) {
120+
asyncSubject.onError(e);
121+
}
122+
});
123+
}
124+
125+
return Single.fromObservable(asyncSubject)
126+
.doOnSubscribe(d -> subscriberCount.incrementAndGet())
127+
.doOnDispose(
128+
() -> {
129+
if (subscriberCount.decrementAndGet() == 0) {
130+
Disposable d = taskDisposable.get();
131+
if (d != null) {
132+
d.dispose();
133+
}
134+
asyncSubject.onError(new CancellationException("disposed"));
135+
}
136+
});
137+
}
138+
}
139+
140+
/** Returns count of subscribers for a task. */
141+
public int getSubscriberCount(KeyT key) {
142+
synchronized (lock) {
143+
Execution execution = inProgress.get(key);
144+
if (execution != null) {
145+
return execution.subscriberCount.get();
146+
}
147+
}
148+
149+
return 0;
150+
}
151+
85152
/**
86153
* Executes a task.
87154
*
@@ -93,50 +160,47 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
93160
public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
94161
return Single.defer(
95162
() -> {
96-
synchronized (this) {
163+
synchronized (lock) {
97164
if (!force && finished.containsKey(key)) {
98165
return Single.just(finished.get(key));
99166
}
100167

101168
finished.remove(key);
102169

103-
Observable<ValueT> execution =
170+
Execution execution =
104171
inProgress.computeIfAbsent(
105172
key,
106173
missingKey -> {
107174
AtomicInteger subscribeTimes = new AtomicInteger(0);
108-
return Single.defer(
109-
() -> {
110-
int times = subscribeTimes.incrementAndGet();
111-
Preconditions.checkState(
112-
times == 1, "Subscribed more than once to the task");
113-
return task;
114-
})
115-
.doOnSuccess(
116-
value -> {
117-
synchronized (this) {
118-
finished.put(key, value);
119-
inProgress.remove(key);
120-
}
121-
})
122-
.doOnError(
123-
error -> {
124-
synchronized (this) {
125-
inProgress.remove(key);
126-
}
127-
})
128-
.doOnDispose(
129-
() -> {
130-
synchronized (this) {
131-
inProgress.remove(key);
132-
}
133-
})
134-
.toObservable()
135-
.publish()
136-
.refCount();
175+
return new Execution(
176+
Single.defer(
177+
() -> {
178+
int times = subscribeTimes.incrementAndGet();
179+
checkState(times == 1, "Subscribed more than once to the task");
180+
return task;
181+
})
182+
.doOnSuccess(
183+
value -> {
184+
synchronized (lock) {
185+
finished.put(key, value);
186+
inProgress.remove(key);
187+
}
188+
})
189+
.doOnError(
190+
error -> {
191+
synchronized (lock) {
192+
inProgress.remove(key);
193+
}
194+
})
195+
.doOnDispose(
196+
() -> {
197+
synchronized (lock) {
198+
inProgress.remove(key);
199+
}
200+
}));
137201
});
138202

139-
return Single.fromObservable(execution);
203+
return execution.start();
140204
}
141205
});
142206
}
@@ -174,5 +238,10 @@ public ImmutableSet<KeyT> getFinishedTasks() {
174238
public ImmutableSet<KeyT> getInProgressTasks() {
175239
return cache.getInProgressTasks();
176240
}
241+
242+
/** Returns count of subscribers for a task. */
243+
public int getSubscriberCount(KeyT key) {
244+
return cache.getSubscriberCount(key);
245+
}
177246
}
178247
}

0 commit comments

Comments
 (0)