Skip to content

Commit 6417a70

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedSep 28, 2020
Updates linalg warning + docs (pytorch#45415)
Summary: Changes the deprecation of norm to a docs deprecation, since PyTorch components still rely on norm and some behavior, like automatically flattening tensors, may need to be ported to torch.linalg.norm. The documentation is also updated to clarify that torch.norm and torch.linalg.norm are distinct. Pull Request resolved: pytorch#45415 Reviewed By: ngimel Differential Revision: D23958252 Pulled By: mruberry fbshipit-source-id: fd54e807c59a2655453a6bcd9f4073cb2c12e8ac
1 parent 7818a21 commit 6417a70

File tree

2 files changed

+3
-18
lines changed

2 files changed

+3
-18
lines changed
 

‎test/test_linalg.py

-13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
import unittest
33
import itertools
4-
import warnings
54
from math import inf, nan, isnan
65

76
from torch.testing._internal.common_utils import \
@@ -654,18 +653,6 @@ def run_test_case(input, ord, dim, keepdim, should_error):
654653
for ord in ord_matrix:
655654
run_test_case(input, ord, dim, keepdim, ord in error_ords)
656655

657-
def test_norm_deprecated(self, device):
658-
expected_message = (
659-
r'torch.norm is deprecated and may be removed in a future PyTorch release. '
660-
r'Use torch.linalg.norm instead.')
661-
with warnings.catch_warnings(record=True) as w:
662-
warnings.simplefilter("always")
663-
for func in [torch.norm, torch.functional.norm]:
664-
func(torch.rand(10, device=device))
665-
self.assertEqual(len(w), 2)
666-
for wi in w:
667-
self.assertEqual(str(wi.message), expected_message)
668-
669656
def test_norm_fastpaths(self, device):
670657
x = torch.randn(3, 5, device=device)
671658

‎torch/functional.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from .overrides import has_torch_function, handle_torch_function
1010
from ._jit_internal import boolean_dispatch, List
1111
from ._jit_internal import _overload as overload
12-
import warnings
1312

1413
Tensor = torch.Tensor
1514
from torch import _VF
@@ -1214,7 +1213,9 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
12141213
.. warning::
12151214
12161215
torch.norm is deprecated and may be removed in a future PyTorch release.
1217-
Use :func:`torch.linalg.norm` instead.
1216+
Use :func:`torch.linalg.norm` instead, but note that :func:`torch.linalg.norm`
1217+
has a different signature and slightly different behavior that is
1218+
more consistent with NumPy's numpy.linalg.norm.
12181219
12191220
Args:
12201221
input (Tensor): the input tensor
@@ -1273,9 +1274,6 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
12731274
>>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
12741275
(tensor(3.7417), tensor(11.2250))
12751276
"""
1276-
warnings.warn((
1277-
"torch.norm is deprecated and may be removed in a future PyTorch release. "
1278-
"Use torch.linalg.norm instead."))
12791277

12801278
if not torch.jit.is_scripting():
12811279
if type(input) is not Tensor and has_torch_function((input,)):

0 commit comments

Comments
 (0)
Please sign in to comment.