Skip to content

Commit b9336b1

Browse files
Fix use of multi_tensor_l2norm, remove test using deprecated syntax
1 parent 47da14a commit b9336b1

File tree

3 files changed

+5
-24
lines changed

3 files changed

+5
-24
lines changed

apex/amp/_process_optimizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,10 +263,11 @@ def post_backward_with_master_weights_FusedAdam(self, scaler):
263263
norm_groups = []
264264
skip = False
265265
for grad_group in stash.grads:
266-
norm = multi_tensor_applier(
266+
norm, _ = multi_tensor_applier(
267267
stash.multi_tensor_l2norm,
268268
stash.dummy_overflow_buf,
269-
[grad_group])
269+
[grad_group],
270+
False)
270271
# Still syncing here for now.
271272
norm = float(norm)
272273
norm_groups.append(norm)

tests/L0/run_amp/test_basic_casts.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,6 @@ def test_sum_is_float(self):
137137
fn = lambda x: x.sum()
138138
run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
139139

140-
class TestDisabledCasts(unittest.TestCase):
141-
def setUp(self):
142-
self.handle = amp.init(enabled=False)
143-
common_init(self)
144-
145-
def test_disabled_linear(self):
146-
m = nn.Linear(self.h, self.h)
147-
f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
148-
input_shape = (self.b, self.h)
149-
150-
for fn in [m, f]:
151-
x = torch.randn(input_shape, dtype=torch.float).requires_grad_()
152-
y = fn(x)
153-
self.assertEqual(y.type(), FLOAT)
154-
y.sum().backward()
155-
self.assertEqual(x.grad.type(), FLOAT)
156-
157-
x = torch.randn(input_shape, dtype=torch.half).requires_grad_()
158-
self.assertRaises(RuntimeError, fn, x)
159-
160140
# TODO: maybe more tests on disabled casting?
161141

162142
if __name__ == '__main__':

tests/L1/cross_product/run.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22

3-
DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
3+
# DATADIR="/home/mcarilli/Desktop/pt18data/apex_stale/examples/imagenet/bare_metal_train_val/"
44
# DATADIR="/opt/home/apex/examples/imagenet/"
55
cp ../common/* .
6-
bash run_test.sh single_gpu $1 $DATADIR yes
6+
bash run_test.sh single_gpu $1

0 commit comments

Comments
 (0)