@@ -28,16 +28,24 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
28
28
if torch .distributed .is_initialized ():
29
29
if not process_group :
30
30
process_group = torch .distributed .group .WORLD
31
+ device = mean .device
31
32
world_size = torch .distributed .get_world_size (process_group )
32
- mean_all = torch .empty (world_size , mean .size (0 ), dtype = mean .dtype , device = mean .device )
33
- var_all = torch .empty (world_size , var_biased .size (0 ), dtype = var_biased .dtype , device = var_biased .device )
33
+ mean_all = torch .empty (world_size , mean .size (0 ), dtype = mean .dtype , device = device )
34
+ var_all = torch .empty (world_size , var_biased .size (0 ), dtype = var_biased .dtype , device = device )
35
+ count_all = torch .cuda .IntTensor (world_size , device = device )
34
36
mean_l = [mean_all .narrow (0 , i , 1 ) for i in range (world_size )]
35
37
var_l = [var_all .narrow (0 , i , 1 ) for i in range (world_size )]
38
+ count_l = [count_all .narrow (0 , i , 1 ) for i in range (world_size )]
36
39
torch .distributed .all_gather (mean_l , mean , process_group )
37
40
torch .distributed .all_gather (var_l , var_biased , process_group )
38
- mean , var , inv_std = syncbn .welford_parallel (mean_all , var_all , count , eps )
39
- # TODO(Jie): should do fp32 math instead!
41
+ torch .distributed .all_gather (
42
+ count_l ,
43
+ torch .cuda .IntTensor ([count ], device = device ),
44
+ process_group )
45
+ mean , var , inv_std = syncbn .welford_parallel (mean_all , var_all , count_all , eps )
40
46
else :
47
+ device = mean .device
48
+ count_all = torch .cuda .IntTensor ([count ], device = device )
41
49
inv_std = 1.0 / torch .sqrt (var_biased + eps )
42
50
var = var_biased * (count ) / (count - 1 )
43
51
@@ -52,7 +60,7 @@ def forward(ctx, input, z, weight, bias, running_mean, running_variance, eps, tr
52
60
mean = running_mean .data
53
61
inv_std = 1.0 / torch .sqrt (running_variance .data + eps )
54
62
55
- ctx .save_for_backward (input , weight , mean , inv_std , z , bias )
63
+ ctx .save_for_backward (input , weight , mean , inv_std , z , bias , count_all )
56
64
ctx .process_group = process_group
57
65
ctx .channel_last = channel_last
58
66
ctx .world_size = world_size
@@ -71,7 +79,7 @@ def backward(ctx, grad_output):
71
79
# mini batch mean & var are calculated by forward path.
72
80
# mu = 1./N*np.sum(h, axis = 0)
73
81
# var = 1./N*np.sum((h-mu)**2, axis = 0)
74
- saved_input , weight , mean , inv_std , z , bias = ctx .saved_tensors
82
+ saved_input , weight , mean , inv_std , z , bias , count = ctx .saved_tensors
75
83
process_group = ctx .process_group
76
84
channel_last = ctx .channel_last
77
85
world_size = ctx .world_size
@@ -83,26 +91,24 @@ def backward(ctx, grad_output):
83
91
if isinstance (z , torch .Tensor ) and ctx .needs_input_grad [1 ]:
84
92
grad_z = grad_output .clone ()
85
93
86
- # TODO(jie): why do I have to clone here? life time of grad_output?
94
+ # TODO: update kernel to not pre_divide by item_num
87
95
if channel_last :
88
- mean_dy , mean_dy_xmu , grad_weight , grad_bias = syncbn .reduce_bn_c_last (grad_output , saved_input , mean , inv_std , weight )
96
+ sum_dy , sum_dy_xmu , grad_weight , grad_bias = syncbn .reduce_bn_c_last (grad_output , saved_input , mean , inv_std , weight )
89
97
else :
90
- mean_dy , mean_dy_xmu , grad_weight , grad_bias = syncbn .reduce_bn (grad_output , saved_input , mean , inv_std , weight )
98
+ sum_dy , sum_dy_xmu , grad_weight , grad_bias = syncbn .reduce_bn (grad_output , saved_input , mean , inv_std , weight )
91
99
92
100
# calculate grad_input
93
101
if ctx .needs_input_grad [0 ]:
94
102
95
103
if torch .distributed .is_initialized ():
96
104
torch .distributed .all_reduce (
97
- mean_dy , ReduceOp .SUM , process_group )
98
- mean_dy = mean_dy / world_size
105
+ sum_dy , ReduceOp .SUM , process_group )
99
106
torch .distributed .all_reduce (
100
- mean_dy_xmu , ReduceOp .SUM , process_group )
101
- mean_dy_xmu = mean_dy_xmu / world_size
107
+ sum_dy_xmu , ReduceOp .SUM , process_group )
102
108
if channel_last :
103
- grad_input = syncbn .batchnorm_backward_c_last (grad_output , saved_input , mean , inv_std , weight , mean_dy , mean_dy_xmu )
109
+ grad_input = syncbn .batchnorm_backward_c_last (grad_output , saved_input , mean , inv_std , weight , sum_dy , sum_dy_xmu , count )
104
110
else :
105
- grad_input = syncbn .batchnorm_backward (grad_output , saved_input , mean , inv_std , weight , mean_dy , mean_dy_xmu )
111
+ grad_input = syncbn .batchnorm_backward (grad_output , saved_input , mean , inv_std , weight , sum_dy , sum_dy_xmu , count )
106
112
107
113
if weight is None or not ctx .needs_input_grad [2 ]:
108
114
grad_weight = None
0 commit comments