10
10
import traceback
11
11
import unittest
12
12
from abc import ABC , abstractmethod
13
- from typing import Any , ClassVar , Collection , Dict , Mapping , MutableMapping , Optional , Sequence , Tuple , Type , TypeVar
13
+ from typing import (
14
+ Any , ClassVar , Collection , Dict , Iterable , Mapping , MutableMapping , Optional , Sequence , Tuple , Type , TypeVar ,
15
+ )
14
16
from unittest .case import SkipTest
15
17
from unittest .mock import patch
16
18
40
42
from pykeen .training import LCWATrainingLoop , SLCWATrainingLoop , TrainingLoop
41
43
from pykeen .triples import TriplesFactory
42
44
from pykeen .typing import HeadRepresentation , MappedTriples , RelationRepresentation , TailRepresentation
43
- from pykeen .utils import all_in_bounds , resolve_device , set_random_seed , unpack_singletons
45
+ from pykeen .utils import all_in_bounds , get_batchnorm_modules , resolve_device , set_random_seed , unpack_singletons
44
46
from tests .constants import EPSILON
45
47
from tests .mocks import CustomRepresentations
46
48
from tests .utils import rand
@@ -380,25 +382,34 @@ def _check_scores(self, scores: torch.FloatTensor, exp_shape: Tuple[int, ...]):
380
382
def _additional_score_checks (self , scores ):
381
383
"""Additional checks for scores."""
382
384
385
+ @property
386
+ def _score_batch_sizes (self ) -> Iterable [int ]:
387
+ """Return the list of batch sizes to test."""
388
+ if get_batchnorm_modules (self .instance ):
389
+ return [self .batch_size ]
390
+ return [1 , self .batch_size ]
391
+
383
392
def test_score_hrt (self ):
384
393
"""Test score_hrt."""
385
- h , r , t = self ._get_hrt (
386
- (self .batch_size ,),
387
- (self .batch_size ,),
388
- (self .batch_size ,),
389
- )
390
- scores = self .instance .score_hrt (h = h , r = r , t = t )
391
- self ._check_scores (scores = scores , exp_shape = (self .batch_size , 1 ))
394
+ for batch_size in self ._score_batch_sizes :
395
+ h , r , t = self ._get_hrt (
396
+ (batch_size ,),
397
+ (batch_size ,),
398
+ (batch_size ,),
399
+ )
400
+ scores = self .instance .score_hrt (h = h , r = r , t = t )
401
+ self ._check_scores (scores = scores , exp_shape = (batch_size , 1 ))
392
402
393
403
def test_score_h (self ):
394
404
"""Test score_h."""
395
- h , r , t = self ._get_hrt (
396
- (self .num_entities ,),
397
- (self .batch_size ,),
398
- (self .batch_size ,),
399
- )
400
- scores = self .instance .score_h (all_entities = h , r = r , t = t )
401
- self ._check_scores (scores = scores , exp_shape = (self .batch_size , self .num_entities ))
405
+ for batch_size in self ._score_batch_sizes :
406
+ h , r , t = self ._get_hrt (
407
+ (self .num_entities ,),
408
+ (batch_size ,),
409
+ (batch_size ,),
410
+ )
411
+ scores = self .instance .score_h (all_entities = h , r = r , t = t )
412
+ self ._check_scores (scores = scores , exp_shape = (batch_size , self .num_entities ))
402
413
403
414
def test_score_h_slicing (self ):
404
415
"""Test score_h with slicing."""
@@ -415,17 +426,18 @@ def test_score_h_slicing(self):
415
426
416
427
def test_score_r (self ):
417
428
"""Test score_r."""
418
- h , r , t = self ._get_hrt (
419
- (self .batch_size ,),
420
- (self .num_relations ,),
421
- (self .batch_size ,),
422
- )
423
- scores = self .instance .score_r (h = h , all_relations = r , t = t )
424
- if len (self .cls .relation_shape ) == 0 :
425
- exp_shape = (self .batch_size , 1 )
426
- else :
427
- exp_shape = (self .batch_size , self .num_relations )
428
- self ._check_scores (scores = scores , exp_shape = exp_shape )
429
+ for batch_size in self ._score_batch_sizes :
430
+ h , r , t = self ._get_hrt (
431
+ (batch_size ,),
432
+ (self .num_relations ,),
433
+ (batch_size ,),
434
+ )
435
+ scores = self .instance .score_r (h = h , all_relations = r , t = t )
436
+ if len (self .cls .relation_shape ) == 0 :
437
+ exp_shape = (batch_size , 1 )
438
+ else :
439
+ exp_shape = (batch_size , self .num_relations )
440
+ self ._check_scores (scores = scores , exp_shape = exp_shape )
429
441
430
442
def test_score_r_slicing (self ):
431
443
"""Test score_r with slicing."""
@@ -444,13 +456,14 @@ def test_score_r_slicing(self):
444
456
445
457
def test_score_t (self ):
446
458
"""Test score_t."""
447
- h , r , t = self ._get_hrt (
448
- (self .batch_size ,),
449
- (self .batch_size ,),
450
- (self .num_entities ,),
451
- )
452
- scores = self .instance .score_t (h = h , r = r , all_entities = t )
453
- self ._check_scores (scores = scores , exp_shape = (self .batch_size , self .num_entities ))
459
+ for batch_size in self ._score_batch_sizes :
460
+ h , r , t = self ._get_hrt (
461
+ (batch_size ,),
462
+ (batch_size ,),
463
+ (self .num_entities ,),
464
+ )
465
+ scores = self .instance .score_t (h = h , r = r , all_entities = t )
466
+ self ._check_scores (scores = scores , exp_shape = (batch_size , self .num_entities ))
454
467
455
468
def test_score_t_slicing (self ):
456
469
"""Test score_t with slicing."""
0 commit comments