1
1
import warnings
2
2
3
+ from .distance import PairwiseDistance
3
4
from .module import Module
4
5
from .. import functional as F
5
6
from .. import _reduction as _Reduction
6
7
7
8
from torch import Tensor
8
- from typing import Optional
9
+ from typing import Callable , Optional
9
10
10
11
11
12
class _Loss (Module ):
@@ -1191,6 +1192,9 @@ class TripletMarginLoss(_Loss):
1191
1192
.. math::
1192
1193
d(x_i, y_i) = \left\lVert {\bf x}_i - {\bf y}_i \right\rVert_p
1193
1194
1195
+ See also :class:`~torch.nn.TripletMarginWithDistanceLoss`, which computes the
1196
+ triplet margin loss for input tensors using a custom distance function.
1197
+
1194
1198
Args:
1195
1199
margin (float, optional): Default: :math:`1`.
1196
1200
p (int, optional): The norm degree for pairwise distance. Default: :math:`2`.
@@ -1215,7 +1219,8 @@ class TripletMarginLoss(_Loss):
1215
1219
1216
1220
Shape:
1217
1221
- Input: :math:`(N, D)` where :math:`D` is the vector dimension.
1218
- - Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
1222
+ - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
1223
+ otherwise.
1219
1224
1220
1225
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
1221
1226
>>> anchor = torch.randn(100, 128, requires_grad=True)
@@ -1246,6 +1251,120 @@ def forward(self, anchor: Tensor, positive: Tensor, negative: Tensor) -> Tensor:
1246
1251
eps = self .eps , swap = self .swap , reduction = self .reduction )
1247
1252
1248
1253
1254
+ class TripletMarginWithDistanceLoss (_Loss ):
1255
+ r"""Creates a criterion that measures the triplet loss given input
1256
+ tensors :math:`a`, :math:`p`, and :math:`n` (representing anchor,
1257
+ positive, and negative examples, respectively), and a nonnegative,
1258
+ real-valued function ("distance function") used to compute the relationship
1259
+ between the anchor and positive example ("positive distance") and the
1260
+ anchor and negative example ("negative distance").
1261
+
1262
+ The unreduced loss (i.e., with :attr:`reduction` set to ``'none'``)
1263
+ can be described as:
1264
+
1265
+ .. math::
1266
+ \ell(a, p, n) = L = \{l_1,\dots,l_N\}^\top, \quad
1267
+ l_i = \max \{d(a_i, p_i) - d(a_i, n_i) + {\rm margin}, 0\}
1268
+
1269
+ where :math:`N` is the batch size; :math:`d` is a nonnegative, real-valued function
1270
+ quantifying the closeness of two tensors, referred to as the :attr:`distance_function`;
1271
+ and :math:`margin` is a non-negative margin representing the minimum difference
1272
+ between the positive and negative distances that is required for the loss to
1273
+ be 0. The input tensors have :math:`N` elements each and can be of any shape
1274
+ that the distance function can handle.
1275
+
1276
+ If :attr:`reduction` is not ``'none'``
1277
+ (default ``'mean'``), then:
1278
+
1279
+ .. math::
1280
+ \ell(x, y) =
1281
+ \begin{cases}
1282
+ \operatorname{mean}(L), & \text{if reduction} = \text{`mean';}\\
1283
+ \operatorname{sum}(L), & \text{if reduction} = \text{`sum'.}
1284
+ \end{cases}
1285
+
1286
+ See also :class:`~torch.nn.TripletMarginLoss`, which computes the triplet
1287
+ loss for input tensors using the :math:`l_p` distance as the distance function.
1288
+
1289
+ Args:
1290
+ distance_function (callable, optional): A nonnegative, real-valued function that
1291
+ quantifies the closeness of two tensors. If not specified,
1292
+ `nn.PairwiseDistance` will be used. Default: ``None``
1293
+ margin (float, optional): A non-negative margin representing the minimum difference
1294
+ between the positive and negative distances required for the loss to be 0. Larger
1295
+ margins penalize cases where the negative examples are not distant enough from the
1296
+ anchors, relative to the positives. Default: :math:`1`.
1297
+ swap (bool, optional): Whether to use the distance swap described in the paper
1298
+ `Learning shallow convolutional feature descriptors with triplet losses` by
1299
+ V. Balntas, E. Riba et al. If True, and if the positive example is closer to the
1300
+ negative example than the anchor is, swaps the positive example and the anchor in
1301
+ the loss computation. Default: ``False``.
1302
+ reduction (string, optional): Specifies the (optional) reduction to apply to the output:
1303
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
1304
+ ``'mean'``: the sum of the output will be divided by the number of
1305
+ elements in the output, ``'sum'``: the output will be summed. Default: ``'mean'``
1306
+
1307
+
1308
+ Shape:
1309
+ - Input: :math:`(N, *)` where :math:`*` represents any number of additional dimensions
1310
+ as supported by the distance function.
1311
+ - Output: A Tensor of shape :math:`(N)` if :attr:`reduction` is ``'none'``, or a scalar
1312
+ otherwise.
1313
+
1314
+ Examples::
1315
+
1316
+ >>> # Initialize embeddings
1317
+ >>> embedding = nn.Embedding(1000, 128)
1318
+ >>> anchor_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1319
+ >>> positive_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1320
+ >>> negative_ids = torch.randint(0, 1000, (1,), requires_grad=True)
1321
+ >>> anchor = embedding(anchor_ids)
1322
+ >>> positive = embedding(positive_ids)
1323
+ >>> negative = embedding(negative_ids)
1324
+ >>>
1325
+ >>> # Built-in Distance Function
1326
+ >>> triplet_loss = \
1327
+ >>> nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
1328
+ >>> output = triplet_loss(anchor, positive, negative)
1329
+ >>> output.backward()
1330
+ >>>
1331
+ >>> # Custom Distance Function
1332
+ >>> def l_infinity(x1, x2):
1333
+ >>> return torch.max(torch.abs(x1 - x2), dim=1).values
1334
+ >>>
1335
+ >>> triplet_loss = \
1336
+ >>> nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
1337
+ >>> output = triplet_loss(anchor, positive, negative)
1338
+ >>> output.backward()
1339
+ >>>
1340
+ >>> # Custom Distance Function (Lambda)
1341
+ >>> triplet_loss = \
1342
+ >>> nn.TripletMarginWithDistanceLoss(
1343
+ >>> distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
1344
+ >>> output = triplet_loss(anchor, positive, negative)
1345
+ >>> output.backward()
1346
+
1347
+ Reference:
1348
+ V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
1349
+ http://www.bmva.org/bmvc/2016/papers/paper119/index.html
1350
+ """
1351
+ __constants__ = ['margin' , 'swap' , 'reduction' ]
1352
+ margin : float
1353
+ swap : bool
1354
+
1355
+ def __init__ (self , * , distance_function : Optional [Callable [[Tensor , Tensor ], Tensor ]] = None ,
1356
+ margin : float = 1.0 , swap : bool = False , reduction : str = 'mean' ):
1357
+ super (TripletMarginWithDistanceLoss , self ).__init__ (size_average = None , reduce = None , reduction = reduction )
1358
+ self .distance_function = distance_function if distance_function is not None else PairwiseDistance ()
1359
+ self .margin = margin
1360
+ self .swap = swap
1361
+
1362
+ def forward (self , anchor : Tensor , positive : Tensor , negative : Tensor ) -> Tensor :
1363
+ return F .triplet_margin_with_distance_loss (anchor , positive , negative ,
1364
+ distance_function = self .distance_function ,
1365
+ margin = self .margin , swap = self .swap , reduction = self .reduction )
1366
+
1367
+
1249
1368
class CTCLoss (_Loss ):
1250
1369
r"""The Connectionist Temporal Classification loss.
1251
1370
0 commit comments