Skip to content

Commit c3bf402

Browse files
liqunfufacebook-github-bot
authored andcommittedSep 28, 2020
handle onnx nll with default ignore index (pytorch#44816)
Summary: in ONNX NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. therefore, when convert nll op to ONNX, we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). Pull Request resolved: pytorch#44816 Reviewed By: ezyang Differential Revision: D23880354 Pulled By: bzinodev fbshipit-source-id: d0bdd58d0a4507ed9ce37133e68533fe6d1bdf2b
1 parent 8bdbedd commit c3bf402

7 files changed

+93
-95
lines changed
 

‎test/onnx/expect/TestOperators.test_softmaxcrossentropy.expect

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ graph {
88
output: "2"
99
name: "SoftmaxCrossEntropyLoss_0"
1010
op_type: "SoftmaxCrossEntropyLoss"
11+
attribute {
12+
name: "ignore_index"
13+
i: -100
14+
type: INT
15+
}
1116
attribute {
1217
name: "reduction"
1318
s: "mean"

‎test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d.expect

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ graph {
88
output: "2"
99
name: "SoftmaxCrossEntropyLoss_0"
1010
op_type: "SoftmaxCrossEntropyLoss"
11+
attribute {
12+
name: "ignore_index"
13+
i: -100
14+
type: INT
15+
}
1116
attribute {
1217
name: "reduction"
1318
s: "mean"

‎test/onnx/expect/TestOperators.test_softmaxcrossentropy_3d_none.expect

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ graph {
88
output: "2"
99
name: "SoftmaxCrossEntropyLoss_0"
1010
op_type: "SoftmaxCrossEntropyLoss"
11+
attribute {
12+
name: "ignore_index"
13+
i: -100
14+
type: INT
15+
}
1116
attribute {
1217
name: "reduction"
1318
s: "none"

‎test/onnx/expect/TestOperators.test_softmaxcrossentropy_4d.expect

+5
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@ graph {
88
output: "2"
99
name: "SoftmaxCrossEntropyLoss_0"
1010
op_type: "SoftmaxCrossEntropyLoss"
11+
attribute {
12+
name: "ignore_index"
13+
i: -100
14+
type: INT
15+
}
1116
attribute {
1217
name: "reduction"
1318
s: "mean"

‎test/onnx/expect/TestOperators.test_softmaxcrossentropy_weights.expect

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ graph {
99
output: "3"
1010
name: "SoftmaxCrossEntropyLoss_0"
1111
op_type: "SoftmaxCrossEntropyLoss"
12+
attribute {
13+
name: "ignore_index"
14+
i: -100
15+
type: INT
16+
}
1217
attribute {
1318
name: "reduction"
1419
s: "mean"

‎test/onnx/test_pytorch_onnx_onnxruntime.py

+66-87
Original file line numberDiff line numberDiff line change
@@ -3742,138 +3742,102 @@ def forward(self, *tensor_list):
37423742
@skipIfUnsupportedMinOpsetVersion(12)
37433743
@disableScriptTest()
37443744
def test_crossentropyloss(self):
3745-
x = torch.randn(3, 5)
3746-
y = torch.empty(3, dtype=torch.long).random_(5)
3747-
self._crossentropyloss(x, y)
3745+
for ignore_index in [-100, 1]:
3746+
x = torch.randn(3, 5)
3747+
y = torch.empty(3, dtype=torch.long).random_(5)
3748+
y[y == 1] = ignore_index
37483749

3749-
x = torch.randn(3, 5, 2)
3750-
y = torch.empty(3, 2, dtype=torch.long).random_(5)
3751-
self._crossentropyloss(x, y)
3750+
self._crossentropyloss(x, y, ignore_index)
37523751

3753-
x = torch.randn(3, 5, 2, 7)
3754-
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
3755-
self._crossentropyloss(x, y)
3752+
x = torch.randn(3, 5, 2)
3753+
y = torch.empty(3, 2, dtype=torch.long).random_(5)
3754+
y[y == 1] = ignore_index
3755+
self._crossentropyloss(x, y, ignore_index)
37563756

3757-
def _crossentropyloss(self, x, y):
3757+
x = torch.randn(3, 5, 2, 7)
3758+
y = torch.empty(3, 2, 7, dtype=torch.long).random_(5)
3759+
y[y == 1] = ignore_index
3760+
self._crossentropyloss(x, y, ignore_index)
3761+
3762+
def _crossentropyloss(self, x, y, ignore_index):
37583763
class CrossEntropyLossNone(torch.nn.Module):
3759-
def __init__(self):
3764+
def __init__(self, ignore_index):
37603765
super(CrossEntropyLossNone, self).__init__()
3761-
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
3766+
if ignore_index == -100:
3767+
self.loss = torch.nn.CrossEntropyLoss(reduction='none')
3768+
else:
3769+
self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=ignore_index)
37623770

37633771
def forward(self, input, target):
37643772
return self.loss(input, target)
37653773

3766-
self.run_test(CrossEntropyLossNone(), input=(x, y))
3774+
self.run_test(CrossEntropyLossNone(ignore_index), input=(x, y))
37673775

37683776
class CrossEntropyLossNoneWeight(torch.nn.Module):
3769-
def __init__(self):
3777+
def __init__(self, ignore_index):
37703778
super(CrossEntropyLossNoneWeight, self).__init__()
3771-
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5))
3779+
if ignore_index == -100:
3780+
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5))
3781+
else:
3782+
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=ignore_index)
37723783

37733784
def forward(self, input, target):
37743785
return self.loss(input, target)
37753786

3776-
self.run_test(CrossEntropyLossNoneWeight(), input=(x, y))
3787+
self.run_test(CrossEntropyLossNoneWeight(ignore_index), input=(x, y))
37773788

37783789
class CrossEntropyLossSum(torch.nn.Module):
3779-
def __init__(self):
3790+
def __init__(self, ignore_index):
37803791
super(CrossEntropyLossSum, self).__init__()
3781-
self.loss = torch.nn.CrossEntropyLoss(reduction='sum')
3792+
if ignore_index == -100:
3793+
self.loss = torch.nn.CrossEntropyLoss(reduction='sum')
3794+
else:
3795+
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=ignore_index)
37823796

37833797
def forward(self, input, target):
37843798
return self.loss(input, target)
37853799

3786-
self.run_test(CrossEntropyLossSum(), input=(x, y))
3800+
self.run_test(CrossEntropyLossSum(ignore_index), input=(x, y))
37873801

37883802
class CrossEntropyLossSumWeight(torch.nn.Module):
3789-
def __init__(self):
3803+
def __init__(self, ignore_index):
37903804
super(CrossEntropyLossSumWeight, self).__init__()
3791-
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5))
3805+
if ignore_index == -100:
3806+
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5))
3807+
else:
3808+
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=ignore_index)
37923809

