Skip to content

Commit 0d396bb

Browse files
mberrcthoyt
andauthoredJun 10, 2021
🩹 👾 Fix ERMLP functional form (pykeen#444)
* Fix ERMLP functional form * Add tests for batch_size=1 for all score_* methods * skip batch_size=1 for BatchNorm interaction modules * Add mypy typestub external packages This is now necessary as of mypy 0.900 Co-authored-by: Charles Tapley Hoyt <[email protected]>
1 parent 9b7adc6 commit 0d396bb

File tree

5 files changed

+70
-93
lines changed

5 files changed

+70
-93
lines changed
 

‎src/pykeen/models/resolve.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .nbase import ERModel, EmbeddingSpecificationHint
6464
from ..nn.emb import EmbeddingSpecification, RepresentationModule
6565
from ..nn.modules import Interaction, interaction_resolver
66+
from ..typing import HeadRepresentation, RelationRepresentation, TailRepresentation
6667

6768
__all__ = [
6869
'make_model',
@@ -74,7 +75,11 @@
7475

7576
def make_model(
7677
dimensions: Union[int, Mapping[str, int]],
77-
interaction: Union[str, Interaction, Type[Interaction]],
78+
interaction: Union[
79+
str,
80+
Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation],
81+
Type[Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation]],
82+
],
7883
interaction_kwargs: Optional[Mapping[str, Any]] = None,
7984
entity_representations: EmbeddingSpecificationHint = None,
8085
relation_representations: EmbeddingSpecificationHint = None,
@@ -104,7 +109,10 @@ def __str__(self):
104109

105110
def make_model_cls(
106111
dimensions: Union[int, Mapping[str, int]],
107-
interaction: Union[str, Interaction, Type[Interaction]],
112+
interaction: Union[
113+
str, Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation],
114+
Type[Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation]],
115+
],
108116
interaction_kwargs: Optional[Mapping[str, Any]] = None,
109117
entity_representations: EmbeddingSpecificationHint = None,
110118
relation_representations: EmbeddingSpecificationHint = None,
@@ -117,15 +125,15 @@ def make_model_cls(
117125

118126
entity_representations, relation_representations = _normalize_entity_representations(
119127
dimensions=dimensions,
120-
interaction=interaction_instance.__class__,
128+
interaction=interaction_instance.__class__, # type: ignore
121129
entity_representations=entity_representations,
122130
relation_representations=relation_representations,
123131
)
124132

125133
# TODO pack/unpack dimensions as default kwargs such that they don't actually need to be used
126134
# to create the class
127135

128-
class ChildERModel(ERModel):
136+
class ChildERModel(ERModel[HeadRepresentation, RelationRepresentation, TailRepresentation]):
129137
def __init__(self, **kwargs) -> None:
130138
"""Initialize the model."""
131139
super().__init__(
@@ -142,7 +150,7 @@ def __init__(self, **kwargs) -> None:
142150

143151
def _normalize_entity_representations(
144152
dimensions: Union[int, Mapping[str, int]],
145-
interaction: Type[Interaction],
153+
interaction: Type[Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation]],
146154
entity_representations: EmbeddingSpecificationHint,
147155
relation_representations: EmbeddingSpecificationHint,
148156
) -> Tuple[

‎src/pykeen/nn/functional.py

+2-51
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from __future__ import annotations
1212

1313
import functools
14-
from dataclasses import dataclass
1514
from typing import Optional, Tuple, Union
1615

1716
import numpy
@@ -54,52 +53,6 @@
5453
]
5554

5655

57-
@dataclass
58-
class SizeInformation:
59-
"""Size information of generic score function."""
60-
61-
#: The batch size of the head representations.
62-
bh: int
63-
64-
#: The number of head representations per batch
65-
nh: int
66-
67-
#: The batch size of the relation representations.
68-
br: int
69-
70-
#: The number of relation representations per batch
71-
nr: int
72-
73-
#: The batch size of the tail representations.
74-
bt: int
75-
76-
#: The number of tail representations per batch
77-
nt: int
78-
79-
@property
80-
def same(self) -> bool:
81-
"""Whether all representations have the same shape."""
82-
return (
83-
self.bh == self.br
84-
and self.bh == self.bt
85-
and self.nh == self.nr
86-
and self.nh == self.nt
87-
)
88-
89-
@classmethod
90-
def extract(
91-
cls,
92-
h: torch.Tensor,
93-
r: torch.Tensor,
94-
t: torch.Tensor,
95-
) -> SizeInformation:
96-
"""Extract size information from tensors."""
97-
bh, nh = h.shape[:2]
98-
br, nr = r.shape[:2]
99-
bt, nt = t.shape[:2]
100-
return cls(bh=bh, nh=nh, br=br, nr=nr, bt=bt, nt=nt)
101-
102-
10356
def _extract_sizes(
10457
h: torch.Tensor,
10558
r: torch.Tensor,
@@ -347,13 +300,11 @@ def ermlp_interaction(
347300
:return: shape: (batch_size, num_heads, num_relations, num_tails)
348301
The scores.
349302
"""
350-
sizes = SizeInformation.extract(h, r, t)
351-
352303
# same shape
353-
if sizes.same:
304+
if h.shape == r.shape and h.shape == t.shape:
354305
return final(activation(
355306
hidden(torch.cat([h, r, t], dim=-1).view(-1, 3 * h.shape[-1]))),
356-
).view(sizes.bh, sizes.nh, sizes.nr, sizes.nt)
307+
).view(*h.shape[:-1])
357308

358309
hidden_dim = hidden.weight.shape[0]
359310
# split, shape: (embedding_dim, hidden_dim)

‎src/pykeen/pipeline/api.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@
175175
import pickle
176176
import time
177177
from dataclasses import dataclass, field
178-
from typing import Any, Collection, Dict, Iterable, List, Mapping, MutableMapping, Optional, Type, Union
178+
from typing import Any, Collection, Dict, Iterable, List, Mapping, MutableMapping, Optional, Type, Union, cast
179179

180180
import pandas as pd
181181
import torch
@@ -881,7 +881,7 @@ def pipeline( # noqa: C901
881881
)
882882

883883
if isinstance(model, Model):
884-
model_instance = model
884+
model_instance = cast(Model, model)
885885
# TODO should training be reset?
886886
# TODO should kwargs for loss and regularizer be checked and raised for?
887887
else:

‎tests/cases.py

+47-34
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
import traceback
1111
import unittest
1212
from abc import ABC, abstractmethod
13-
from typing import Any, ClassVar, Collection, Dict, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, TypeVar
13+
from typing import (
14+
Any, ClassVar, Collection, Dict, Iterable, Mapping, MutableMapping, Optional, Sequence, Tuple, Type, TypeVar,
15+
)
1416
from unittest.case import SkipTest
1517
from unittest.mock import patch
1618

@@ -40,7 +42,7 @@
4042
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop, TrainingLoop
4143
from pykeen.triples import TriplesFactory
4244
from pykeen.typing import HeadRepresentation, MappedTriples, RelationRepresentation, TailRepresentation
43-
from pykeen.utils import all_in_bounds, resolve_device, set_random_seed, unpack_singletons
45+
from pykeen.utils import all_in_bounds, get_batchnorm_modules, resolve_device, set_random_seed, unpack_singletons
4446
from tests.constants import EPSILON
4547
from tests.mocks import CustomRepresentations
4648
from tests.utils import rand
@@ -380,25 +382,34 @@ def _check_scores(self, scores: torch.FloatTensor, exp_shape: Tuple[int, ...]):
380382
def _additional_score_checks(self, scores):
381383
"""Additional checks for scores."""
382384

385+
@property
386+
def _score_batch_sizes(self) -> Iterable[int]:
387+
"""Return the list of batch sizes to test."""
388+
if get_batchnorm_modules(self.instance):
389+
return [self.batch_size]
390+
return [1, self.batch_size]
391+
383392
def test_score_hrt(self):
384393
"""Test score_hrt."""
385-
h, r, t = self._get_hrt(
386-
(self.batch_size,),
387-
(self.batch_size,),
388-
(self.batch_size,),
389-
)
390-
scores = self.instance.score_hrt(h=h, r=r, t=t)
391-
self._check_scores(scores=scores, exp_shape=(self.batch_size, 1))
394+
for batch_size in self._score_batch_sizes:
395+
h, r, t = self._get_hrt(
396+
(batch_size,),
397+
(batch_size,),
398+
(batch_size,),
399+
)
400+
scores = self.instance.score_hrt(h=h, r=r, t=t)
401+
self._check_scores(scores=scores, exp_shape=(batch_size, 1))
392402

393403
def test_score_h(self):
394404
"""Test score_h."""
395-
h, r, t = self._get_hrt(
396-
(self.num_entities,),
397-
(self.batch_size,),
398-
(self.batch_size,),
399-
)
400-
scores = self.instance.score_h(all_entities=h, r=r, t=t)
401-
self._check_scores(scores=scores, exp_shape=(self.batch_size, self.num_entities))
405+
for batch_size in self._score_batch_sizes:
406+
h, r, t = self._get_hrt(
407+
(self.num_entities,),
408+
(batch_size,),
409+
(batch_size,),
410+
)
411+
scores = self.instance.score_h(all_entities=h, r=r, t=t)
412+
self._check_scores(scores=scores, exp_shape=(batch_size, self.num_entities))
402413

403414
def test_score_h_slicing(self):
404415
"""Test score_h with slicing."""
@@ -415,17 +426,18 @@ def test_score_h_slicing(self):
415426

416427
def test_score_r(self):
417428
"""Test score_r."""
418-
h, r, t = self._get_hrt(
419-
(self.batch_size,),
420-
(self.num_relations,),
421-
(self.batch_size,),
422-
)
423-
scores = self.instance.score_r(h=h, all_relations=r, t=t)
424-
if len(self.cls.relation_shape) == 0:
425-
exp_shape = (self.batch_size, 1)
426-
else:
427-
exp_shape = (self.batch_size, self.num_relations)
428-
self._check_scores(scores=scores, exp_shape=exp_shape)
429+
for batch_size in self._score_batch_sizes:
430+
h, r, t = self._get_hrt(
431+
(batch_size,),
432+
(self.num_relations,),
433+
(batch_size,),
434+
)
435+
scores = self.instance.score_r(h=h, all_relations=r, t=t)
436+
if len(self.cls.relation_shape) == 0:
437+
exp_shape = (batch_size, 1)
438+
else:
439+
exp_shape = (batch_size, self.num_relations)
440+
self._check_scores(scores=scores, exp_shape=exp_shape)
429441

430442
def test_score_r_slicing(self):
431443
"""Test score_r with slicing."""
@@ -444,13 +456,14 @@ def test_score_r_slicing(self):
444456

445457
def test_score_t(self):
446458
"""Test score_t."""
447-
h, r, t = self._get_hrt(
448-
(self.batch_size,),
449-
(self.batch_size,),
450-
(self.num_entities,),
451-
)
452-
scores = self.instance.score_t(h=h, r=r, all_entities=t)
453-
self._check_scores(scores=scores, exp_shape=(self.batch_size, self.num_entities))
459+
for batch_size in self._score_batch_sizes:
460+
h, r, t = self._get_hrt(
461+
(batch_size,),
462+
(batch_size,),
463+
(self.num_entities,),
464+
)
465+
scores = self.instance.score_t(h=h, r=r, all_entities=t)
466+
self._check_scores(scores=scores, exp_shape=(batch_size, self.num_entities))
454467

455468
def test_score_t_slicing(self):
456469
"""Test score_t with slicing."""

‎tox.ini

+6-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,12 @@ commands =
148148
description = Check all python files do not have mistaken trailing commas
149149

150150
[testenv:mypy]
151-
deps = mypy
151+
deps =
152+
mypy
153+
types-click
154+
types-pkg_resources
155+
types-requests
156+
types-tabulate
152157
skip_install = true
153158
commands = mypy --ignore-missing-imports \
154159
src/pykeen/datasets \

0 commit comments

Comments
 (0)