Skip to content

Commit 1ff54b8

Browse files
authored
[sync BN] (NVIDIA#792)
* [sync BN] support non-uniform batch size across process group. TODO: test should be added once cleaned up. * updating unit tests * new unit tests for different inputs * cleaning
1 parent 43a6f9f commit 1ff54b8

File tree

7 files changed

+290
-83
lines changed

7 files changed

+290
-83
lines changed

apex/parallel/optimized_sync_batchnorm_kernel.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,24 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
2828
if torch.distributed.is_initialized():
2929
if not process_group:
3030
process_group = torch.distributed.group.WORLD
31+
device = mean.device
3132
world_size = torch.distributed.get_world_size(process_group)
32-
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=mean.device)
33-
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=var_biased.device)
33+
mean_all = torch.empty(world_size, mean.size(0), dtype=mean.dtype, device=device)
34+
var_all = torch.empty(world_size, var_biased.size(0), dtype=var_biased.dtype, device=device)
35+
count_all = torch.cuda.IntTensor(world_size, device=device)
3436
mean_l = [mean_all.narrow(0, i, 1) for i in range(world_size)]
3537
var_l = [var_all.narrow(0, i, 1) for i in range(world_size)]
38+
count_l = [count_all.narrow(0, i, 1) for i in range(world_size)]
3639
torch.distributed.all_gather(mean_l, mean, process_group)
3740
torch.distributed.all_gather(var_l, var_biased, process_group)
38-
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count, eps)
39-
# TODO(Jie): should do fp32 math instead!
41+
torch.distributed.all_gather(
42+
count_l,
43+
torch.cuda.IntTensor([count], device=device),
44+
process_group)
45+
mean, var, inv_std = syncbn.welford_parallel(mean_all, var_all, count_all, eps)
4046
else:
47+
device = mean.device
48+
count_all = torch.cuda.IntTensor([count], device=device)
4149
inv_std = 1.0 / torch.sqrt(var_biased + eps)
4250
var = var_biased * (count) / (count-1)
4351

@@ -52,7 +60,7 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
5260
mean = running_mean.data
5361
inv_std = 1.0 / torch.sqrt(running_variance.data + eps)
5462

55-
ctx.save_for_backward(input, weight, mean, inv_std, z, bias)
63+
ctx.save_for_backward(input, weight, mean, inv_std, z, bias, count_all)
5664
ctx.process_group = process_group
5765
ctx.channel_last = channel_last
5866
ctx.world_size = world_size
@@ -71,7 +79,7 @@ def backward(ctx, grad_output):
7179
# mini batch mean & var are calculated by forward path.
7280
# mu = 1./N*np.sum(h, axis = 0)
7381
# var = 1./N*np.sum((h-mu)**2, axis = 0)
74-
saved_input, weight, mean, inv_std, z, bias = ctx.saved_tensors
82+
saved_input, weight, mean, inv_std, z, bias, count = ctx.saved_tensors
7583
process_group = ctx.process_group
7684
channel_last = ctx.channel_last
7785
world_size = ctx.world_size
@@ -83,26 +91,24 @@ def backward(ctx, grad_output):
8391
if isinstance(z, torch.Tensor) and ctx.needs_input_grad[1]:
8492
grad_z = grad_output.clone()
8593

86-
# TODO(jie): why do I have to clone here? life time of grad_output?
94+
# TODO: update kernel to not pre_divide by item_num
8795
if channel_last:
88-
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
96+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn_c_last(grad_output, saved_input, mean, inv_std, weight)
8997
else:
90-
mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
98+
sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output, saved_input, mean, inv_std, weight)
9199

92100
# calculate grad_input
93101
if ctx.needs_input_grad[0]:
94102

95103
if torch.distributed.is_initialized():
96104
torch.distributed.all_reduce(
97-
mean_dy, ReduceOp.SUM, process_group)
98-
mean_dy = mean_dy / world_size
105+
sum_dy, ReduceOp.SUM, process_group)
99106
torch.distributed.all_reduce(
100-
mean_dy_xmu, ReduceOp.SUM, process_group)
101-
mean_dy_xmu = mean_dy_xmu / world_size
107+
sum_dy_xmu, ReduceOp.SUM, process_group)
102108
if channel_last:
103-
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
109+
grad_input = syncbn.batchnorm_backward_c_last(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
104110
else:
105-
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, mean_dy, mean_dy_xmu)
111+
grad_input = syncbn.batchnorm_backward(grad_output, saved_input, mean, inv_std, weight, sum_dy, sum_dy_xmu, count)
106112

107113
if weight is None or not ctx.needs_input_grad[2]:
108114
grad_weight = None

csrc/syncbn.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
1212
// implemented using welford
1313
std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
1414
const at::Tensor var_biased_feature_nodes,
15-
int numel,
15+
const at::Tensor numel,
1616
const float eps);
1717

1818
// elementwise BN operation, returns output
@@ -24,7 +24,7 @@ at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
2424
const at::optional<at::Tensor> weight,
2525
const at::optional<at::Tensor> shift);
2626

27-
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
27+
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
2828
// grad_output/input should have identical data type;
2929
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
3030
// implemented using kahan summation
@@ -36,14 +36,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
3636

3737
// elementwise backward BN operation, returns grad_input
3838
// grad_output/input/weight precision could be fp16/fp32;
39-
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
39+
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
4040
at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
4141
const at::Tensor input,
4242
const at::Tensor mean,
4343
const at::Tensor inv_std,
4444
const at::optional<at::Tensor> weight,
45-
const at::Tensor mean_dy,
46-
const at::Tensor mean_dy_xmu);
45+
const at::Tensor sum_dy,
46+
const at::Tensor sum_dy_xmu,
47+
const at::Tensor count);
4748

4849
// returns {mean, biased_var}
4950
// implemented using welford
@@ -62,7 +63,7 @@ at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
6263
const at::optional<at::Tensor> shift,
6364
const bool fuse_relu);
6465

65-
// backward BN operation, returns {mean_dy, mean_dy_xmu, grad_weight, grad_bias}
66+
// backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
6667
// grad_output/input should have identical data type;
6768
// mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
6869
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
@@ -74,15 +75,16 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
7475

7576
// elementwise backward BN operation, returns grad_input
7677
// grad_output/input/weight precision could be fp16/fp32;
77-
// mean/inv_std/mean_dy/mean_dy_xmu precision is fp32
78+
// mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
7879
// expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
7980
at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
8081
const at::Tensor input,
8182
const at::Tensor mean,
8283
const at::Tensor inv_std,
8384
const at::optional<at::Tensor> weight,
84-
const at::Tensor mean_dy,
85-
const at::Tensor mean_dy_xmu);
85+
const at::Tensor sum_dy,
86+
const at::Tensor sum_dy_xmu,
87+
const at::Tensor count);
8688

8789
at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
8890
const at::Tensor input,

0 commit comments

Comments
 (0)