37933810
def forward(self, input, target):
37943811
return self.loss(input, target)
37953812

3796-
self.run_test(CrossEntropyLossSumWeight(), input=(x, y))
3813+
self.run_test(CrossEntropyLossSumWeight(ignore_index), input=(x, y))
37973814

37983815
class CrossEntropyLossMean(torch.nn.Module):
3799-
def __init__(self):
3816+
def __init__(self, ignore_index):
38003817
super(CrossEntropyLossMean, self).__init__()
3801-
self.loss = torch.nn.CrossEntropyLoss()
3818+
if ignore_index == -100:
3819+
self.loss = torch.nn.CrossEntropyLoss()
3820+
else:
3821+
self.loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)
38023822

38033823
def forward(self, input, target):
38043824
return self.loss(input, target)
38053825

3806-
self.run_test(CrossEntropyLossMean(), input=(x, y))
3826+
self.run_test(CrossEntropyLossMean(ignore_index), input=(x, y))
38073827

38083828
class CrossEntropyLossMeanWeight(torch.nn.Module):
3809-
def __init__(self):
3829+
def __init__(self, ignore_index):
38103830
super(CrossEntropyLossMeanWeight, self).__init__()
3811-
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
3812-
3813-
def forward(self, input, target):
3814-
return self.loss(input, target)
3815-
3816-
self.run_test(CrossEntropyLossMeanWeight(), input=(x, y))
3817-
3818-
class CrossEntropyLossNoneIgnoreIndex(torch.nn.Module):
3819-
def __init__(self):
3820-
super(CrossEntropyLossNoneIgnoreIndex, self).__init__()
3821-
self.loss = torch.nn.CrossEntropyLoss(reduction='none', ignore_index=1)
3822-
3823-
def forward(self, input, target):
3824-
return self.loss(input, target)
3825-
3826-
self.run_test(CrossEntropyLossNoneIgnoreIndex(), input=(x, y))
3827-
3828-
class CrossEntropyLossNoneWeightIgnoreIndex(torch.nn.Module):
3829-
def __init__(self):
3830-
super(CrossEntropyLossNoneWeightIgnoreIndex, self).__init__()
3831-
self.loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.randn(5), ignore_index=1)
3832-
3833-
def forward(self, input, target):
3834-
return self.loss(input, target)
3835-
3836-
self.run_test(CrossEntropyLossNoneWeightIgnoreIndex(), input=(x, y))
3837-
3838-
class CrossEntropyLossSumIgnoreIndex(torch.nn.Module):
3839-
def __init__(self):
3840-
super(CrossEntropyLossSumIgnoreIndex, self).__init__()
3841-
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', ignore_index=1)
3842-
3843-
def forward(self, input, target):
3844-
return self.loss(input, target)
3845-
3846-
self.run_test(CrossEntropyLossSumIgnoreIndex(), input=(x, y))
3847-
3848-
class CrossEntropyLossSumWeightIgnoreIndex(torch.nn.Module):
3849-
def __init__(self):
3850-
super(CrossEntropyLossSumWeightIgnoreIndex, self).__init__()
3851-
self.loss = torch.nn.CrossEntropyLoss(reduction='sum', weight=torch.randn(5), ignore_index=1)
3852-
3853-
def forward(self, input, target):
3854-
return self.loss(input, target)
3855-
3856-
self.run_test(CrossEntropyLossSumWeightIgnoreIndex(), input=(x, y))
3857-
3858-
class CrossEntropyLossMeanIgnoreIndex(torch.nn.Module):
3859-
def __init__(self):
3860-
super(CrossEntropyLossMeanIgnoreIndex, self).__init__()
3861-
self.loss = torch.nn.CrossEntropyLoss(ignore_index=1)
3831+
if ignore_index == -100:
3832+
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5))
3833+
else:
3834+
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=ignore_index)
38623835

38633836
def forward(self, input, target):
38643837
return self.loss(input, target)
38653838

3866-
self.run_test(CrossEntropyLossMeanIgnoreIndex(), input=(x, y))
3839+
self.run_test(CrossEntropyLossMeanWeight(ignore_index), input=(x, y))
38673840

3868-
class CrossEntropyLossMeanWeightIgnoreIndex(torch.nn.Module):
3869-
def __init__(self):
3870-
super(CrossEntropyLossMeanWeightIgnoreIndex, self).__init__()
3871-
self.loss = torch.nn.CrossEntropyLoss(weight=torch.randn(5), ignore_index=1)
3872-
3873-
def forward(self, input, target):
3874-
return self.loss(input, target)
3875-
3876-
self.run_test(CrossEntropyLossMeanWeightIgnoreIndex(), input=(x, y))
38773841

38783842
@skipIfUnsupportedMinOpsetVersion(9)
38793843
def test_kldiv_loss(self):
@@ -3957,6 +3921,9 @@ def forward(self, input, target):
39573921
N, C = 5, 4
39583922
input = torch.randn(N, 16)
39593923
target = torch.empty(N, dtype=torch.long).random_(0, C)
3924+
3925+
# using test data containing default ignore_index=-100
3926+
target[target == 1] = -100
39603927
self.run_test(NLLModel(), (input, target))
39613928

39623929
@skipIfUnsupportedMinOpsetVersion(12)
@@ -3976,6 +3943,9 @@ def forward(self, input, target):
39763943
N, C = 5, 4
39773944
input = torch.randn(N, 16, 10, 10)
39783945
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
3946+
3947+
# using test data containing default ignore_index=-100
3948+
target[target == 1] = -100
39793949
self.run_test(NLLModel(), (input, target))
39803950

39813951
@skipIfUnsupportedMinOpsetVersion(12)
@@ -3995,6 +3965,9 @@ def forward(self, input, target):
39953965
N, C = 5, 4
39963966
input = torch.randn(N, 16, 10, 10)
39973967
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
3968+
3969+
# using test data containing default ignore_index=-100
3970+
target[target == 1] = -100
39983971
self.run_test(NLLModel(), (input, target))
39993972

40003973
@skipIfUnsupportedMinOpsetVersion(12)
@@ -4014,6 +3987,9 @@ def forward(self, input, target):
40143987
N, C = 5, 4
40153988
input = torch.randn(N, 16, 10, 10)
40163989
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
3990+
3991+
# using test data containing default ignore_index=-100
3992+
target[target == 1] = -100
40173993
self.run_test(NLLModel(), (input, target))
40183994

40193995
@skipIfUnsupportedMinOpsetVersion(12)
@@ -4033,6 +4009,9 @@ def forward(self, input, target):
40334009
N, C = 5, 4
40344010
input = torch.randn(N, 16, 10, 10)
40354011
target = torch.empty(N, 8, 8, dtype=torch.long).random_(0, C)
4012+
4013+
# using test data containing default ignore_index=-100
4014+
target[target == 1] = -100
40364015
self.run_test(NLLModel(), (input, target))
40374016

40384017
@skipIfUnsupportedMinOpsetVersion(12)

‎torch/onnx/symbolic_opset12.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,9 @@ def nll_loss(g, self, target, weight, reduction, ignore_index):
3636
reduction_vals = ['none', 'mean', 'sum']
3737
reduction = reduction_vals[reduction]
3838

39-
# when ignore_index is not specified, ignore_index == onnx::Constant[value={-100}]
39+
# in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value.
40+
# therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100).
4041
ignore_index = sym_help._maybe_get_const(ignore_index, 'i')
41-
if ignore_index == -100:
42-
if weight.node().mustBeNone():
43-
return g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction)
44-
else:
45-
return g.op("NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction)
46-
47-
# if ignore_index is specified, compute nllloss with no reduction and apply the reduction afterwards
4842
if weight.node().mustBeNone():
4943
nllloss = g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index)
5044
else:

0 commit comments

Comments
 (0)
Please sign in to comment.