30
30
from typing_extensions import Self
31
31
32
32
from . import functional as pkf
33
- from . import init
34
- from .algebra import quaterion_multiplication_table
33
+ from . import init , quaternion
35
34
from .compute_kernel import batched_dot
36
35
from .sim import KG2ESimilarity , kg2e_similarity_resolver
37
36
from .utils import apply_optional_bn
@@ -2782,16 +2781,23 @@ def forward(self, h: FloatTensor, r: tuple[FloatTensor, FloatTensor], t: FloatTe
2782
2781
2783
2782
2784
2783
@parse_docdata
2785
- class QuatEInteraction (
2786
- FunctionalInteraction [
2787
- FloatTensor ,
2788
- FloatTensor ,
2789
- FloatTensor ,
2790
- ],
2791
- ):
2792
- """A module wrapper for the QuatE interaction function.
2784
+ class QuatEInteraction (Interaction [FloatTensor , FloatTensor , FloatTensor ]):
2785
+ r"""The state-less QuatE interaction function.
2786
+
2787
+ It is given as
2788
+
2789
+ .. math ::
2790
+ \langle \mathbf{h} \otimes \mathbf{r}, \mathbf{t} \rangle
2791
+
2792
+ where $\mathbf{h}, \mathbf{r}, \mathbf{t} \in \mathbb{H}^d$ are quanternion representations,
2793
+ $\otimes$ denotes the Hamilton product, and $\langle \cdot, \cdot \rangle$ the inner product.
2793
2794
2794
- .. seealso:: :func:`pykeen.nn.functional.quat_e_interaction`
2795
+ .. warning ::
2796
+ In order to representation a rotation, $\mathbf{r}$ must be normalized to unit length,
2797
+ cf. :func:`pykeen.nn.quaternion.normalize`.
2798
+
2799
+ .. seealso::
2800
+ - https://en.wikipedia.org/wiki/Quaternion
2795
2801
2796
2802
---
2797
2803
citation:
@@ -2805,15 +2811,34 @@ class QuatEInteraction(
2805
2811
# with k=4
2806
2812
entity_shape : Sequence [str ] = ("dk" ,)
2807
2813
relation_shape : Sequence [str ] = ("dk" ,)
2808
- func = pkf .quat_e_interaction
2809
2814
2810
2815
def __init__ (self ) -> None :
2811
2816
"""Initialize the interaction module."""
2812
2817
super ().__init__ ()
2813
- self .register_buffer (name = "table" , tensor = quaterion_multiplication_table ())
2818
+ self .register_buffer (name = "table" , tensor = quaternion . multiplication_table ())
2814
2819
2815
- def _prepare_state_for_functional (self ) -> MutableMapping [str , Any ]:
2816
- return dict (table = self .table )
2820
+ def forward (self , h : FloatTensor , r : tuple [FloatTensor , FloatTensor ], t : FloatTensor ) -> FloatTensor :
2821
+ """Evaluate the interaction function of QuatE for given embeddings.
2822
+
2823
+ The embeddings have to be in a broadcastable shape.
2824
+
2825
+ .. seealso::
2826
+ :meth:`Interaction.forward <pykeen.nn.modules.Interaction.forward>` for a detailed description about
2827
+ the generic batched form of the interaction function.
2828
+
2829
+ :param h: shape: (`*batch_dims`, dim, 4)
2830
+ The head representations.
2831
+ :param r: shape: (`*batch_dims`, dim, 4)
2832
+ The head representations.
2833
+ :param t: shape: (`*batch_dims`, dim, 4)
2834
+ The tail representations.
2835
+
2836
+ :return: shape: (...)
2837
+ The scores.
2838
+ """
2839
+ # TODO: this sign is in the official code, too, but why do we need it?
2840
+ # note: this is a fused kernel for computing the Hamilton product and the inner product at once
2841
+ return - einsum ("...di, ...dj, ...dk, ijk -> ..." , h , r , t , self .table )
2817
2842
2818
2843
2819
2844
class MonotonicAffineTransformationInteraction (
0 commit comments