Skip to content

Commit af652ca

Browse files
author
Caroline Chen
authoredAug 3, 2021
Improve RNNT Loss docstrings (#1620)
1 parent d74d060 commit af652ca

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed
 

‎torchaudio/prototype/rnnt_loss.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ def rnnt_loss(
2424
dependencies.
2525
2626
Args:
27-
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
27+
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
28+
containing output from joiner
2829
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
2930
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
3031
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
31-
blank (int, opt): blank label (Default: ``-1``)
32-
clamp (float): clamp for gradients (Default: ``-1``)
32+
blank (int, optional): blank label (Default: ``-1``)
33+
clamp (float, optional): clamp for gradients (Default: ``-1``)
3334
reduction (string, optional): Specifies the reduction to apply to the output:
3435
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
3536
@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
6970
dependencies.
7071
7172
Args:
72-
blank (int, opt): blank label (Default: ``-1``)
73-
clamp (float): clamp for gradients (Default: ``-1``)
73+
blank (int, optional): blank label (Default: ``-1``)
74+
clamp (float, optional): clamp for gradients (Default: ``-1``)
7475
reduction (string, optional): Specifies the reduction to apply to the output:
7576
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
7677
"""
@@ -95,7 +96,8 @@ def forward(
9596
):
9697
"""
9798
Args:
98-
logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
99+
logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
100+
containing output from joiner
99101
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
100102
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
101103
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence

0 commit comments

Comments
 (0)
Please sign in to comment.