@@ -3742,138 +3742,102 @@ def forward(self, *tensor_list):
3742
3742
@skipIfUnsupportedMinOpsetVersion (12 )
3743
3743
@disableScriptTest ()
3744
3744
def test_crossentropyloss (self ):
3745
- x = torch .randn (3 , 5 )
3746
- y = torch .empty (3 , dtype = torch .long ).random_ (5 )
3747
- self ._crossentropyloss (x , y )
3745
+ for ignore_index in [- 100 , 1 ]:
3746
+ x = torch .randn (3 , 5 )
3747
+ y = torch .empty (3 , dtype = torch .long ).random_ (5 )
3748
+ y [y == 1 ] = ignore_index
3748
3749
3749
- x = torch .randn (3 , 5 , 2 )
3750
- y = torch .empty (3 , 2 , dtype = torch .long ).random_ (5 )
3751
- self ._crossentropyloss (x , y )
3750
+ self ._crossentropyloss (x , y , ignore_index )
3752
3751
3753
- x = torch .randn (3 , 5 , 2 , 7 )
3754
- y = torch .empty (3 , 2 , 7 , dtype = torch .long ).random_ (5 )
3755
- self ._crossentropyloss (x , y )
3752
+ x = torch .randn (3 , 5 , 2 )
3753
+ y = torch .empty (3 , 2 , dtype = torch .long ).random_ (5 )
3754
+ y [y == 1 ] = ignore_index
3755
+ self ._crossentropyloss (x , y , ignore_index )
3756
3756
3757
- def _crossentropyloss (self , x , y ):
3757
+ x = torch .randn (3 , 5 , 2 , 7 )
3758
+ y = torch .empty (3 , 2 , 7 , dtype = torch .long ).random_ (5 )
3759
+ y [y == 1 ] = ignore_index
3760
+ self ._crossentropyloss (x , y , ignore_index )
3761
+
3762
+ def _crossentropyloss (self , x , y , ignore_index ):
3758
3763
class CrossEntropyLossNone (torch .nn .Module ):
3759
- def __init__ (self ):
3764
+ def __init__ (self , ignore_index ):
3760
3765
super (CrossEntropyLossNone , self ).__init__ ()
3761
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' )
3766
+ if ignore_index == - 100 :
3767
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' )
3768
+ else :
3769
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , ignore_index = ignore_index )
3762
3770
3763
3771
def forward (self , input , target ):
3764
3772
return self .loss (input , target )
3765
3773
3766
- self .run_test (CrossEntropyLossNone (), input = (x , y ))
3774
+ self .run_test (CrossEntropyLossNone (ignore_index ), input = (x , y ))
3767
3775
3768
3776
class CrossEntropyLossNoneWeight (torch .nn .Module ):
3769
- def __init__ (self ):
3777
+ def __init__ (self , ignore_index ):
3770
3778
super (CrossEntropyLossNoneWeight , self ).__init__ ()
3771
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ))
3779
+ if ignore_index == - 100 :
3780
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ))
3781
+ else :
3782
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ), ignore_index = ignore_index )
3772
3783
3773
3784
def forward (self , input , target ):
3774
3785
return self .loss (input , target )
3775
3786
3776
- self .run_test (CrossEntropyLossNoneWeight (), input = (x , y ))
3787
+ self .run_test (CrossEntropyLossNoneWeight (ignore_index ), input = (x , y ))
3777
3788
3778
3789
class CrossEntropyLossSum (torch .nn .Module ):
3779
- def __init__ (self ):
3790
+ def __init__ (self , ignore_index ):
3780
3791
super (CrossEntropyLossSum , self ).__init__ ()
3781
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' )
3792
+ if ignore_index == - 100 :
3793
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' )
3794
+ else :
3795
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , ignore_index = ignore_index )
3782
3796
3783
3797
def forward (self , input , target ):
3784
3798
return self .loss (input , target )
3785
3799
3786
- self .run_test (CrossEntropyLossSum (), input = (x , y ))
3800
+ self .run_test (CrossEntropyLossSum (ignore_index ), input = (x , y ))
3787
3801
3788
3802
class CrossEntropyLossSumWeight (torch .nn .Module ):
3789
- def __init__ (self ):
3803
+ def __init__ (self , ignore_index ):
3790
3804
super (CrossEntropyLossSumWeight , self ).__init__ ()
3791
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ))
3805
+ if ignore_index == - 100 :
3806
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ))
3807
+ else :
3808
+ self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ), ignore_index = ignore_index )
3792
3809
3793
3810
def forward (self , input , target ):
3794
3811
return self .loss (input , target )
3795
3812
3796
- self .run_test (CrossEntropyLossSumWeight (), input = (x , y ))
3813
+ self .run_test (CrossEntropyLossSumWeight (ignore_index ), input = (x , y ))
3797
3814
3798
3815
class CrossEntropyLossMean (torch .nn .Module ):
3799
- def __init__ (self ):
3816
+ def __init__ (self , ignore_index ):
3800
3817
super (CrossEntropyLossMean , self ).__init__ ()
3801
- self .loss = torch .nn .CrossEntropyLoss ()
3818
+ if ignore_index == - 100 :
3819
+ self .loss = torch .nn .CrossEntropyLoss ()
3820
+ else :
3821
+ self .loss = torch .nn .CrossEntropyLoss (ignore_index = ignore_index )
3802
3822
3803
3823
def forward (self , input , target ):
3804
3824
return self .loss (input , target )
3805
3825
3806
- self .run_test (CrossEntropyLossMean (), input = (x , y ))
3826
+ self .run_test (CrossEntropyLossMean (ignore_index ), input = (x , y ))
3807
3827
3808
3828
class CrossEntropyLossMeanWeight (torch .nn .Module ):
3809
- def __init__ (self ):
3829
+ def __init__ (self , ignore_index ):
3810
3830
super (CrossEntropyLossMeanWeight , self ).__init__ ()
3811
- self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ))
3812
-
3813
- def forward (self , input , target ):
3814
- return self .loss (input , target )
3815
-
3816
- self .run_test (CrossEntropyLossMeanWeight (), input = (x , y ))
3817
-
3818
- class CrossEntropyLossNoneIgnoreIndex (torch .nn .Module ):
3819
- def __init__ (self ):
3820
- super (CrossEntropyLossNoneIgnoreIndex , self ).__init__ ()
3821
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , ignore_index = 1 )
3822
-
3823
- def forward (self , input , target ):
3824
- return self .loss (input , target )
3825
-
3826
- self .run_test (CrossEntropyLossNoneIgnoreIndex (), input = (x , y ))
3827
-
3828
- class CrossEntropyLossNoneWeightIgnoreIndex (torch .nn .Module ):
3829
- def __init__ (self ):
3830
- super (CrossEntropyLossNoneWeightIgnoreIndex , self ).__init__ ()
3831
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'none' , weight = torch .randn (5 ), ignore_index = 1 )
3832
-
3833
- def forward (self , input , target ):
3834
- return self .loss (input , target )
3835
-
3836
- self .run_test (CrossEntropyLossNoneWeightIgnoreIndex (), input = (x , y ))
3837
-
3838
- class CrossEntropyLossSumIgnoreIndex (torch .nn .Module ):
3839
- def __init__ (self ):
3840
- super (CrossEntropyLossSumIgnoreIndex , self ).__init__ ()
3841
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , ignore_index = 1 )
3842
-
3843
- def forward (self , input , target ):
3844
- return self .loss (input , target )
3845
-
3846
- self .run_test (CrossEntropyLossSumIgnoreIndex (), input = (x , y ))
3847
-
3848
- class CrossEntropyLossSumWeightIgnoreIndex (torch .nn .Module ):
3849
- def __init__ (self ):
3850
- super (CrossEntropyLossSumWeightIgnoreIndex , self ).__init__ ()
3851
- self .loss = torch .nn .CrossEntropyLoss (reduction = 'sum' , weight = torch .randn (5 ), ignore_index = 1 )
3852
-
3853
- def forward (self , input , target ):
3854
- return self .loss (input , target )
3855
-
3856
- self .run_test (CrossEntropyLossSumWeightIgnoreIndex (), input = (x , y ))
3857
-
3858
- class CrossEntropyLossMeanIgnoreIndex (torch .nn .Module ):
3859
- def __init__ (self ):
3860
- super (CrossEntropyLossMeanIgnoreIndex , self ).__init__ ()
3861
- self .loss = torch .nn .CrossEntropyLoss (ignore_index = 1 )
3831
+ if ignore_index == - 100 :
3832
+ self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ))
3833
+ else :
3834
+ self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ), ignore_index = ignore_index )
3862
3835
3863
3836
def forward (self , input , target ):
3864
3837
return self .loss (input , target )
3865
3838
3866
- self .run_test (CrossEntropyLossMeanIgnoreIndex ( ), input = (x , y ))
3839
+ self .run_test (CrossEntropyLossMeanWeight ( ignore_index ), input = (x , y ))
3867
3840
3868
- class CrossEntropyLossMeanWeightIgnoreIndex (torch .nn .Module ):
3869
- def __init__ (self ):
3870
- super (CrossEntropyLossMeanWeightIgnoreIndex , self ).__init__ ()
3871
- self .loss = torch .nn .CrossEntropyLoss (weight = torch .randn (5 ), ignore_index = 1 )
3872
-
3873
- def forward (self , input , target ):
3874
- return self .loss (input , target )
3875
-
3876
- self .run_test (CrossEntropyLossMeanWeightIgnoreIndex (), input = (x , y ))
3877
3841
3878
3842
@skipIfUnsupportedMinOpsetVersion (9 )
3879
3843
def test_kldiv_loss (self ):
@@ -3957,6 +3921,9 @@ def forward(self, input, target):
3957
3921
N , C = 5 , 4
3958
3922
input = torch .randn (N , 16 )
3959
3923
target = torch .empty (N , dtype = torch .long ).random_ (0 , C )
3924
+
3925
+ # using test data containing default ignore_index=-100
3926
+ target [target == 1 ] = - 100
3960
3927
self .run_test (NLLModel (), (input , target ))
3961
3928
3962
3929
@skipIfUnsupportedMinOpsetVersion (12 )
@@ -3976,6 +3943,9 @@ def forward(self, input, target):
3976
3943
N , C = 5 , 4
3977
3944
input = torch .randn (N , 16 , 10 , 10 )
3978
3945
target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3946
+
3947
+ # using test data containing default ignore_index=-100
3948
+ target [target == 1 ] = - 100
3979
3949
self .run_test (NLLModel (), (input , target ))
3980
3950
3981
3951
@skipIfUnsupportedMinOpsetVersion (12 )
@@ -3995,6 +3965,9 @@ def forward(self, input, target):
3995
3965
N , C = 5 , 4
3996
3966
input = torch .randn (N , 16 , 10 , 10 )
3997
3967
target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3968
+
3969
+ # using test data containing default ignore_index=-100
3970
+ target [target == 1 ] = - 100
3998
3971
self .run_test (NLLModel (), (input , target ))
3999
3972
4000
3973
@skipIfUnsupportedMinOpsetVersion (12 )
@@ -4014,6 +3987,9 @@ def forward(self, input, target):
4014
3987
N , C = 5 , 4
4015
3988
input = torch .randn (N , 16 , 10 , 10 )
4016
3989
target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
3990
+
3991
+ # using test data containing default ignore_index=-100
3992
+ target [target == 1 ] = - 100
4017
3993
self .run_test (NLLModel (), (input , target ))
4018
3994
4019
3995
@skipIfUnsupportedMinOpsetVersion (12 )
@@ -4033,6 +4009,9 @@ def forward(self, input, target):
4033
4009
N , C = 5 , 4
4034
4010
input = torch .randn (N , 16 , 10 , 10 )
4035
4011
target = torch .empty (N , 8 , 8 , dtype = torch .long ).random_ (0 , C )
4012
+
4013
+ # using test data containing default ignore_index=-100
4014
+ target [target == 1 ] = - 100
4036
4015
self .run_test (NLLModel (), (input , target ))
4037
4016
4038
4017
@skipIfUnsupportedMinOpsetVersion (12 )
0 commit comments