Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a recently introduced race condition in DataLayer #2998

Merged
merged 1 commit into from
Aug 30, 2015

Conversation

longjon
Copy link
Contributor

@longjon longjon commented Aug 29, 2015

This PR fixes a race condition introduced to DataLayer by #2903, which can cause incorrect data to appear on top after a GPU forward. This affects single or multiple GPU usage with any data source.

DataLayer's forward copies prefetch GPU data -> top GPU data using caffe_copy. Meanwhile, the prefetch thread copies prefetch CPU data -> prefetch GPU data using a non-blocking CUDA stream.

caffe_copy is asynchronous wrt the host (when device -> device). That means these two copies can happen in any order, giving you either this batch's data or the next's (or some combination?). If you have two synchronized data sources (e.g., separate images and labels), this can be catastrophic.

Note that the queue pair is no help here; the batch is reinserted into the free queue immediately after the copy is issued, before it's completed.

To reproduce this issue easily, set PREFETCH_COUNT to 1, and put the copy https://github.com/BVLC/caffe/blob/master/src/caffe/layers/base_data_layer.cu#L14 in a loop that executes, e.g., 1000 times. That shouldn't affect correctness, but gives the race enough time to occur reliably (on my system, at least).

The fix here explicitly synchronizes the null stream used by caffe_copy. However, I think it requires CUDA 7. @thatguymike or others, what's the right way to do this without switching right away to CUDA 7?

It would be nice if there were some way to test that this doesn't happen again, but that seems difficult...

Please note: Caffe as such, with few exceptions, uses the default stream with no explicit synchronization. Layer calls are asynchronous wrt the host. (That's why there's, e.g., #2077.) caffe_copy (as cudaMemcpy) is asynchronous wrt the host when device -> device. If you create a non-blocking stream, don't expect it to be synchronous wrt any existing Caffe GPU code.

@@ -13,6 +13,9 @@ void BasePrefetchingDataLayer<Dtype>::Forward_gpu(
// Copy the data
caffe_copy(batch->data_.count(), batch->data_.gpu_data(),
top[0]->mutable_gpu_data());
// Ensure the copy is synchronous wrt the host, so that the next batch isn't
// copied in meanwhile.
CUDA_CHECK(cudaStreamSynchronize(cudaStreamLegacy));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @longjon for catching this and the explanation!

However, I am still a little confused why calling cudaStreamSynchronize(cudaStreamLegacy) here, rather than several lines below after batch->label_ is also copied. To avoid data race, I think this synchronization should occur right before prefetch_free_.push(batch) (correct me if I'm wrong).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I'll admit I overlooked this because I wasn't using labels, I don't actually see a race there. It looks like (https://github.com/BVLC/caffe/blob/master/src/caffe/layers/base_data_layer.cpp#L87) only the data portion is copied in a non-blocking stream, so any label access should be synchronized with the default stream.

However I think it would be clearer and safer to keep all copies synchronous between queues as you suggest, and I don't see any significant performance implications, so I'll update accordingly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. Thanks!

@ronghanghu ronghanghu added the bug label Aug 29, 2015
@cypof
Copy link
Member

cypof commented Aug 30, 2015

Thanks @longjon for finding this. Synchronizing on the default stream is done at a couple places in parallel.cpp. If cudaStreamLegacy needs to be used instead for old CUDA versions I can change it. P2PSync used to start async copies and call cudaStreamSynchronize once, but after @thatguymike tuning it calls sync after each copy, so we could have the same code as base_data_layer.cu.

Previously, the prefetch GPU -> top GPU and prefetch CPU -> prefetch GPU
copies were launched concurrently in separate streams, allowing the next
batch to be copied in before the current one is read.

This patch explicitly synchronizes the prefetch -> top copy wrt the
host, preventing the CPU -> GPU from being launched until its
completion.
@longjon
Copy link
Contributor Author

longjon commented Aug 30, 2015

@cypof, I think what I'm saying about cudaStreamLegacy is exactly the opposite, i.e., I should switch it to cudaStreamDefault here (which I've now done). The only reason I used cudaStreamLegacy here (which is CUDA 7) is because it's the only way I knew to refer to the default stream, from the CUDA 7 blog post. (By the way, is there a place I can find comprehensive, authoritative, and up-to-date CUDA documentation? Googling CUDA things usually yields the 4.1 doxygen... I've been baffled by the non-discoverability of CUDA documentation for a while now.)

@cypof
Copy link
Member

cypof commented Aug 30, 2015

Ah OK let's keep cudaStreamDefault everywhere then. For the docs I don't know, they have the programing guides and samples that can be helpful in addition to the headers/doxygen.

@ronghanghu
Copy link
Member

This PR should receive high priority since it affects the correctness of all ongoing training with DataLayer, no matter single GPU or multi GPU. I hope to merge this as soon as possible (if no one opposes).

@cypof
Copy link
Member

cypof commented Aug 30, 2015

+1

@jeffdonahue
Copy link
Contributor

LGTM, thanks for tracking this down @longjon!

jeffdonahue added a commit that referenced this pull request Aug 30, 2015
Fix a recently introduced race condition in DataLayer
@jeffdonahue jeffdonahue merged commit d362894 into BVLC:master Aug 30, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants