Skip to content

Commit ef885c1

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommittedSep 22, 2020
[pytorch] Add triplet margin loss with custom distance (pytorch#43680)
Summary: Pull Request resolved: pytorch#43680 As discussed [here](pytorch#43342), adding in a Python-only implementation of the triplet-margin loss that takes a custom distance function. Still discussing whether this is necessary to add to PyTorch Core. Test Plan: python test/run_tests.py Imported from OSS Reviewed By: albanD Differential Revision: D23363898 fbshipit-source-id: 1cafc05abecdbe7812b41deaa1e50ea11239d0cb
1 parent 10f2875 commit ef885c1

File tree

8 files changed

+258
-11
lines changed

8 files changed

+258
-11
lines changed
 

‎docs/source/nn.functional.rst

+5-2
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,11 @@ Loss functions
483483

484484
.. autofunction:: triplet_margin_loss
485485

486+
:hidden:`triplet_margin_with_distance_loss`
487+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488+
489+
.. autofunction:: triplet_margin_with_distance_loss
490+
486491
Vision functions
487492
----------------
488493

@@ -533,5 +538,3 @@ DataParallel functions (multi-GPU, distributed)
533538
~~~~~~~~~~~~~~~~~~~~~~~
534539

535540
.. autofunction:: torch.nn.parallel.data_parallel
536-
537-

‎docs/source/nn.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ These are the basic building block for graphs
1010
:depth: 2
1111
:local:
1212
:backlinks: top
13-
13+
1414

1515
.. currentmodule:: torch.nn
1616

@@ -269,6 +269,7 @@ Loss Functions
269269
nn.CosineEmbeddingLoss
270270
nn.MultiMarginLoss
271271
nn.TripletMarginLoss
272+
nn.TripletMarginWithDistanceLoss
272273

273274
Vision Layers
274275
----------------

‎test/test_nn.py

+80
Original file line numberDiff line numberDiff line change
@@ -9866,6 +9866,7 @@ def v(fn):
98669866
v(lambda: F.multilabel_margin_loss(input, zeros, reduction=reduction))
98679867

98689868
v(lambda: F.triplet_margin_loss(input, input, input, reduction=reduction))
9869+
v(lambda: F.triplet_margin_with_distance_loss(input, input, input, reduction=reduction))
98699870
v(lambda: F.margin_ranking_loss(input, input, input.sign(), reduction=reduction))
98709871
v(lambda: F.cosine_embedding_loss(input, input, input[:, 0].sign(), reduction=reduction))
98719872

@@ -12185,6 +12186,85 @@ def test_threshold_inplace_overlap(self, device):
1218512186
F.threshold(x, 0.5, 0.5, inplace=True)
1218612187
F.threshold_(x, 0.5, 0.5)
1218712188

12189+
@onlyOnCPUAndCUDA
12190+
def test_triplet_margin_with_distance_loss_default_parity(self, device):
12191+
# Test for `nn.TripletMarginWithDistanceLoss` and
12192+
# `F.triplet_margin_with_distance_loss`. Checks
12193+
# for parity against the respective non-distance-agnostic
12194+
# implementations of triplet margin loss (``nn.TripletMarginLoss`
12195+
# and `F.triplet_margin_loss`) under *default args*.
12196+
12197+
for extra_args in \
12198+
itertools.product((0.5, 1, 1.5), (True, False), ('none', 'mean', 'sum')):
12199+
kwargs = {'margin': extra_args[0], 'swap': extra_args[1], 'reduction': extra_args[2]}
12200+
12201+
anchor = torch.randn(5, 10, device=device, requires_grad=True)
12202+
positive = torch.randn(5, 10, device=device, requires_grad=True)
12203+
negative = torch.randn(5, 10, device=device, requires_grad=True)
12204+
12205+
# Test forward, functional
12206+
expected = F.triplet_margin_loss(anchor, positive, negative, **kwargs)
12207+
actual = F.triplet_margin_with_distance_loss(anchor, positive, negative, **kwargs)
12208+
self.assertEqual(actual, expected, rtol=1e-6, atol=1e-6)
12209+
12210+
# Test forward, module
12211+
loss_ref = nn.TripletMarginLoss(**kwargs)
12212+
loss_op = nn.TripletMarginWithDistanceLoss(**kwargs)
12213+
self.assertEqual(loss_op(anchor, positive, negative),
12214+
loss_ref(anchor, positive, negative),
12215+
rtol=1e-6, atol=1e-6)
12216+
12217+
# Test backward
12218+
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12219+
a, p, n, **kwargs), (anchor, positive, negative)))
12220+
self.assertTrue(gradcheck(lambda a, p, n: loss_op(a, p, n),
12221+
(anchor, positive, negative)))
12222+
12223+
@onlyOnCPUAndCUDA
12224+
def test_triplet_margin_with_distance_loss(self, device):
12225+
# Test for parity between `nn.TripletMarginWithDistanceLoss` and
12226+
# `F.triplet_margin_with_distance_loss`.
12227+
12228+
pairwise_distance = nn.PairwiseDistance()
12229+
12230+
def cosine_distance(x, y):
12231+
return 1.0 - F.cosine_similarity(x, y)
12232+
12233+
distance_functions = (pairwise_distance, cosine_distance,
12234+
lambda x, y: 1.0 - F.cosine_similarity(x, y))
12235+
12236+
reductions = ('mean', 'none', 'sum')
12237+
margins = (1.0, 1.5, 0.5)
12238+
swaps = (True, False)
12239+
12240+
for distance_fn, reduction, margin, swap \
12241+
in itertools.product(distance_functions, reductions, margins, swaps):
12242+
anchor = torch.randn(5, 10, device=device, requires_grad=True)
12243+
positive = torch.randn(5, 10, device=device, requires_grad=True)
12244+
negative = torch.randn(5, 10, device=device, requires_grad=True)
12245+
12246+
# Test backward
12247+
self.assertTrue(gradcheck(lambda a, p, n: F.triplet_margin_with_distance_loss(
12248+
a, p, n, distance_function=distance_fn, reduction=reduction, margin=margin, swap=swap),
12249+
(anchor, positive, negative)))
12250+
loss_op = nn.TripletMarginWithDistanceLoss(distance_function=distance_fn,
12251+
reduction=reduction, margin=margin, swap=swap)
12252+
self.assertTrue(gradcheck(lambda a, p, n: loss_op(
12253+
a, p, n), (anchor, positive, negative)))
12254+
traced_loss_op = torch.jit.trace(loss_op, (anchor, positive, negative))
12255+
self.assertTrue(gradcheck(lambda a, p, n: traced_loss_op(
12256+
a, p, n), (anchor, positive, negative)))
12257+
12258+
# Test forward parity
12259+
functional = F.triplet_margin_with_distance_loss(anchor, positive, negative,
12260+
distance_function=distance_fn,
12261+
reduction=reduction, margin=margin, swap=swap)
12262+
modular = loss_op(anchor, positive, negative)
12263+
traced = traced_loss_op(anchor, positive, negative)
12264+
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
12265+
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)
12266+
12267+
1218812268
class TestModuleGlobalHooks(TestCase):
1218912269

1219012270
def tearDown(self):

‎torch/nn/functional.py

+36
Original file line numberDiff line numberDiff line change
@@ -3728,6 +3728,42 @@ def triplet_margin_loss(anchor, positive, negative, margin=1.0, p=2, eps=1e-6, s
37283728
swap, reduction_enum)
37293729

37303730

3731+
def triplet_margin_with_distance_loss(anchor, positive, negative, *, distance_function=None,
3732+
margin=1.0, swap=False, reduction="mean"):
3733+
# type: (Tensor, Tensor, Tensor, Optional[Callable[[Tensor, Tensor], Tensor]], float, bool, str) -> Tensor
3734+
r"""
3735+
See :class:`~torch.nn.TripletMarginWithDistanceLoss` for details.
3736+
"""
3737+
if torch.jit.is_scripting():
3738+
raise NotImplementedError("F.triplet_margin_with_distance_loss does not support JIT scripting: "
3739+
"functions requiring Callables cannot be scripted.")
3740+
3741+
tens_ops = (anchor, positive, negative)
3742+
if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
3743+
return handle_torch_function(
3744+
triplet_margin_with_distance_loss, tens_ops, anchor, positive, negative,
3745+
distance_function=distance_function, margin=margin, swap=swap, reduction=reduction)
3746+
3747+
distance_function = distance_function if distance_function is not None else pairwise_distance
3748+
3749+
positive_dist = distance_function(anchor, positive)
3750+
negative_dist = distance_function(anchor, negative)
3751+
3752+
if swap:
3753+
swap_dist = distance_function(positive, negative)
3754+
negative_dist = torch.min(negative_dist, swap_dist)
3755+
3756+
output = torch.clamp(positive_dist - negative_dist + margin, min=0.0)
3757+
3758+
reduction_enum = _Reduction.get_enum(reduction)
3759+
if reduction_enum == 1:
3760+
return output.mean()
3761+
elif reduction_enum == 2:
3762+
return output.sum()
3763+
else:
3764+
return output
3765+
3766+
37313767
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
37323768
# type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
37333769
r"""Performs :math:`L_p` normalization of inputs over specified dimension.

‎torch/nn/functional.pyi.in

+8-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ GRID_SAMPLE_PADDING_MODES = Dict[str, int]
2222
# This was necessary since the JIT uses BroadcastingList* types but static checking with mypy etc requires a `Sequence`
2323
# type. There is no way to express the expected lengths of these lists in the current Python typing system.
2424
#
25-
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
26-
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
27-
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
25+
# Functions created via `_add_docstr` in `functional.py` where merely typed as `Any` by `stubgen`, so those were
26+
# deleted from the stub and replaced by generated declarations. See `gen_pyi` for the implementation of the code
27+
# generation logic for those functions. In the future, it might be worth looking into using the mypy plugin system
2828
# to encode the type semantics of `_add_docstr`, should that system ever become widespread.
2929
def fractional_max_pool2d_with_indices(input: Tensor, kernel_size: _size, output_size: Optional[_size] = ...,
3030
output_ratio: Optional[_ratio_any_t] = ..., return_indices: bool = ...,
@@ -319,6 +319,11 @@ def triplet_margin_loss(anchor: Tensor, positive: Tensor, negative: Tensor, marg
319319
reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ...
320320

321321

322+
def triplet_margin_with_distance_loss(anchor: Tensor, positive: Tensor, negative: Tensor, *,
323+
distance_function: Optional[Callable[[Tensor, Tensor], Tensor]]=...,
324+
margin: float=..., swap: bool=..., reduction: str=...) -> Tensor: ...
325+
326+
322327
def normalize(input: Tensor, p: float = ..., dim: int = ..., eps: float = ...,
323328
out: Optional[Tensor] = ...) -> Tensor: ...
324329

‎torch/nn/modules/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
Hardsigmoid, Hardswish, SiLU
99
from .loss import L1Loss, NLLLoss, KLDivLoss, MSELoss, BCELoss, BCEWithLogitsLoss, NLLLoss2d, \
1010
CosineEmbeddingLoss, CTCLoss, HingeEmbeddingLoss, MarginRankingLoss, \
11-
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, \
12-
SmoothL1Loss, SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, PoissonNLLLoss
11+
MultiLabelMarginLoss, MultiLabelSoftMarginLoss, MultiMarginLoss, SmoothL1Loss, \
12+
SoftMarginLoss, CrossEntropyLoss, TripletMarginLoss, TripletMarginWithDistanceLoss, PoissonNLLLoss
1313
from .container import Container, Sequential, ModuleList, ModuleDict, ParameterList, ParameterDict
1414
from .pooling import AvgPool1d, AvgPool2d, AvgPool3d, MaxPool1d, MaxPool2d, MaxPool3d, \
1515
MaxUnpool1d, MaxUnpool2d, MaxUnpool3d, FractionalMaxPool2d, FractionalMaxPool3d, LPPool1d, LPPool2d, \
@@ -54,5 +54,5 @@
5454
'ConstantPad3d', 'Bilinear', 'CosineSimilarity', 'Unfold', 'Fold',
5555
'AdaptiveLogSoftmaxWithLoss', 'TransformerEncoder', 'TransformerDecoder',
5656
'TransformerEncoderLayer', 'TransformerDecoderLayer', 'Transformer',
57-
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU',
57+
'Flatten', 'Unflatten', 'Hardsigmoid', 'Hardswish', 'SiLU', 'TripletMarginWithDistanceLoss'
5858
]

‎torch/nn/modules/loss.py

+121-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import warnings
22

3+
from .distance import PairwiseDistance
34
from .module import Module
45
from .. import functional as F
56
from .. import _reduction as _Reduction
67

78
from torch import Tensor
8-
from typing import Optional
9+
from typing import Callable, Optional
910

1011

1112
class _Loss(Module):
@@ -1191,6 +1192,9 @@ class TripletMarginLoss(_Loss):
11911192
.. math::
11921193
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
11931194
1195+
See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
1196+
triplet margin loss for input tensors using a custom distance function.
1197+
11941198
Args:
11951199
margin (float, optional): Default: :math:`1`.
11961200
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
@@ -1215,7 +1219,8 @@ class TripletMarginLoss(_Loss):
12151219
12161220
Shape:
12171221
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
1218-
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
1222+
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
1223+
otherwise.
12191224
12201225
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
12211226
>>> anchor = torch.randn(100, 128, requires_grad=True)
@@ -1246,6 +1251,120 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
12461251
eps=self.eps, swap=self.swap, reduction=self.reduction)
12471252

12481253

1254+
class TripletMarginWithDistanceLoss(_Loss):
1255+
r"""Creates a criterion that measures the triplet loss given input
1256+
tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
1257+
positive, and negative examples, respectively), and a nonnegative,
1258+
real-valued function ("distance function") used to compute the relationship
1259+
between the anchor and positive example ("positive distance") and the
1260+
anchor and negative example ("negative distance").
1261+
1262+
The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
1263+
can be described as:
1264+
1265+
.. math::
1266+
\ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
1267+
l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
1268+
1269+
where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
1270+
quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
1271+
and :math:`margin` is a non-negative margin representing the minimum difference
1272+
between the positive and negative distances that is required for the loss to
1273+
be 0. The input tensors have :math:`N` elements each and can be of any shape
1274+
that the distance function can handle.
1275+
1276+
If :attr:`reduction` is not ``'none'``
1277+
(default ``'mean'``), then:
1278+
1279+
.. math::
1280+
\ell(x, y) =
1281+
\begin{cases}
1282+
\operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
1283+
\operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
1284+
\end{cases}
1285+
1286+
See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
1287+
loss for input tensors using the :math:`l_p` distance as the distance function.
1288+
1289+
Args:
1290+
distance_function (callable, optional): A nonnegative, real-valued function that
1291+
quantifies the closeness of two tensors. If not specified,
1292+
`nn.PairwiseDistance` will be used. Default: ``None``
1293+
margin (float, optional): A non-negative margin representing the minimum difference
1294+
between the positive and negative distances required for the loss to be 0. Larger
1295+
margins penalize cases where the negative examples are not distant enough from the
1296+
anchors, relative to the positives. Default: :math:`1`.
1297+
swap (bool, optional): Whether to use the distance swap described in the paper
1298+
`Learning shallow convolutional feature descriptors with triplet losses` by
1299+
V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
1300+
negative example than the anchor is, swaps the positive example and the anchor in
1301+
the loss computation. Default: ``False``.
1302+
reduction (string, optional): Specifies the (optional) reduction to apply to the output:
1303+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1304+
``'mean'``: the sum of the output will be divided by the number of
1305+
elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
1306+
1307+
1308+
Shape:
1309+
- Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
1310+
as supported by the distance function.
1311+
- Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
1312+
otherwise.
1313+
1314+
Examples::
1315+
1316+
>>> # Initialize embeddings
1317+
>>> embedding = nn.Embedding(1000, 128)
1318+
>>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1319+
>>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1320+
>>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1321+
>>> anchor = embedding(anchor_ids)
1322+
>>> positive = embedding(positive_ids)
1323+
>>> negative = embedding(negative_ids)
1324+
>>>
1325+
>>> # Built-in Distance Function
1326+
>>> triplet_loss = \
1327+
>>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
1328+
>>> output = triplet_loss(anchor, positive, negative)
1329+
>>> output.backward()
1330+
>>>
1331+
>>> # Custom Distance Function
1332+
>>> def l_infinity(x1, x2):
1333+
>>> return torch.max(torch.abs(x1 - x2), dim=1).values
1334+
>>>
1335+
>>> triplet_loss = \
1336+
>>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
1337+
>>> output = triplet_loss(anchor, positive, negative)
1338+
>>> output.backward()
1339+
>>>
1340+
>>> # Custom Distance Function (Lambda)
1341+
>>> triplet_loss = \
1342+
>>> nn.TripletMarginWithDistanceLoss(
1343+
>>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
1344+
>>> output = triplet_loss(anchor, positive, negative)
1345+
>>> output.backward()
1346+
1347+
Reference:
1348+
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
1349+
http://www.bmva.org/bmvc/2016/papers/paper119/index.html
1350+
"""
1351+
__constants__ = ['margin', 'swap', 'reduction']
1352+
margin: float
1353+
swap: bool
1354+
1355+
def __init__(self, *, distance_function: Optional[Callable[[Tensor, Tensor], Tensor]] = None,
1356+
margin: float = 1.0, swap: bool = False, reduction: str = 'mean'):
1357+
super(TripletMarginWithDistanceLoss, self).__init__(size_average=None, reduce=None, reduction=reduction)
1358+
self.distance_function = distance_function if distance_function is not None else PairwiseDistance()
1359+
self.margin = margin
1360+
self.swap = swap
1361+
1362+
def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
1363+
return F.triplet_margin_with_distance_loss(anchor, positive, negative,
1364+
distance_function=self.distance_function,
1365+
margin=self.margin, swap=self.swap, reduction=self.reduction)
1366+
1367+
12491368
class CTCLoss(_Loss):
12501369
r"""The Connectionist Temporal Classification loss.
12511370

‎torch/overrides.py

+3
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,9 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
624624
torch.nn.functional.threshold: lambda input, threshold, value, inplace=False: -1,
625625
torch.nn.functional.triplet_margin_loss: (lambda anchor, positive, negative, margin=1.0, p=2, eps=1e-06,
626626
swap=False, size_average=None, reduce=None, reduction='mean': -1),
627+
torch.nn.functional.triplet_margin_with_distance_loss: (lambda anchor, positive, negative, *,
628+
distance_function=None, margin=1.0,
629+
swap=False, reduction='mean': -1),
627630
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
628631
torch.nonzero: lambda input, as_tuple=False: -1,
629632
torch.norm: lambda input, p='fro', dim=None, keepdim=False, out=None, dtype=None: -1,

0 commit comments

Comments
 (0)
Please sign in to comment.