Skip to content

Commit c754fcb

Browse files
cthoytmberr
andauthoredOct 25, 2024··
↔️🎱 Merge QuatE (#1472)
Merge QuatE functional form into module, and re-organize documentation. See also: #1102 --------- Co-authored-by: Max Berrendorf <[email protected]>
1 parent 35c5756 commit c754fcb

File tree

8 files changed

+155
-143
lines changed

8 files changed

+155
-143
lines changed
 

‎docs/source/reference/nn/utils.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ Utilities
33
.. automodule:: pykeen.nn.utils
44
:members:
55

6-
.. automodule:: pykeen.nn.algebra
6+
.. automodule:: pykeen.nn.quaternion
77
:members:

‎src/pykeen/models/unimodal/quate.py

+7-31
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,22 @@
44
from typing import Any, ClassVar, Optional
55

66
import torch
7-
from torch.nn import functional
87

98
from ..nbase import ERModel
109
from ...constants import DEFAULT_EMBEDDING_HPO_EMBEDDING_DIM_RANGE
1110
from ...losses import BCEWithLogitsLoss, Loss
11+
from ...nn import quaternion
1212
from ...nn.init import init_quaternions
1313
from ...nn.modules import QuatEInteraction
1414
from ...regularizers import LpRegularizer, Regularizer
15-
from ...typing import Constrainer, FloatTensor, Hint, Initializer
15+
from ...typing import Constrainer, Hint, Initializer
1616
from ...utils import get_expected_norm
1717

1818
__all__ = [
1919
"QuatE",
2020
]
2121

2222

23-
def quaternion_normalizer(x: FloatTensor) -> FloatTensor:
24-
r"""
25-
Normalize the length of relation vectors, if the forward constraint has not been applied yet.
26-
27-
Absolute value of a quaternion
28-
29-
.. math::
30-
31-
|a + bi + cj + dk| = \sqrt{a^2 + b^2 + c^2 + d^2}
32-
33-
L2 norm of quaternion vector:
34-
35-
.. math::
36-
\|x\|^2 = \sum_{i=1}^d |x_i|^2
37-
= \sum_{i=1}^d (x_i.re^2 + x_i.im_1^2 + x_i.im_2^2 + x_i.im_3^2)
38-
:param x:
39-
The vector.
40-
41-
:return:
42-
The normalized vector.
43-
"""
44-
# Normalize relation embeddings
45-
shape = x.shape
46-
x = x.view(*shape[:-1], -1, 4)
47-
x = functional.normalize(x, p=2, dim=-1)
48-
return x.view(*shape)
49-
50-
5123
class QuatE(ERModel):
5224
r"""An implementation of QuatE from [zhang2019]_.
5325
@@ -56,13 +28,17 @@ class QuatE(ERModel):
5628
$\textbf{e}_i, \textbf{r}_i \in \mathbb{H}^d$, and the plausibility score is computed using the
5729
quaternion inner product.
5830
31+
The representations are stored in an :class:`~pykeen.nn.representation.Embedding`.
32+
Scores are calculated with :class:`~pykeen.nn.modules.QuatEInteraction`.
33+
5934
.. seealso ::
6035
6136
Official implementation: https://github.com/cheungdaven/QuatE/blob/master/models/QuatE.py
6237
---
6338
citation:
6439
author: Zhang
6540
year: 2019
41+
arxiv: 1904.10281
6642
link: https://arxiv.org/abs/1904.10281
6743
github: cheungdaven/quate
6844
"""
@@ -92,7 +68,7 @@ def __init__(
9268
relation_initializer: Hint[Initializer] = init_quaternions,
9369
relation_regularizer: Hint[Regularizer] = LpRegularizer,
9470
relation_regularizer_kwargs: Optional[Mapping[str, Any]] = None,
95-
relation_normalizer: Hint[Constrainer] = quaternion_normalizer,
71+
relation_normalizer: Hint[Constrainer] = quaternion.normalize,
9672
**kwargs,
9773
) -> None:
9874
"""Initialize QuatE.

‎src/pykeen/nn/algebra.py

-48
This file was deleted.

‎src/pykeen/nn/functional.py

-28
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@
99
import torch
1010

1111
from ..typing import FloatTensor
12-
from ..utils import einsum
1312

1413
__all__ = [
1514
"circular_correlation",
16-
"quat_e_interaction",
1715
]
1816

1917

@@ -44,29 +42,3 @@ def circular_correlation(
4442
p_fft = a_fft * b_fft
4543
# inverse real FFT
4644
return torch.fft.irfft(p_fft, n=a.shape[-1], dim=-1)
47-
48-
49-
def quat_e_interaction(
50-
h: FloatTensor,
51-
r: FloatTensor,
52-
t: FloatTensor,
53-
table: FloatTensor,
54-
):
55-
"""Evaluate the interaction function of QuatE for given embeddings.
56-
57-
The embeddings have to be in a broadcastable shape.
58-
59-
:param h: shape: (`*batch_dims`, dim, 4)
60-
The head representations.
61-
:param r: shape: (`*batch_dims`, dim, 4)
62-
The head representations.
63-
:param t: shape: (`*batch_dims`, dim, 4)
64-
The tail representations.
65-
:param table:
66-
the quaternion multiplication table.
67-
68-
:return: shape: (...)
69-
The scores.
70-
"""
71-
# TODO: this sign is in the official code, too, but why do we need it?
72-
return -einsum("...di, ...dj, ...dk, ijk -> ...", h, r, t, table)

‎src/pykeen/nn/modules.py

+40-15
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from typing_extensions import Self
3131

3232
from . import functional as pkf
33-
from . import init
34-
from .algebra import quaterion_multiplication_table
33+
from . import init, quaternion
3534
from .compute_kernel import batched_dot
3635
from .sim import KG2ESimilarity, kg2e_similarity_resolver
3736
from .utils import apply_optional_bn
@@ -2782,16 +2781,23 @@ def forward(self, h: FloatTensor, r: tuple[FloatTensor, FloatTensor], t: FloatTe
27822781

27832782

27842783
@parse_docdata
2785-
class QuatEInteraction(
2786-
FunctionalInteraction[
2787-
FloatTensor,
2788-
FloatTensor,
2789-
FloatTensor,
2790-
],
2791-
):
2792-
"""A module wrapper for the QuatE interaction function.
2784+
class QuatEInteraction(Interaction[FloatTensor, FloatTensor, FloatTensor]):
2785+
r"""The state-less QuatE interaction function.
2786+
2787+
It is given as
2788+
2789+
.. math ::
2790+
\langle \mathbf{h} \otimes \mathbf{r}, \mathbf{t} \rangle
2791+
2792+
where $\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{H}^d$ are quanternion representations,
2793+
$\otimes$ denotes the Hamilton product, and $\langle \cdot, \cdot \rangle$ the inner product.
27932794
2794-
.. seealso:: :func:`pykeen.nn.functional.quat_e_interaction`
2795+
.. warning ::
2796+
In order to representation a rotation, $\mathbf{r}$ must be normalized to unit length,
2797+
cf. :func:`pykeen.nn.quaternion.normalize`.
2798+
2799+
.. seealso::
2800+
- https://en.wikipedia.org/wiki/Quaternion
27952801
27962802
---
27972803
citation:
@@ -2805,15 +2811,34 @@ class QuatEInteraction(
28052811
# with k=4
28062812
entity_shape: Sequence[str] = ("dk",)
28072813
relation_shape: Sequence[str] = ("dk",)
2808-
func = pkf.quat_e_interaction
28092814

28102815
def __init__(self) -> None:
28112816
"""Initialize the interaction module."""
28122817
super().__init__()
2813-
self.register_buffer(name="table", tensor=quaterion_multiplication_table())
2818+
self.register_buffer(name="table", tensor=quaternion.multiplication_table())
28142819

2815-
def _prepare_state_for_functional(self) -> MutableMapping[str, Any]:
2816-
return dict(table=self.table)
2820+
def forward(self, h: FloatTensor, r: tuple[FloatTensor, FloatTensor], t: FloatTensor) -> FloatTensor:
2821+
"""Evaluate the interaction function of QuatE for given embeddings.
2822+
2823+
The embeddings have to be in a broadcastable shape.
2824+
2825+
.. seealso::
2826+
:meth:`Interaction.forward <pykeen.nn.modules.Interaction.forward>` for a detailed description about
2827+
the generic batched form of the interaction function.
2828+
2829+
:param h: shape: (`*batch_dims`, dim, 4)
2830+
The head representations.
2831+
:param r: shape: (`*batch_dims`, dim, 4)
2832+
The head representations.
2833+
:param t: shape: (`*batch_dims`, dim, 4)
2834+
The tail representations.
2835+
2836+
:return: shape: (...)
2837+
The scores.
2838+
"""
2839+
# TODO: this sign is in the official code, too, but why do we need it?
2840+
# note: this is a fused kernel for computing the Hamilton product and the inner product at once
2841+
return -einsum("...di, ...dj, ...dk, ijk -> ...", h, r, t, self.table)
28172842

28182843

28192844
class MonotonicAffineTransformationInteraction(

‎src/pykeen/nn/quaternion.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
"""Utilities for quaternions."""
2+
3+
from functools import lru_cache
4+
5+
import torch
6+
7+
from ..typing import FloatTensor
8+
9+
__all__ = [
10+
"normalize",
11+
"hamiltonian_product",
12+
"multiplication_table",
13+
]
14+
15+
16+
def normalize(x: FloatTensor) -> FloatTensor:
17+
r"""
18+
Normalize the length of relation vectors, if the forward constraint has not been applied yet.
19+
20+
Absolute value of a quaternion
21+
22+
.. math::
23+
24+
|a + bi + cj + dk| = \sqrt{a^2 + b^2 + c^2 + d^2}
25+
26+
L2 norm of quaternion vector:
27+
28+
.. math::
29+
\|x\|^2 = \sum_{i=1}^d |x_i|^2
30+
= \sum_{i=1}^d (x_i.re^2 + x_i.im_1^2 + x_i.im_2^2 + x_i.im_3^2)
31+
32+
:param x: shape: ``(*batch_dims, 4 \cdot d)``
33+
The vector in flat form.
34+
35+
:return: shape: ``(*batch_dims, 4 \cdot d)``
36+
The normalized vector.
37+
"""
38+
# Normalize relation embeddings
39+
shape = x.shape
40+
x = x.view(*shape[:-1], -1, 4)
41+
x = torch.nn.functional.normalize(x, p=2, dim=-1)
42+
return x.view(*shape)
43+
44+
45+
def hamiltonian_product(qa: FloatTensor, qb: FloatTensor) -> FloatTensor:
46+
"""Compute the hamiltonian product of two quaternions (which enables rotation)."""
47+
return torch.stack(
48+
[
49+
qa[0] * qb[0] - qa[1] * qb[1] - qa[2] * qb[2] - qa[3] * qb[3],
50+
qa[0] * qb[1] + qa[1] * qb[0] + qa[2] * qb[3] - qa[3] * qb[2],
51+
qa[0] * qb[2] - qa[1] * qb[3] + qa[2] * qb[0] + qa[3] * qb[1],
52+
qa[0] * qb[3] + qa[1] * qb[2] - qa[2] * qb[1] + qa[3] * qb[0],
53+
],
54+
dim=-1,
55+
)
56+
57+
58+
@lru_cache(1)
59+
def multiplication_table() -> FloatTensor:
60+
"""
61+
Create the quaternion basis multiplication table.
62+
63+
:return: shape: (4, 4, 4)
64+
the table of products of basis elements.
65+
66+
..seealso:: https://en.wikipedia.org/wiki/Quaternion#Multiplication_of_basis_elements
67+
"""
68+
_1, _i, _j, _k = 0, 1, 2, 3
69+
table = torch.zeros(4, 4, 4)
70+
for i, j, k, v in [
71+
# 1 * ? = ?; ? * 1 = ?
72+
(_1, _1, _1, 1),
73+
(_1, _i, _i, 1),
74+
(_1, _j, _j, 1),
75+
(_1, _k, _k, 1),
76+
(_i, _1, _i, 1),
77+
(_j, _1, _j, 1),
78+
(_k, _1, _k, 1),
79+
# i**2 = j**2 = k**2 = -1
80+
(_i, _i, _1, -1),
81+
(_j, _j, _1, -1),
82+
(_k, _k, _1, -1),
83+
# i * j = k; i * k = -j
84+
(_i, _j, _k, 1),
85+
(_i, _k, _j, -1),
86+
# j * i = -k, j * k = i
87+
(_j, _i, _k, -1),
88+
(_j, _k, _i, 1),
89+
# k * i = j; k * j = -i
90+
(_k, _i, _j, 1),
91+
(_k, _j, _i, -1),
92+
]:
93+
table[i, j, k] = v
94+
return table

‎tests/test_nn/test_modules.py

+11-18
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,15 @@
1414
import pykeen.nn.modules
1515
import pykeen.nn.sim
1616
import pykeen.utils
17-
from pykeen.models.unimodal.quate import quaternion_normalizer
17+
from pykeen.nn import quaternion
1818
from pykeen.typing import Representation, Sign
19-
from pykeen.utils import clamp_norm, complex_normalize, einsum, ensure_tuple, project_entity
19+
from pykeen.utils import (
20+
clamp_norm,
21+
complex_normalize,
22+
einsum,
23+
ensure_tuple,
24+
project_entity,
25+
)
2026
from tests import cases
2127

2228
logger = logging.getLogger(__name__)
@@ -226,34 +232,21 @@ def _exp_score(self, h, r, t) -> torch.FloatTensor:
226232
)
227233

228234

229-
def _rotate_quaternion(qa: torch.FloatTensor, qb: torch.FloatTensor) -> torch.FloatTensor:
230-
# Rotate (=Hamilton product in quaternion space).
231-
return torch.stack(
232-
[
233-
qa[0] * qb[0] - qa[1] * qb[1] - qa[2] * qb[2] - qa[3] * qb[3],
234-
qa[0] * qb[1] + qa[1] * qb[0] + qa[2] * qb[3] - qa[3] * qb[2],
235-
qa[0] * qb[2] - qa[1] * qb[3] + qa[2] * qb[0] + qa[3] * qb[1],
236-
qa[0] * qb[3] + qa[1] * qb[2] - qa[2] * qb[1] + qa[3] * qb[0],
237-
],
238-
dim=-1,
239-
)
240-
241-
242235
class QuatETests(cases.InteractionTestCase):
243236
"""Tests for QuatE interaction."""
244237

245238
cls = pykeen.nn.modules.QuatEInteraction
246239
shape_kwargs = dict(k=4) # quaternions
247240
atol = 1.0e-06
248241

249-
def _exp_score(self, h: torch.Tensor, r: torch.Tensor, t: torch.Tensor, table: torch.Tensor) -> torch.FloatTensor: # noqa: D102
242+
def _exp_score(self, h: torch.Tensor, r: torch.Tensor, t: torch.Tensor) -> torch.FloatTensor: # noqa: D102
250243
# we calculate the scores using the hard-coded formula, instead of utilizing table + einsum
251-
x = _rotate_quaternion(*(x.unbind(dim=-1) for x in [h, r]))
244+
x = quaternion.hamiltonian_product(*(x.unbind(dim=-1) for x in [h, r]))
252245
return -(x * t).sum()
253246

254247
def _get_hrt(self, *shapes):
255248
h, r, t = super()._get_hrt(*shapes)
256-
r = quaternion_normalizer(r)
249+
r = quaternion.normalize(r)
257250
return h, r, t
258251

259252

‎tests/test_nn/test_algebra.py ‎tests/test_nn/test_quaternion.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
import pykeen.nn.algebra
5+
from pykeen.nn import quaternion
66

77

88
def _test_multiplication_table(t: torch.Tensor):
@@ -22,4 +22,4 @@ def _test_multiplication_table(t: torch.Tensor):
2222

2323
def test_quaternion_multiplication_table():
2424
"""Test quaternion multiplication table."""
25-
_test_multiplication_table(pykeen.nn.algebra.quaterion_multiplication_table())
25+
_test_multiplication_table(quaternion.multiplication_table())

0 commit comments

Comments
 (0)
Please sign in to comment.