Skip to content

Commit 274cc06

Browse files
author
ngimel
authored
set device guard for multi tensor optimizer implementations (NVIDIA#927)
* add device guards to the optimizers * add untracked file * set deviceGuard in multi_tensor_apply * address review comments; fix lamb * indent * typo
1 parent 5b53121 commit 274cc06

File tree

8 files changed

+167
-157
lines changed

8 files changed

+167
-157
lines changed

apex/optimizers/fused_lamb.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, params, lr=1e-3, bias_correction=True,
7676
import amp_C
7777
self.multi_tensor_l2norm=amp_C.multi_tensor_l2norm
7878
# Skip buffer
79-
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
79+
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
8080
self.multi_tensor_lamb = amp_C.multi_tensor_lamb
8181
else:
8282
raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
@@ -117,7 +117,8 @@ def step(self, closure=None):
117117
else:
118118
raise RuntimeError('FusedLAMB only support fp16 and fp32.')
119119

120-
g_norm_32, g_norm_16 = torch.zeros(1, device='cuda'), torch.zeros(1, device='cuda')
120+
device = self.param_groups[0]["params"][0].device
121+
g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
121122
# compute grad norm for two lists
122123
if len(g_all_32) > 0:
123124
g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm,

apex/optimizers/fused_sgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, params, lr=required, momentum=0, dampening=0,
9898
if multi_tensor_applier.available:
9999
import amp_C
100100
# Skip buffer
101-
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
101+
self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
102102
self.multi_tensor_sgd = amp_C.multi_tensor_sgd
103103
else:
104104
raise RuntimeError('apex.optimizers.FusedSGD requires cuda extensions')

csrc/multi_tensor_apply.cuh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/AccumulateType.h>
33
#include <ATen/cuda/CUDAContext.h>
44
#include <ATen/cuda/Exceptions.h>
5+
#include <c10/cuda/CUDAGuard.h>
56
#include "compat.h"
67

78
#include <assert.h>
@@ -34,7 +35,7 @@ __global__ void multi_tensor_apply_kernel(
3435
ArgTypes... args)
3536
{
3637
// Hand the chunk information to the user-supplied functor to process however it likes.
37-
callable(chunk_size, noop_flag, tl, args...);
38+
callable(chunk_size, noop_flag, tl, args...);
3839
}
3940

4041
template<int depth, typename T, typename... ArgTypes>
@@ -49,8 +50,9 @@ void multi_tensor_apply(
4950
TORCH_CHECK(tensor_lists.size() == depth, "tensor_lists.size() != depth");
5051
int len0 = tensor_lists[0].size();
5152
TORCH_CHECK(len0 > 0, "tensor_lists[0].size() is not > 0");
52-
53-
for(int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
53+
auto ref_device = tensor_lists[0][0].device();
54+
TORCH_CHECK(ref_device.type() == at::kCUDA, "expected input to be on cuda");
55+
for (int l = 0; l < tensor_lists.size(); l++) // No range-based for because I need indices
5456
{
5557
TORCH_CHECK(tensor_lists[l].size() == len0, "Size mismatch among tensor lists");
5658
for(int t = 0; t < tensor_lists[l].size(); t++)
@@ -61,7 +63,7 @@ void multi_tensor_apply(
6163
contiguous_memory = (contiguous_memory || tensor_lists[l][t].is_contiguous(at::MemoryFormat::ChannelsLast));
6264
#endif
6365
TORCH_CHECK(contiguous_memory, "A tensor was not contiguous.");
64-
TORCH_CHECK(tensor_lists[l][t].is_cuda(), "A tensor was not cuda.");
66+
TORCH_CHECK(tensor_lists[l][t].device() == ref_device, "A tensor was not on the same device as the first tensor");
6567
TORCH_CHECK(tensor_lists[l][t].numel() == tensor_lists[0][t].numel(), "Size mismatch");
6668
}
6769
}
@@ -70,8 +72,9 @@ void multi_tensor_apply(
7072

7173
TensorListMetadata<depth> tl;
7274

75+
const at::cuda::OptionalCUDAGuard device_guard(device_of(tensor_lists[0][0]));
7376
auto stream = at::cuda::getCurrentCUDAStream();
74-
77+
7578
tl.start_tensor_this_launch = 0;
7679
int loc_block_info = 0;
7780
int loc_tensor_info = 0;
@@ -90,7 +93,7 @@ void multi_tensor_apply(
9093
tl.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
9194
tl.block_to_chunk[loc_block_info] = chunk;
9295
loc_block_info++;
93-
96+
9497
bool tensors_full = (loc_tensor_info == depth_to_max_tensors[depth-1] &&
9598
chunk == chunks_this_tensor - 1);
9699
bool blocks_full = (loc_block_info == depth_to_max_blocks[depth-1]);
@@ -112,7 +115,7 @@ void multi_tensor_apply(
112115
if(chunk == chunks_this_tensor - 1)
113116
{
114117
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
115-
loc_tensor_info = 0;
118+
loc_tensor_info = 0;
116119
tl.start_tensor_this_launch = t + 1;
117120
}
118121
else

csrc/multi_tensor_l2norm_kernel.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <ATen/AccumulateType.h>
33
#include <ATen/cuda/CUDAContext.h>
44
#include <ATen/cuda/Exceptions.h>
5+
#include <c10/cuda/CUDAGuard.h>
56
// Another possibility:
67
// #include <torch/all.h>
78

@@ -335,13 +336,13 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
335336
max_chunks_per_tensor);)
336337

337338
AT_CUDA_CHECK(cudaGetLastError());
338-
339339
// AT_CUDA_CHECK(cudaDeviceSynchronize());
340340

341341
// This involves one more small kernel launches, but will be negligible end to end.
342342
// I could get rid of these by hacking the functor + multi tensor harness with persistence
343343
// logic, but keeping it simple for now
344344
auto ret = at::empty({1}, output.options());
345+
const at::cuda::OptionalCUDAGuard device_guard(device_of(output));
345346
auto stream = at::cuda::getCurrentCUDAStream();
346347
cleanup<<<per_tensor ? ntensors : 1, 512, 0, stream>>>(
347348
output.DATA_PTR<float>(),
@@ -369,7 +370,7 @@ void multi_tensor_norm_out_cuda(
369370
const int norm_type)
370371
{
371372
auto float_options = tensor_lists[0][0].options().dtype(at::kFloat);
372-
373+
TORCH_CHECK(tensor_lists[0][0].device() == noop_flag.device(), "noop flag should be on the same device as tensors");
373374
// we don't need global thus uses empty here
374375
auto output = at::empty({320}, float_options);
375376

csrc/multi_tensor_sgd_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ void multi_tensor_sgd_cuda(
160160
TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
161161
"Additional output tensors should always be fp16.");
162162

163+
TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
164+
163165
// We have 3 possibilities to handle here, in terms of
164166
// grad_type, param_type, momentum_type, requires_fp16_copy
165167
// 1. fp16, fp16, fp16, No

tests/L0/run_optimizers/test_adagrad.py

Lines changed: 0 additions & 114 deletions
This file was deleted.

0 commit comments

Comments
 (0)