@@ -24,12 +24,13 @@ def rnnt_loss(
24
24
dependencies.
25
25
26
26
Args:
27
- logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
27
+ logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
28
+ containing output from joiner
28
29
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
29
30
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
30
31
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
31
- blank (int, opt ): blank label (Default: ``-1``)
32
- clamp (float): clamp for gradients (Default: ``-1``)
32
+ blank (int, optional ): blank label (Default: ``-1``)
33
+ clamp (float, optional ): clamp for gradients (Default: ``-1``)
33
34
reduction (string, optional): Specifies the reduction to apply to the output:
34
35
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
35
36
@@ -69,8 +70,8 @@ class RNNTLoss(torch.nn.Module):
69
70
dependencies.
70
71
71
72
Args:
72
- blank (int, opt ): blank label (Default: ``-1``)
73
- clamp (float): clamp for gradients (Default: ``-1``)
73
+ blank (int, optional ): blank label (Default: ``-1``)
74
+ clamp (float, optional ): clamp for gradients (Default: ``-1``)
74
75
reduction (string, optional): Specifies the reduction to apply to the output:
75
76
``'none'`` | ``'mean'`` | ``'sum'``. (Default: ``'mean'``)
76
77
"""
@@ -95,7 +96,8 @@ def forward(
95
96
):
96
97
"""
97
98
Args:
98
- logits (Tensor): Tensor of dimension (batch, time, target, class) containing output from joiner
99
+ logits (Tensor): Tensor of dimension (batch, max seq length, max target length + 1, class)
100
+ containing output from joiner
99
101
targets (Tensor): Tensor of dimension (batch, max target length) containing targets with zero padded
100
102
logit_lengths (Tensor): Tensor of dimension (batch) containing lengths of each sequence from encoder
101
103
target_lengths (Tensor): Tensor of dimension (batch) containing lengths of targets for each sequence
0 commit comments