"""Test cases for PyKEEN."""

import inspect
import itertools
import logging
import os
import pathlib
import tempfile
import timeit
import traceback
import unittest
from abc import ABC, abstractmethod
from collections import ChainMap, Counter
from collections.abc import Collection, Iterable, Mapping, MutableMapping, Sequence
from typing import (
    Any,
    Callable,
    ClassVar,
    Optional,
    TypeVar,
    Union,
)
from unittest.case import SkipTest
from unittest.mock import Mock, patch

import numpy
import numpy.random
import pandas
import pytest
import torch
import torch.utils.data
import unittest_templates
from class_resolver import HintOrType
from click.testing import CliRunner, Result
from docdata import get_docdata
from torch import optim
from torch.nn import functional
from torch.optim import SGD, Adagrad

import pykeen.evaluation.evaluation_loop
import pykeen.models
import pykeen.nn.combination
import pykeen.nn.message_passing
import pykeen.nn.node_piece
import pykeen.nn.representation
import pykeen.nn.text
import pykeen.nn.weighting
import pykeen.predict
from pykeen.datasets import Nations
from pykeen.datasets.base import LazyDataset
from pykeen.datasets.ea.combination import GraphPairCombinator
from pykeen.datasets.kinships import KINSHIPS_TRAIN_PATH
from pykeen.datasets.mocks import create_inductive_dataset
from pykeen.datasets.nations import NATIONS_TEST_PATH, NATIONS_TRAIN_PATH
from pykeen.evaluation import Evaluator, MetricResults, evaluator_resolver
from pykeen.losses import Loss, PairwiseLoss, PointwiseLoss, SetwiseLoss, UnsupportedLabelSmoothingError, loss_resolver
from pykeen.metrics import rank_based_metric_resolver
from pykeen.metrics.ranking import (
    DerivedRankBasedMetric,
    NoClosedFormError,
    RankBasedMetric,
    generate_num_candidates_and_ranks,
)
from pykeen.models import RESCAL, ERModel, Model, TransE
from pykeen.models.cli import build_cli_from_cls
from pykeen.models.meta.filtered import CooccurrenceFilteredModel
from pykeen.models.mocks import FixedModel
from pykeen.nn.modules import DistMultInteraction, FunctionalInteraction, Interaction
from pykeen.nn.representation import Representation
from pykeen.nn.utils import adjacency_tensor_to_stacked_matrix
from pykeen.optimizers import optimizer_resolver
from pykeen.pipeline import pipeline
from pykeen.regularizers import LpRegularizer, Regularizer
from pykeen.stoppers.early_stopping import EarlyStopper
from pykeen.trackers import ResultTracker
from pykeen.training import LCWATrainingLoop, SLCWATrainingLoop, TrainingCallback, TrainingLoop
from pykeen.triples import Instances, TriplesFactory, generation
from pykeen.triples.instances import BaseBatchedSLCWAInstances, SLCWABatch
from pykeen.triples.splitting import Cleaner, Splitter
from pykeen.triples.triples_factory import CoreTriplesFactory
from pykeen.triples.utils import get_entities
from pykeen.typing import (
    EA_SIDE_LEFT,
    EA_SIDE_RIGHT,
    LABEL_HEAD,
    LABEL_TAIL,
    RANK_REALISTIC,
    SIDE_BOTH,
    TRAINING,
    HeadRepresentation,
    InductiveMode,
    Initializer,
    MappedTriples,
    RelationRepresentation,
    TailRepresentation,
    Target,
)
from pykeen.utils import (
    all_in_bounds,
    get_batchnorm_modules,
    getattr_or_docdata,
    is_triple_tensor_subset,
    resolve_device,
    set_random_seed,
    triple_tensor_to_set,
    unpack_singletons,
)
from tests.constants import EPSILON
from tests.mocks import MockEvaluator
from tests.utils import needs_packages, rand

T = TypeVar("T")

logger = logging.getLogger(__name__)


class GenericTestCase(unittest_templates.GenericTestCase[T]):
    """Generic tests."""

    generator: torch.Generator

    def pre_setup_hook(self) -> None:
        """Instantiate a generator for usage in the test case."""
        self.generator = set_random_seed(seed=42)[1]


class DatasetTestCase(unittest.TestCase):
    """A test case for quickly defining common tests for datasets."""

    #: The expected number of entities
    exp_num_entities: ClassVar[int]
    #: The expected number of relations
    exp_num_relations: ClassVar[int]
    #: The expected number of triples
    exp_num_triples: ClassVar[int]
    #: The tolerance on expected number of triples, for randomized situations
    exp_num_triples_tolerance: ClassVar[Optional[int]] = None

    #: The dataset to test
    dataset_cls: ClassVar[type[LazyDataset]]
    #: The instantiated dataset
    dataset: LazyDataset

    #: Should the validation be assumed to have been loaded with train/test?
    autoloaded_validation: ClassVar[bool] = False

    def test_dataset(self):
        """Generic test for datasets."""
        self.assertIsInstance(self.dataset, LazyDataset)

        # Not loaded
        self.assertIsNone(self.dataset._training)
        self.assertIsNone(self.dataset._testing)
        self.assertIsNone(self.dataset._validation)
        self.assertFalse(self.dataset._loaded)
        self.assertFalse(self.dataset._loaded_validation)

        # Load
        self.dataset._load()

        self.assertIsInstance(self.dataset.training, TriplesFactory)
        self.assertIsInstance(self.dataset.testing, TriplesFactory)
        self.assertTrue(self.dataset._loaded)

        if self.autoloaded_validation:
            self.assertTrue(self.dataset._loaded_validation)
        else:
            self.assertFalse(self.dataset._loaded_validation)
            self.dataset._load_validation()

        self.assertIsInstance(self.dataset.validation, TriplesFactory)

        self.assertIsNotNone(self.dataset._training)
        self.assertIsNotNone(self.dataset._testing)
        self.assertIsNotNone(self.dataset._validation)
        self.assertTrue(self.dataset._loaded)
        self.assertTrue(self.dataset._loaded_validation)

        self.assertEqual(self.dataset.num_entities, self.exp_num_entities)
        self.assertEqual(self.dataset.num_relations, self.exp_num_relations)

        num_triples = sum(
            triples_factory.num_triples
            for triples_factory in (self.dataset._training, self.dataset._testing, self.dataset._validation)
        )
        if self.exp_num_triples_tolerance is None:
            self.assertEqual(self.exp_num_triples, num_triples)
        else:
            self.assertAlmostEqual(self.exp_num_triples, num_triples, delta=self.exp_num_triples_tolerance)

        # Test caching
        start = timeit.default_timer()
        _ = self.dataset.training
        end = timeit.default_timer()
        # assert (end - start) < 1.0e-02
        self.assertAlmostEqual(start, end, delta=1.0e-02, msg="Caching should have made this operation fast")

        # Test consistency of training / validation / testing mapping
        training = self.dataset.training
        for part, factory in self.dataset.factory_dict.items():
            if not isinstance(factory, TriplesFactory):
                logger.warning("Skipping mapping consistency checks since triples factory does not provide mappings.")
                continue
            if part == "training":
                continue
            assert training.entity_to_id == factory.entity_to_id
            assert training.num_entities == factory.num_entities
            assert training.relation_to_id == factory.relation_to_id
            assert training.num_relations == factory.num_relations


class LocalDatasetTestCase(DatasetTestCase):
    """A test case for datasets that don't need a cache directory."""

    def setUp(self):
        """Set up the test case."""
        self.dataset = self.dataset_cls()


class CachedDatasetCase(DatasetTestCase):
    """A test case for datasets that need a cache directory."""

    #: The directory, if there is caching
    directory: Optional[tempfile.TemporaryDirectory]

    def setUp(self):
        """Set up the test with a temporary cache directory."""
        self.directory = tempfile.TemporaryDirectory()
        self.dataset = self.dataset_cls(cache_root=self.directory.name)

    def tearDown(self) -> None:
        """Tear down the test case by cleaning up the temporary cache directory."""
        self.directory.cleanup()


def iter_from_space(key: str, space: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]:
    """Iterate over some configurations for a single hyperparameter."""
    typ = space["type"]
    if typ == "categorical":
        for choice in space["choices"]:
            yield {key: choice}
    elif typ in (float, int):
        yield {key: space["low"]}
        yield {key: space["high"]}
    else:
        raise NotImplementedError(typ)


def iter_hpo_configs(hpo_default: Mapping[str, Mapping[str, Any]]) -> Iterable[Mapping[str, Any]]:
    """Iterate over some representative configurations from the HPO default."""
    for combination in itertools.product(
        *(iter_from_space(key=key, space=space) for key, space in hpo_default.items())
    ):
        yield ChainMap(*combination)


class LossTestCase(GenericTestCase[Loss]):
    """Base unittest for loss functions."""

    #: The batch size
    batch_size: ClassVar[int] = 3

    #: The number of negatives per positive for sLCWA training loop.
    num_neg_per_pos: ClassVar[int] = 7

    #: The number of entities LCWA training loop / label smoothing.
    num_entities: ClassVar[int] = 7

    def _check_loss_value(self, loss_value: torch.FloatTensor) -> None:
        """Check loss value dimensionality, and ability for backward."""
        # test reduction
        self.assertEqual(0, loss_value.ndim)

        # test finite loss value
        self.assertTrue(torch.isfinite(loss_value))

        # Test backward
        loss_value.backward()

    def help_test_process_slcwa_scores(
        self,
        positive_scores: torch.FloatTensor,
        negative_scores: torch.FloatTensor,
        batch_filter: Optional[torch.BoolTensor] = None,
    ):
        """Help test processing scores from SLCWA training loop."""
        loss_value = self.instance.process_slcwa_scores(
            positive_scores=positive_scores,
            negative_scores=negative_scores,
            label_smoothing=None,
            batch_filter=batch_filter,
            num_entities=self.num_entities,
        )
        self._check_loss_value(loss_value=loss_value)

    def test_process_slcwa_scores(self):
        """Test processing scores from SLCWA training loop."""
        positive_scores = torch.rand(self.batch_size, 1, requires_grad=True)
        negative_scores = torch.rand(self.batch_size, self.num_neg_per_pos, requires_grad=True)
        self.help_test_process_slcwa_scores(positive_scores=positive_scores, negative_scores=negative_scores)

    def test_process_slcwa_scores_filtered(self):
        """Test processing scores from SLCWA training loop with filtering."""
        positive_scores = torch.rand(self.batch_size, 1, requires_grad=True)
        negative_scores = torch.rand(self.batch_size, self.num_neg_per_pos, requires_grad=True)
        batch_filter = torch.rand(self.batch_size, self.num_neg_per_pos) < 0.5
        self.help_test_process_slcwa_scores(
            positive_scores=positive_scores,
            negative_scores=negative_scores[batch_filter],
            batch_filter=batch_filter,
        )

    def test_process_lcwa_scores(self):
        """Test processing scores from LCWA training loop without smoothing."""
        self.help_test_process_lcwa_scores(label_smoothing=None)

    def test_process_lcwa_scores_smooth(self):
        """Test processing scores from LCWA training loop with smoothing."""
        try:
            self.help_test_process_lcwa_scores(label_smoothing=0.01)
        except UnsupportedLabelSmoothingError as error:
            raise SkipTest from error

    def help_test_process_lcwa_scores(self, label_smoothing):
        """Help test processing scores from LCWA training loop."""
        predictions = torch.rand(self.batch_size, self.num_entities, requires_grad=True)
        labels = (torch.rand(self.batch_size, self.num_entities, requires_grad=True) > 0.8).float()
        loss_value = self.instance.process_lcwa_scores(
            predictions=predictions,
            labels=labels,
            label_smoothing=label_smoothing,
            num_entities=self.num_entities,
        )
        self._check_loss_value(loss_value=loss_value)

    def test_optimization_direction_lcwa(self):
        """Test whether the loss leads to increasing positive scores, and decreasing negative scores."""
        labels = torch.as_tensor(data=[0, 1], dtype=torch.get_default_dtype()).view(1, -1)
        predictions = torch.zeros(1, 2, requires_grad=True)
        optimizer = optimizer_resolver.make(query=None, params=[predictions])
        for _ in range(10):
            optimizer.zero_grad()
            loss = self.instance.process_lcwa_scores(predictions=predictions, labels=labels)
            loss.backward()
            optimizer.step()

        # negative scores decreased compared to positive ones
        assert predictions[0, 0] < predictions[0, 1] - 1.0e-06

    def test_optimization_direction_slcwa(self):
        """Test whether the loss leads to increasing positive scores, and decreasing negative scores."""
        positive_scores = torch.zeros(self.batch_size, 1, requires_grad=True)
        negative_scores = torch.zeros(self.batch_size, self.num_neg_per_pos, requires_grad=True)
        optimizer = optimizer_resolver.make(query=None, params=[positive_scores, negative_scores])
        for _ in range(10):
            optimizer.zero_grad()
            loss = self.instance.process_slcwa_scores(
                positive_scores=positive_scores,
                negative_scores=negative_scores,
            )
            loss.backward()
            optimizer.step()

        # negative scores decreased compared to positive ones
        assert (negative_scores < positive_scores.unsqueeze(dim=1) - 1.0e-06).all()

    def test_hpo_defaults(self):
        """Test hpo defaults."""
        signature = inspect.signature(self.cls.__init__)
        # check for invalid keys
        invalid_keys = set(self.cls.hpo_default.keys()).difference(signature.parameters)
        assert not invalid_keys
        # check that each parameter without a default occurs
        required_parameters = {
            parameter_name
            for parameter_name, parameter in signature.parameters.items()
            if parameter.default == inspect.Parameter.empty
            and parameter.kind not in {inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL}
        }.difference({"self"})
        missing_required = required_parameters.difference(self.cls.hpo_default.keys())
        assert not missing_required
        # try to instantiate loss for some configurations in the HPO search space
        for config in iter_hpo_configs(self.cls.hpo_default):
            assert loss_resolver.make(query=self.cls, pos_kwargs=config)


class PointwiseLossTestCase(LossTestCase):
    """Base unit test for label-based losses."""

    #: The number of entities.
    num_entities: int = 17

    def test_type(self):
        """Test the loss is the right type."""
        self.assertIsInstance(self.instance, PointwiseLoss)

    def test_label_loss(self):
        """Test ``forward(logits, labels)``."""
        logits = torch.rand(self.batch_size, self.num_entities, requires_grad=True)
        labels = functional.normalize(torch.rand(self.batch_size, self.num_entities, requires_grad=False), p=1, dim=-1)
        loss_value = self.instance(
            logits,
            labels,
        )
        self._check_loss_value(loss_value)


class PairwiseLossTestCase(LossTestCase):
    """Base unit test for pair-wise losses."""

    #: The number of negative samples
    num_negatives: int = 5

    def test_type(self):
        """Test the loss is the right type."""
        self.assertIsInstance(self.instance, PairwiseLoss)

    def test_pair_loss(self):
        """Test ``forward(pos_scores, neg_scores)``."""
        pos_scores = torch.rand(self.batch_size, 1, requires_grad=True)
        neg_scores = torch.rand(self.batch_size, self.num_negatives, requires_grad=True)
        loss_value = self.instance(
            pos_scores,
            neg_scores,
        )
        self._check_loss_value(loss_value)


class GMRLTestCase(PairwiseLossTestCase):
    """Tests for generalized margin ranking loss."""

    def test_label_smoothing_raise(self):
        """Test errors are raised if label smoothing is given."""
        with self.assertRaises(UnsupportedLabelSmoothingError):
            self.instance.process_lcwa_scores(..., ..., label_smoothing=5)
        with self.assertRaises(UnsupportedLabelSmoothingError):
            self.instance.process_lcwa_scores(..., ..., label_smoothing=5)


class SetwiseLossTestCase(LossTestCase):
    """Unit tests for setwise losses."""

    #: The number of entities.
    num_entities: int = 13

    def test_type(self):
        """Test the loss is the right type."""
        self.assertIsInstance(self.instance, SetwiseLoss)


class InteractionTestCase(
    GenericTestCase[Interaction[HeadRepresentation, RelationRepresentation, TailRepresentation]],
    ABC,
):
    """Generic test for interaction functions."""

    dim: int = 2
    batch_size: int = 3
    num_relations: int = 5
    num_entities: int = 7
    dtype: torch.dtype = torch.get_default_dtype()
    # the relative tolerance for checking close results, cf. torch.allclose
    rtol: float = 1.0e-5
    # the absolute tolerance for checking close results, cf. torch.allclose
    atol: float = 1.0e-8

    shape_kwargs = dict()

    def post_instantiation_hook(self) -> None:
        """Initialize parameters."""
        self.instance.reset_parameters()

    def _get_hrt(
        self,
        *shapes: tuple[int, ...],
    ):
        shape_kwargs = dict(self.shape_kwargs)
        shape_kwargs.setdefault("d", self.dim)
        result = tuple(
            tuple(
                torch.rand(
                    size=tuple(prefix_shape) + tuple(shape_kwargs[dim] for dim in weight_shape),
                    requires_grad=True,
                    dtype=self.dtype,
                )
                for weight_shape in weight_shapes
            )
            for prefix_shape, weight_shapes in zip(
                shapes,
                [self.instance.entity_shape, self.instance.relation_shape, self.instance.tail_entity_shape],
            )
        )
        return unpack_singletons(*result)

    def _check_scores(self, scores: torch.FloatTensor, exp_shape: tuple[int, ...]):
        """Check shape, dtype and gradients of scores."""
        assert torch.is_tensor(scores)
        assert scores.dtype == torch.float32
        assert scores.ndimension() == len(exp_shape)
        assert scores.shape == exp_shape
        assert scores.requires_grad
        self._additional_score_checks(scores)

    def _additional_score_checks(self, scores):
        """Additional checks for scores."""

    @property
    def _score_batch_sizes(self) -> Iterable[int]:
        """Return the list of batch sizes to test."""
        if get_batchnorm_modules(self.instance):
            return [self.batch_size]
        return [1, self.batch_size]

    def test_score_hrt(self):
        """Test score_hrt."""
        for batch_size in self._score_batch_sizes:
            h, r, t = self._get_hrt(
                (batch_size,),
                (batch_size,),
                (batch_size,),
            )
            scores = self.instance.score_hrt(h=h, r=r, t=t)
            self._check_scores(scores=scores, exp_shape=(batch_size, 1))

    def test_score_h(self):
        """Test score_h."""
        for batch_size in self._score_batch_sizes:
            h, r, t = self._get_hrt(
                (self.num_entities,),
                (batch_size,),
                (batch_size,),
            )
            scores = self.instance.score_h(all_entities=h, r=r, t=t)
            self._check_scores(scores=scores, exp_shape=(batch_size, self.num_entities))

    def test_score_h_slicing(self):
        """Test score_h with slicing."""
        #: The equivalence for models with batch norm only holds in evaluation mode
        self.instance.eval()
        h, r, t = self._get_hrt(
            (self.num_entities,),
            (self.batch_size,),
            (self.batch_size,),
        )
        scores = self.instance.score_h(all_entities=h, r=r, t=t, slice_size=self.num_entities // 2 + 1)
        scores_no_slice = self.instance.score_h(all_entities=h, r=r, t=t, slice_size=None)
        self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice)

    def test_score_r(self):
        """Test score_r."""
        for batch_size in self._score_batch_sizes:
            h, r, t = self._get_hrt(
                (batch_size,),
                (self.num_relations,),
                (batch_size,),
            )
            scores = self.instance.score_r(h=h, all_relations=r, t=t)
            if len(self.cls.relation_shape) == 0:
                exp_shape = (batch_size, 1)
            else:
                exp_shape = (batch_size, self.num_relations)
            self._check_scores(scores=scores, exp_shape=exp_shape)

    def test_score_r_slicing(self):
        """Test score_r with slicing."""
        if len(self.cls.relation_shape) == 0:
            raise unittest.SkipTest("No use in slicing relations for models without relation information.")
        #: The equivalence for models with batch norm only holds in evaluation mode
        self.instance.eval()
        h, r, t = self._get_hrt(
            (self.batch_size,),
            (self.num_relations,),
            (self.batch_size,),
        )
        scores = self.instance.score_r(h=h, all_relations=r, t=t, slice_size=self.num_relations // 2 + 1)
        scores_no_slice = self.instance.score_r(h=h, all_relations=r, t=t, slice_size=None)
        self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice)

    def test_score_t(self):
        """Test score_t."""
        for batch_size in self._score_batch_sizes:
            h, r, t = self._get_hrt(
                (batch_size,),
                (batch_size,),
                (self.num_entities,),
            )
            scores = self.instance.score_t(h=h, r=r, all_entities=t)
            self._check_scores(scores=scores, exp_shape=(batch_size, self.num_entities))

    def test_score_t_slicing(self):
        """Test score_t with slicing."""
        #: The equivalence for models with batch norm only holds in evaluation mode
        self.instance.eval()
        h, r, t = self._get_hrt(
            (self.batch_size,),
            (self.batch_size,),
            (self.num_entities,),
        )
        scores = self.instance.score_t(h=h, r=r, all_entities=t, slice_size=self.num_entities // 2 + 1)
        scores_no_slice = self.instance.score_t(h=h, r=r, all_entities=t, slice_size=None)
        self._check_close_scores(scores=scores, scores_no_slice=scores_no_slice)

    def _check_close_scores(self, scores, scores_no_slice):
        self.assertTrue(torch.isfinite(scores).all(), msg=f"Normal scores had nan:\n\t{scores}")
        self.assertTrue(torch.isfinite(scores_no_slice).all(), msg=f"Slice scores had nan\n\t{scores}")
        self.assertTrue(torch.allclose(scores, scores_no_slice), msg=f"Differences: {scores - scores_no_slice}")

    def _get_test_shapes(self) -> Collection[tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]]:
        """Return a set of test shapes for (h, r, t)."""
        return (
            (  # single score
                tuple(),
                tuple(),
                tuple(),
            ),
            (  # score_r with multi-t
                (self.batch_size, 1, 1),
                (1, self.num_relations, 1),
                (self.batch_size, 1, self.num_entities // 2 + 1),
            ),
            (  # score_r with multi-t and broadcasted head
                (1, 1, 1),
                (1, self.num_relations, 1),
                (self.batch_size, 1, self.num_entities),
            ),
            (  # full cwa
                (self.num_entities, 1, 1),
                (1, self.num_relations, 1),
                (1, 1, self.num_entities),
            ),
        )

    def _get_output_shape(
        self,
        hs: tuple[int, ...],
        rs: tuple[int, ...],
        ts: tuple[int, ...],
    ) -> tuple[int, ...]:
        components = []
        if self.instance.entity_shape:
            components.extend((hs, ts))
        if self.instance.relation_shape:
            components.append(rs)
        return tuple(max(ds) for ds in zip(*components))

    def test_forward(self):
        """Test forward."""
        for hs, rs, ts in self._get_test_shapes():
            if get_batchnorm_modules(self.instance) and any(numpy.prod(s) == 1 for s in (hs, rs, ts)):
                logger.warning(
                    f"Skipping test for shapes {hs}, {rs}, {ts} because too small batch size for batch norm",
                )
                continue
            h, r, t = self._get_hrt(hs, rs, ts)
            scores = self.instance(h=h, r=r, t=t)
            expected_shape = self._get_output_shape(hs, rs, ts)
            self._check_scores(scores=scores, exp_shape=expected_shape)

    def test_forward_consistency_with_functional(self):
        """Test forward's consistency with functional."""
        if not isinstance(self.instance, FunctionalInteraction):
            self.skipTest("Not a functional interaction")

        # set in eval mode (otherwise there are non-deterministic factors like Dropout
        self.instance.eval()
        for hs, rs, ts in self._get_test_shapes():
            h, r, t = self._get_hrt(hs, rs, ts)
            scores = self.instance(h=h, r=r, t=t)
            kwargs = self.instance._prepare_for_functional(h=h, r=r, t=t)
            scores_f = self.cls.func(**kwargs)
            assert torch.allclose(scores, scores_f)

    def test_scores(self):
        """Test individual scores."""
        # set in eval mode (otherwise there are non-deterministic factors like Dropout
        self.instance.eval()
        for _ in range(10):
            # test multiple different initializations
            self.instance.reset_parameters()
            h, r, t = self._get_hrt(tuple(), tuple(), tuple())

            if isinstance(self.instance, FunctionalInteraction):
                kwargs = self.instance._prepare_for_functional(h=h, r=r, t=t)
                # calculate by functional
                scores_f = self.cls.func(**kwargs).view(-1)
            else:
                kwargs = dict(h=h, r=r, t=t)
                scores_f = self.instance(h=h, r=r, t=t)

            # calculate manually
            scores_f_manual = self._exp_score(**kwargs).view(-1)
            if not torch.allclose(scores_f, scores_f_manual, rtol=self.rtol, atol=self.atol):
                # allclose checks: | input - other | < atol + rtol * |other|
                a_delta = (scores_f_manual - scores_f).abs()
                r_delta = (scores_f_manual - scores_f).abs() / scores_f.abs().clamp_min(1.0e-08)
                raise AssertionError(
                    f"Abs. Diff: {a_delta.item()} (tol.: {self.atol}); Rel. Diff: {r_delta.item()} (tol. {self.rtol})",
                )

    @abstractmethod
    def _exp_score(self, **kwargs) -> torch.FloatTensor:
        """Compute the expected score for a single-score batch."""
        raise NotImplementedError(f"{self.cls.__name__}({sorted(kwargs.keys())})")


class TranslationalInteractionTests(InteractionTestCase, ABC):
    """Common tests for translational interaction."""

    kwargs = dict(
        p=2,
    )

    def _additional_score_checks(self, scores):
        assert (scores <= 0).all()


class ResultTrackerTests(GenericTestCase[ResultTracker], unittest.TestCase):
    """Common tests for result trackers."""

    def test_start_run(self):
        """Test start_run."""
        self.instance.start_run(run_name="my_test.run")

    def test_end_run(self):
        """Test end_run."""
        self.instance.end_run()

    def test_log_metrics(self):
        """Test log_metrics."""
        for metrics, step, prefix in (
            (
                # simple
                {"a": 1.0},
                0,
                None,
            ),
            (
                # nested
                {"a": {"b": 5.0}, "c": -1.0},
                2,
                "test",
            ),
        ):
            self.instance.log_metrics(metrics=metrics, step=step, prefix=prefix)

    def test_log_params(self):
        """Test log_params."""
        # nested
        params = {
            "num_epochs": 12,
            "loss": {
                "margin": 2.0,  # a number
                "normalize": True,  # a bool
                "activation": "relu",  # a string
            },
        }
        prefix = None
        self.instance.log_params(params=params, prefix=prefix)


class FileResultTrackerTests(ResultTrackerTests):
    """Tests for FileResultTracker."""

    def setUp(self) -> None:
        """Set up the file result tracker test."""
        self.temporary_directory = tempfile.TemporaryDirectory()
        self.path = pathlib.Path(self.temporary_directory.name).joinpath("test.log")
        super().setUp()

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        # prepare a temporary test directory
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["path"] = self.path
        return kwargs

    def tearDown(self) -> None:  # noqa: D102
        # check that file was created
        assert self.path.is_file()
        # make sure to close file before trying to delete it
        self.instance.end_run()
        # delete intermediate files
        self.path.unlink()
        self.temporary_directory.cleanup()


class RegularizerTestCase(GenericTestCase[Regularizer]):
    """A test case for quickly defining common tests for regularizers."""

    #: The batch size
    batch_size: int = 16
    #: The device
    device: torch.device

    def post_instantiation_hook(self) -> None:
        """Move instance to device."""
        self.device = resolve_device()
        # move test instance to device
        self.instance = self.instance.to(self.device)

    def test_model(self) -> None:
        """Test whether the regularizer can be passed to a model."""
        triples_factory = Nations().training
        positive_batch = triples_factory.mapped_triples[: self.batch_size, :].to(device=self.device)

        # Use RESCAL as it regularizes multiple tensors of different shape.
        model = RESCAL(
            triples_factory=triples_factory,
            regularizer=self.instance,
        ).to(self.device)

        # verify that the regularizer is stored for both, entity and relation representations
        for r in (model.entity_representations, model.relation_representations):
            assert len(r) == 1
            self.assertEqual(r[0].regularizer, self.instance)

        # Forward pass (should update regularizer)
        model.score_hrt(hrt_batch=positive_batch)

        # Call post_parameter_update (should reset regularizer)
        model.post_parameter_update()

        # Check if regularization term is reset
        self.assertEqual(0.0, self.instance.term)

    def _check_reset(self, instance: Optional[Regularizer] = None):
        """Verify that the regularizer is in resetted state."""
        if instance is None:
            instance = self.instance
        # regularization term should be zero
        self.assertEqual(0.0, instance.regularization_term.item())
        # updated should be set to false
        self.assertFalse(instance.updated)

    def test_reset(self) -> None:
        """Test method `reset`."""
        # call method
        self.instance.reset()
        self._check_reset()

    def _generate_update_input(self, requires_grad: bool = False) -> Sequence[torch.FloatTensor]:
        """Generate input for update."""
        # generate random tensors
        return (
            rand(self.batch_size, 10, generator=self.generator, device=self.device).requires_grad_(requires_grad),
            rand(self.batch_size, 20, generator=self.generator, device=self.device).requires_grad_(requires_grad),
        )

    def _expected_updated_term(self, inputs: Sequence[torch.FloatTensor]) -> torch.FloatTensor:
        """Calculate the expected updated regularization term."""
        exp_penalties = torch.stack([self._expected_penalty(x) for x in inputs])
        expected_term = torch.sum(exp_penalties).view(1) * self.instance.weight
        assert expected_term.shape == (1,)
        return expected_term

    def test_update(self) -> None:
        """Test method `update`."""
        # generate inputs
        inputs = self._generate_update_input()

        # call update
        self.instance.update(*inputs)

        # check shape
        self.assertEqual((1,), self.instance.term.shape)

        # check result
        expected_term = self._expected_updated_term(inputs=inputs)
        self.assertAlmostEqual(self.instance.regularization_term.item(), expected_term.item())

    def test_forward(self) -> None:
        """Test the regularizer's `forward` method."""
        # generate single random tensor
        x = rand(self.batch_size, 10, generator=self.generator, device=self.device)

        # calculate penalty
        penalty = self.instance(x=x)

        # check shape
        assert penalty.numel() == 1

        # check value
        expected_penalty = self._expected_penalty(x=x)
        if expected_penalty is None:
            logging.warning(f"{self.__class__.__name__} did not override `_expected_penalty`.")
        else:
            assert (expected_penalty == penalty).all()

    def _expected_penalty(self, x: torch.FloatTensor) -> Optional[torch.FloatTensor]:
        """Compute expected penalty for given tensor."""
        return None

    def test_pop_regularization_term(self):
        """Verify popping a regularization term."""
        # update term
        inputs = self._generate_update_input(requires_grad=True)
        self.instance.update(*inputs)

        # check that the expected term is returned
        exp = (self.instance.weight * self._expected_updated_term(inputs)).item()
        self.assertEqual(exp, self.instance.pop_regularization_term().item())

        # check that the regularizer is now reset
        self._check_reset()

    def test_apply_only_once(self):
        """Test apply-only-once support."""
        # create another instance with apply_only_once enabled
        instance = self.cls(**ChainMap(dict(apply_only_once=True), self.instance_kwargs)).to(self.device)

        # test initial state
        self._check_reset(instance=instance)

        # after first update, should change the term
        first_tensors = self._generate_update_input()
        instance.update(*first_tensors)
        self.assertTrue(instance.updated)
        self.assertNotEqual(0.0, instance.regularization_term.item())
        term = instance.regularization_term.clone()

        # after second update, no change should happen
        second_tensors = self._generate_update_input()
        instance.update(*second_tensors)
        self.assertTrue(instance.updated)
        self.assertEqual(term, instance.regularization_term)


class LpRegularizerTest(RegularizerTestCase):
    """Common test for L_p regularizers."""

    cls = LpRegularizer

    def _expected_penalty(self, x: torch.FloatTensor) -> torch.FloatTensor:  # noqa: D102
        kwargs = self.kwargs
        if kwargs is None:
            kwargs = {}
        p = kwargs.get("p", self.instance.p)
        value = x.norm(p=p, dim=-1).mean()
        if kwargs.get("normalize", False):
            dim = torch.as_tensor(x.shape[-1], dtype=torch.float, device=x.device)
            # FIXME isn't any finite number allowed now?
            if p == 2:
                value = value / dim.sqrt()
            elif p == 1:
                value = value / dim
            else:
                raise NotImplementedError
        return value


class ModelTestCase(unittest_templates.GenericTestCase[Model]):
    """A test case for quickly defining common tests for KGE models."""

    #: Additional arguments passed to the training loop's constructor method
    training_loop_kwargs: ClassVar[Optional[Mapping[str, Any]]] = None

    #: The triples factory instance
    factory: TriplesFactory

    #: The batch size for use for forward_* tests
    batch_size: int = 20

    #: The embedding dimensionality
    embedding_dim: int = 3

    #: Whether to create inverse triples (needed e.g. by ConvE)
    create_inverse_triples: bool = False

    #: The sampler to use for sLCWA (different e.g. for R-GCN)
    sampler: Optional[str] = None

    #: The batch size for use when testing training procedures
    train_batch_size = 400

    #: The number of epochs to train the model
    train_num_epochs = 2

    #: A random number generator from torch
    generator: torch.Generator

    #: The number of parameters which receive a constant (i.e. non-randomized)
    # initialization
    num_constant_init: int = 0

    #: Static extras to append to the CLI
    cli_extras: Sequence[str] = tuple()

    #: the model's device
    device: torch.device

    #: the inductive mode
    mode: ClassVar[Optional[InductiveMode]] = None

    def pre_setup_hook(self) -> None:  # noqa: D102
        # for reproducible testing
        _, self.generator, _ = set_random_seed(42)
        self.device = resolve_device()

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        dataset = Nations(create_inverse_triples=self.create_inverse_triples)
        self.factory = dataset.training
        # insert shared parameters
        kwargs["triples_factory"] = self.factory
        kwargs["embedding_dim"] = self.embedding_dim
        return kwargs

    def post_instantiation_hook(self) -> None:  # noqa: D102
        # move model to correct device
        self.instance = self.instance.to(self.device)

    def test_get_grad_parameters(self):
        """Test the model's ``get_grad_params()`` method."""
        self.assertLess(
            0, len(list(self.instance.get_grad_params())), msg="There is not at least one trainable parameter"
        )

        # Check that all the parameters actually require a gradient
        for parameter in self.instance.get_grad_params():
            assert parameter.requires_grad

        # Try to initialize an optimizer
        optimizer = SGD(params=self.instance.get_grad_params(), lr=1.0)
        assert optimizer is not None

    def test_reset_parameters_(self):
        """Test :func:`Model.reset_parameters_`."""
        # get model parameters
        params = list(self.instance.parameters())
        old_content = {id(p): p.data.detach().clone() for p in params}

        # re-initialize
        self.instance.reset_parameters_()

        # check that the operation works in-place
        new_params = list(self.instance.parameters())
        assert set(id(np) for np in new_params) == set(id(p) for p in params)

        # check that the parameters where modified
        num_equal_weights_after_re_init = sum(
            1 for new_param in new_params if (new_param.data == old_content[id(new_param)]).all()
        )
        self.assertEqual(num_equal_weights_after_re_init, self.num_constant_init)

    def _check_scores(self, batch, scores) -> None:
        """Check the scores produced by a forward function."""
        # check for finite values by default
        self.assertTrue(torch.all(torch.isfinite(scores)).item(), f"Some scores were not finite:\n{scores}")

        # check whether a gradient can be back-propgated
        scores.mean().backward()

    def test_save(self) -> None:
        """Test that the model can be saved properly."""
        with tempfile.TemporaryDirectory() as temp_directory:
            torch.save(self.instance, os.path.join(temp_directory, "model.pickle"))

    def _test_score(
        self, score: Callable, columns: Union[Sequence[int], slice], shape: tuple[int, ...], **kwargs
    ) -> None:
        """Test score functions."""
        batch = self.factory.mapped_triples[: self.batch_size, columns].to(self.instance.device)
        try:
            scores = score(batch, mode=self.mode, **kwargs)
        except ValueError as error:
            raise SkipTest() from error
        except NotImplementedError:
            self.fail(msg=f"{score} not yet implemented")
        except RuntimeError as e:
            if str(e) == "fft: ATen not compiled with MKL support":
                self.skipTest(str(e))
            else:
                raise e
        if score is self.instance.score_r and self.create_inverse_triples:
            # TODO: look into score_r for inverse relations
            logger.warning("score_r's shape is not clear yet for models with inverse relations")
        else:
            self.assertTupleEqual(tuple(scores.shape), shape)
        self._check_scores(batch, scores)
        # clear buffers for message passing models
        self.instance.post_parameter_update()

    def _test_score_multi(self, name: str, max_id: int, **kwargs):
        """Test score functions with multi scoring."""
        k = max_id // 2
        for ids in (
            torch.randperm(max_id)[:k],
            torch.randint(max_id, size=(self.batch_size, k)),
        ):
            with self.subTest(shape=ids.shape):
                self._test_score(shape=(self.batch_size, k), **kwargs, **{name: ids.to(device=self.instance.device)})

    def test_score_hrt(self) -> None:
        """Test the model's ``score_hrt()`` function."""
        self._test_score(score=self.instance.score_hrt, columns=slice(None), shape=(self.batch_size, 1))

    def test_score_t(self) -> None:
        """Test the model's ``score_t()`` function."""
        self._test_score(
            score=self.instance.score_t, columns=slice(0, 2), shape=(self.batch_size, self.instance.num_entities)
        )

    def test_score_t_multi(self) -> None:
        """Test the model's ``score_t()`` function with custom tail candidates."""
        self._test_score_multi(
            name="tails", max_id=self.factory.num_entities, score=self.instance.score_t, columns=slice(0, 2)
        )

    def test_score_r(self) -> None:
        """Test the model's ``score_r()`` function."""
        self._test_score(
            score=self.instance.score_r,
            columns=[0, 2],
            shape=(self.batch_size, self.instance.num_relations),
        )

    def test_score_r_multi(self) -> None:
        """Test the model's ``score_r()`` function with custom relation candidates."""
        self._test_score_multi(
            name="relations", max_id=self.factory.num_relations, score=self.instance.score_r, columns=[0, 2]
        )

    def test_score_h(self) -> None:
        """Test the model's ``score_h()`` function."""
        self._test_score(
            score=self.instance.score_h, columns=slice(1, None), shape=(self.batch_size, self.instance.num_entities)
        )

    def test_score_h_multi(self) -> None:
        """Test the model's ``score_h()`` function with custom head candidates."""
        self._test_score_multi(
            name="heads", max_id=self.factory.num_entities, score=self.instance.score_h, columns=slice(1, None)
        )

    @pytest.mark.slow
    def test_train_slcwa(self) -> None:
        """Test that sLCWA training does not fail."""
        loop = SLCWATrainingLoop(
            model=self.instance,
            triples_factory=self.factory,
            optimizer=Adagrad(params=self.instance.get_grad_params(), lr=0.001),
            **(self.training_loop_kwargs or {}),
        )
        losses = self._safe_train_loop(
            loop,
            num_epochs=self.train_num_epochs,
            batch_size=self.train_batch_size,
            sampler=self.sampler,
        )
        self.assertIsInstance(losses, list)

    @pytest.mark.slow
    def test_train_lcwa(self) -> None:
        """Test that LCWA training does not fail."""
        loop = LCWATrainingLoop(
            model=self.instance,
            triples_factory=self.factory,
            optimizer=Adagrad(params=self.instance.get_grad_params(), lr=0.001),
            **(self.training_loop_kwargs or {}),
        )
        losses = self._safe_train_loop(
            loop,
            num_epochs=self.train_num_epochs,
            batch_size=self.train_batch_size,
            sampler=None,
        )
        self.assertIsInstance(losses, list)

    def _safe_train_loop(self, loop: TrainingLoop, num_epochs, batch_size, sampler):
        try:
            losses = loop.train(
                triples_factory=self.factory,
                num_epochs=num_epochs,
                batch_size=batch_size,
                sampler=sampler,
                use_tqdm=False,
            )
        except RuntimeError as e:
            if str(e) == "fft: ATen not compiled with MKL support":
                self.skipTest(str(e))
            else:
                raise e
        else:
            return losses

    def test_save_load_model_state(self):
        """Test whether a saved model state can be re-loaded."""
        original_model = self.cls(
            random_seed=42,
            **self.instance_kwargs,
        )

        loaded_model = self.cls(
            random_seed=21,
            **self.instance_kwargs,
        )

        def _equal_embeddings(a: Representation, b: Representation) -> bool:
            """Test whether two embeddings are equal."""
            return (a(indices=None) == b(indices=None)).all()

        with tempfile.TemporaryDirectory() as tmpdirname:
            file_path = os.path.join(tmpdirname, "test.pt")
            original_model.save_state(path=file_path)
            loaded_model.load_state(path=file_path)

    @property
    def _cli_extras(self):
        """Return a list of extra flags for the CLI."""
        kwargs = self.kwargs or {}
        extras = [
            "--silent",
        ]
        for k, v in kwargs.items():
            extras.append("--" + k.replace("_", "-"))
            extras.append(str(v))

        # For the high/low memory test cases of NTN, SE, etc.
        if self.training_loop_kwargs and "automatic_memory_optimization" in self.training_loop_kwargs:
            automatic_memory_optimization = self.training_loop_kwargs.get("automatic_memory_optimization")
            if automatic_memory_optimization is True:
                extras.append("--automatic-memory-optimization")
            elif automatic_memory_optimization is False:
                extras.append("--no-automatic-memory-optimization")
            # else, leave to default

        extras += [
            "--number-epochs",
            self.train_num_epochs,
            "--embedding-dim",
            self.embedding_dim,
            "--batch-size",
            self.train_batch_size,
        ]
        extras.extend(self.cli_extras)

        # Make sure that inverse triples are created if create_inverse_triples=True
        if self.create_inverse_triples:
            extras.append("--create-inverse-triples")

        extras = [str(e) for e in extras]
        return extras

    @pytest.mark.slow
    def test_cli_training_nations(self):
        """Test running the pipeline on almost all models with only training data."""
        self._help_test_cli(["-t", NATIONS_TRAIN_PATH] + self._cli_extras)

    @pytest.mark.slow
    def test_pipeline_nations_early_stopper(self):
        """Test running the pipeline with early stopping."""
        model_kwargs = dict(self.instance_kwargs)
        # triples factory is added by the pipeline
        model_kwargs.pop("triples_factory")
        pipeline(
            model=self.cls,
            model_kwargs=model_kwargs,
            dataset="nations",
            dataset_kwargs=dict(create_inverse_triples=self.create_inverse_triples),
            stopper="early",
            training_loop_kwargs=self.training_loop_kwargs,
            stopper_kwargs=dict(frequency=1),
            training_kwargs=dict(
                batch_size=self.train_batch_size,
                num_epochs=self.train_num_epochs,
            ),
        )

    @pytest.mark.slow
    def test_cli_training_kinships(self):
        """Test running the pipeline on almost all models with only training data."""
        self._help_test_cli(["-t", KINSHIPS_TRAIN_PATH] + self._cli_extras)

    @pytest.mark.slow
    def test_cli_training_nations_testing(self):
        """Test running the pipeline on almost all models with only training data."""
        self._help_test_cli(["-t", NATIONS_TRAIN_PATH, "-q", NATIONS_TEST_PATH] + self._cli_extras)

    def _help_test_cli(self, args):
        """Test running the pipeline on all models."""
        if (
            issubclass(self.cls, (pykeen.models.RGCN, pykeen.models.CooccurrenceFilteredModel))
            or self.cls is pykeen.models.ERModel
        ):
            self.skipTest(f"Cannot choose interaction via CLI for {self.cls}.")
        runner = CliRunner()
        cli = build_cli_from_cls(self.cls)
        # TODO: Catch HolE MKL error?
        result: Result = runner.invoke(cli, args)

        self.assertEqual(
            0,
            result.exit_code,
            msg=f"""
Command
=======
$ pykeen train {self.cls.__name__.lower()} {' '.join(map(str, args))}

Output
======
{result.output}

Exception
=========
{result.exc_info[1]}

Traceback
=========
{''.join(traceback.format_tb(result.exc_info[2]))}
            """,
        )

    def test_has_hpo_defaults(self):
        """Test that there are defaults for HPO."""
        try:
            d = self.cls.hpo_default
        except AttributeError:
            self.fail(msg=f"{self.cls.__name__} is missing hpo_default class attribute")
        else:
            self.assertIsInstance(d, dict)

    def test_post_parameter_update_regularizer(self):
        """Test whether post_parameter_update resets the regularization term."""
        if not hasattr(self.instance, "regularizer"):
            self.skipTest("no regularizer")

        # set regularizer term to something that isn't zero
        self.instance.regularizer.regularization_term = torch.ones(1, dtype=torch.float, device=self.instance.device)

        # call post_parameter_update
        self.instance.post_parameter_update()

        # assert that the regularization term has been reset
        expected_term = torch.zeros(1, dtype=torch.float, device=self.instance.device)
        assert self.instance.regularizer.regularization_term == expected_term

    def test_post_parameter_update(self):
        """Test whether post_parameter_update correctly enforces model constraints."""
        # do one optimization step
        opt = optim.SGD(params=self.instance.parameters(), lr=1.0)
        batch = self.factory.mapped_triples[: self.batch_size, :].to(self.instance.device)
        scores = self.instance.score_hrt(hrt_batch=batch, mode=self.mode)
        fake_loss = scores.mean()
        fake_loss.backward()
        opt.step()

        # call post_parameter_update
        self.instance.post_parameter_update()

        # check model constraints
        self._check_constraints()

    def _check_constraints(self):
        """Check model constraints."""

    def _test_score_equality(self, columns: Union[slice, list[int]], name: str) -> None:
        """Migration tests for non-ERModel models testing for consistent optimized score implementations."""
        if isinstance(self.instance, ERModel):
            raise SkipTest("ERModel fulfils this by design.")
        if isinstance(self.instance, CooccurrenceFilteredModel):
            raise SkipTest("CooccurrenceFilteredModel fulfils this if its base model fulfils it.")
        batch = self.factory.mapped_triples[: self.batch_size, columns].to(self.instance.device)
        self.instance.eval()
        try:
            scores = getattr(self.instance, name)(batch)
            scores_super = getattr(super(self.instance.__class__, self.instance), name)(batch)
        except NotImplementedError:
            self.fail(msg=f"{name} not yet implemented")
        except RuntimeError as e:
            if str(e) == "fft: ATen not compiled with MKL support":
                self.skipTest(str(e))
            else:
                raise e

        self.assertIsNotNone(scores)
        self.assertIsNotNone(scores_super)
        assert torch.allclose(scores, scores_super, atol=1e-06)

    def test_score_h_with_score_hrt_equality(self) -> None:
        """Test the equality of the model's  ``score_h()`` and ``score_hrt()`` function."""
        self._test_score_equality(columns=slice(1, None), name="score_h")

    def test_score_r_with_score_hrt_equality(self) -> None:
        """Test the equality of the model's  ``score_r()`` and ``score_hrt()`` function."""
        self._test_score_equality(columns=[0, 2], name="score_r")

    def test_score_t_with_score_hrt_equality(self) -> None:
        """Test the equality of the model's  ``score_t()`` and ``score_hrt()`` function."""
        self._test_score_equality(columns=slice(2), name="score_t")

    def test_reset_parameters_constructor_call(self):
        """Tests whether reset_parameters is called in the constructor."""
        with patch.object(self.cls, "reset_parameters_", return_value=None) as mock_method:
            try:
                self.cls(**self.instance_kwargs)
            except TypeError as error:
                assert error.args == ("'NoneType' object is not callable",)
            mock_method.assert_called_once()


class DistanceModelTestCase(ModelTestCase):
    """A test case for distance-based models."""

    def _check_scores(self, batch, scores) -> None:
        super()._check_scores(batch=batch, scores=scores)
        # Distance-based model
        assert (scores <= 0.0).all()


class BaseKG2ETest(ModelTestCase):
    """General tests for the KG2E model."""

    cls = pykeen.models.KG2E
    c_min: float = 0.01
    c_max: float = 1.0

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["c_min"] = self.c_min
        kwargs["c_max"] = self.c_max
        return kwargs

    def _check_constraints(self):
        """Check model constraints.

        * Entity and relation embeddings have to have at most unit L2 norm.
        * Covariances have to have values between c_min and c_max
        """
        self.instance: ERModel
        (e_mean, e_cov), (r_mean, r_cov) = self.instance.entity_representations, self.instance.relation_representations
        for embedding in (e_mean, r_mean):
            assert all_in_bounds(embedding(indices=None).norm(p=2, dim=-1), high=1.0, a_tol=EPSILON)
        for cov in (e_cov, r_cov):
            assert all_in_bounds(
                cov(indices=None), low=self.instance_kwargs["c_min"], high=self.instance_kwargs["c_max"]
            )


class BaseRGCNTest(ModelTestCase):
    """Test the R-GCN model."""

    cls = pykeen.models.RGCN
    sampler = "schlichtkrull"

    def _check_constraints(self):
        """Check model constraints.

        Enriched embeddings have to be reset.
        """
        assert self.instance.entity_representations[0].enriched_embeddings is None


class BaseNodePieceTest(ModelTestCase):
    """Test the NodePiece model."""

    cls = pykeen.models.NodePiece
    create_inverse_triples = True

    def _help_test_cli(self, args):  # noqa: D102
        if self.instance_kwargs.get("tokenizers_kwargs"):
            raise SkipTest("No support for tokenizers_kwargs via CLI.")
        return super()._help_test_cli(args)


class InductiveModelTestCase(ModelTestCase):
    """Tests for inductive models."""

    mode = TRAINING
    num_relations: ClassVar[int] = 7
    num_entities_transductive: ClassVar[int] = 13
    num_entities_inductive: ClassVar[int] = 5
    num_triples_training: ClassVar[int] = 33
    num_triples_inference: ClassVar[int] = 31
    num_triples_testing: ClassVar[int] = 37

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        dataset = create_inductive_dataset(
            num_relations=self.num_relations,
            num_entities_transductive=self.num_entities_transductive,
            num_entities_inductive=self.num_entities_inductive,
            num_triples_training=self.num_triples_training,
            num_triples_inference=self.num_triples_inference,
            num_triples_testing=self.num_triples_testing,
            create_inverse_triples=self.create_inverse_triples,
        )
        training_loop_kwargs = dict(self.training_loop_kwargs or dict())
        training_loop_kwargs["mode"] = self.mode
        InductiveModelTestCase.training_loop_kwargs = training_loop_kwargs
        # dataset = InductiveFB15k237(create_inverse_triples=self.create_inverse_triples)
        kwargs["triples_factory"] = self.factory = dataset.transductive_training
        kwargs["inference_factory"] = dataset.inductive_inference
        return kwargs

    def _help_test_cli(self, args):  # noqa: D102
        raise SkipTest("Inductive models are not compatible the CLI.")

    def test_pipeline_nations_early_stopper(self):  # noqa: D102
        raise SkipTest("Inductive models are not compatible the pipeline.")


class RepresentationTestCase(GenericTestCase[Representation]):
    """Common tests for representation modules."""

    batch_size: ClassVar[int] = 2
    num_negatives: ClassVar[int] = 3
    max_id: ClassVar[int] = 7

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        kwargs = super()._pre_instantiation_hook(kwargs)
        kwargs.update(dict(max_id=self.max_id))
        return kwargs

    def _check_result(self, x: torch.FloatTensor, prefix_shape: tuple[int, ...]):
        """Check the result."""
        # check type
        assert torch.is_tensor(x)
        assert x.dtype == torch.get_default_dtype()

        # check shape
        expected_shape = prefix_shape + self.instance.shape
        self.assertEqual(x.shape, expected_shape)

    def _test_forward(self, indices: Optional[torch.LongTensor]):
        """Test forward method."""
        representations = self.instance(indices=indices)
        prefix_shape = (self.instance.max_id,) if indices is None else tuple(indices.shape)
        self._check_result(x=representations, prefix_shape=prefix_shape)

    def _test_indices(self, indices: Optional[torch.LongTensor]):
        """Test forward and canonical shape for indices."""
        self._test_forward(indices=indices)

    def test_max_id(self):
        """Test maximum id."""
        self.assertEqual(self.max_id, self.instance.max_id)

    def test_no_indices(self):
        """Test without indices."""
        self._test_indices(indices=None)

    def test_1d_indices(self):
        """Test with 1-dimensional indices."""
        self._test_indices(indices=torch.randint(self.instance.max_id, size=(self.batch_size,)))

    def test_2d_indices(self):
        """Test with 1-dimensional indices."""
        self._test_indices(indices=(torch.randint(self.instance.max_id, size=(self.batch_size, self.num_negatives))))

    def test_all_indices(self):
        """Test with all indices."""
        self._test_indices(indices=torch.arange(self.instance.max_id))

    def test_dropout(self):
        """Test dropout layer."""
        # create a new instance with guaranteed dropout
        kwargs = self.instance_kwargs
        kwargs.pop("dropout", None)
        dropout_instance = self.cls(**kwargs, dropout=0.1)
        # set to training mode
        dropout_instance.train()
        # check for different output
        indices = torch.arange(2)
        # use more samples to make sure that enough values can be dropped
        a = torch.stack([dropout_instance(indices) for _ in range(20)])
        assert not (a[0:1] == a).all()

    def test_str(self):
        """Test generating the string representation."""
        # this implicitly tests extra_repr / iter_extra_repr
        assert isinstance(str(self), str)


class TriplesFactoryRepresentationTestCase(RepresentationTestCase):
    """Tests for representations requiring triples factories."""

    num_entities: ClassVar[int]
    num_relations: ClassVar[int] = 7
    num_triples: ClassVar[int] = 31
    create_inverse_triples: bool = False

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        self.num_entities = self.max_id
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["triples_factory"] = generation.generate_triples_factory(
            num_entities=self.max_id,
            num_relations=self.num_relations,
            num_triples=self.num_triples,
            create_inverse_triples=self.create_inverse_triples,
        )
        return kwargs


@needs_packages("torch_geometric")
class MessagePassingRepresentationTests(TriplesFactoryRepresentationTestCase):
    """Tests for message passing representations."""

    def test_consistency_k_hop(self):
        """Test consistency of results between using only k-hop and using the full graph."""
        # select random indices
        indices = torch.randint(self.num_entities, size=(self.num_entities // 2,), generator=self.generator)
        assert isinstance(self.instance, pykeen.nn.pyg.MessagePassingRepresentation)
        # forward pass with full graph
        self.instance.restrict_k_hop = False
        x_full = self.instance(indices=indices)
        # forward pass with restricted graph
        self.instance.restrict_k_hop = True
        x_restrict = self.instance(indices=indices)
        # verify the results are similar
        assert torch.allclose(x_full, x_restrict)


class EdgeWeightingTestCase(GenericTestCase[pykeen.nn.weighting.EdgeWeighting]):
    """Tests for message weighting."""

    #: The number of entities
    num_entities: int = 16

    #: The number of triples
    num_triples: int = 101

    #: the message dim
    message_dim: int = 3

    def post_instantiation_hook(self):  # noqa: D102
        self.source, self.target = torch.randint(self.num_entities, size=(2, self.num_triples))
        self.message = torch.rand(self.num_triples, self.message_dim, requires_grad=True)
        # TODO: separation message vs. entity dim?
        self.x_e = torch.rand(self.num_entities, self.message_dim)

    def _test(self, weights: torch.FloatTensor, shape: tuple[int, ...]):
        """Perform common tests."""
        # check shape
        assert weights.shape == shape

        # check dtype
        assert weights.dtype == torch.float32

        # check finite values (e.g. due to division by zero)
        assert torch.isfinite(weights).all()

        # check non-negativity
        assert (weights >= 0.0).all()

    def test_message_weighting(self):
        """Test message weighting with message."""
        self._test(
            weights=self.instance(source=self.source, target=self.target, message=self.message, x_e=self.x_e),
            shape=self.message.shape,
        )

    def test_message_weighting_no_message(self):
        """Test message weighting without message."""
        if self.instance.needs_message:
            raise SkipTest(f"{self.cls} needs messages for weighting them.")
        self._test(weights=self.instance(source=self.source, target=self.target), shape=self.source.shape)


class DecompositionTestCase(GenericTestCase[pykeen.nn.message_passing.Decomposition]):
    """Tests for relation-specific weight decomposition message passing classes."""

    #: the input dimension
    input_dim: int = 8
    #: the output dimension
    output_dim: int = 4

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        self.factory = Nations().training
        self.source, self.edge_type, self.target = self.factory.mapped_triples.t()
        self.x = torch.rand(self.factory.num_entities, self.input_dim, requires_grad=True)
        kwargs["input_dim"] = self.input_dim
        kwargs["output_dim"] = self.output_dim
        kwargs["num_relations"] = self.factory.num_relations
        return kwargs

    def test_forward(self):
        """Test the :meth:`Decomposition.forward` function."""
        for edge_weights in [None, torch.rand_like(self.source, dtype=torch.get_default_dtype())]:
            y = self.instance(
                x=self.x,
                source=self.source,
                target=self.target,
                edge_type=self.edge_type,
                edge_weights=edge_weights,
            )
            assert y.shape == (self.x.shape[0], self.output_dim)

    def prepare_adjacency(self, horizontal: bool) -> torch.Tensor:
        """
        Prepare adjacency matrix for the given stacking direction.

        :param horizontal:
            whether to stack horizontally or vertically

        :return:
            the adjacency matrix
        """
        return adjacency_tensor_to_stacked_matrix(
            num_relations=self.factory.num_relations,
            num_entities=self.factory.num_entities,
            source=self.source,
            target=self.target,
            edge_type=self.edge_type,
            horizontal=horizontal,
        )

    def check_output(self, x: torch.Tensor):
        """Check the output tensor."""
        assert torch.is_tensor(x)
        assert x.shape == (self.factory.num_entities, self.output_dim)
        assert x.requires_grad

    def test_horizontal(self):
        """Test processing of horizontally stacked matrix."""
        adj = self.prepare_adjacency(horizontal=True)
        x = self.instance.forward_horizontally_stacked(x=self.x, adj=adj)
        self.check_output(x=x)

    def test_vertical(self):
        """Test processing of vertically stacked matrix."""
        adj = self.prepare_adjacency(horizontal=False)
        x = self.instance.forward_vertically_stacked(x=self.x, adj=adj)
        self.check_output(x=x)


class InitializerTestCase(unittest.TestCase):
    """A test case for initializers."""

    #: the number of entities
    num_entities: ClassVar[int] = 33

    #: the shape of the tensor to initialize
    shape: ClassVar[tuple[int, ...]] = (3,)

    #: to be initialized / set in subclass
    initializer: Initializer

    #: the interaction to use for testing a model
    interaction: ClassVar[HintOrType[Interaction]] = DistMultInteraction
    dtype: ClassVar[torch.dtype] = torch.get_default_dtype()

    def test_initialization(self):
        """Test whether the initializer returns a modified tensor."""
        shape = (self.num_entities, *self.shape)
        if self.dtype.is_complex:
            shape = shape + (2,)
        x = torch.rand(size=shape)
        # initializers *may* work in-place => clone
        y = self.initializer(x.clone())
        assert not (x == y).all()
        self._verify_initialization(y)

    def _verify_initialization(self, x: torch.FloatTensor) -> None:
        """Verify properties of initialization."""
        pass

    def test_model(self):
        """Test whether initializer can be used for a model."""
        triples_factory = generation.generate_triples_factory(num_entities=self.num_entities)
        # actual number may be different...
        self.num_entities = triples_factory.num_entities
        model = pykeen.models.ERModel(
            triples_factory=triples_factory,
            interaction=self.interaction,
            entity_representations_kwargs=dict(shape=self.shape, initializer=self.initializer, dtype=self.dtype),
            relation_representations_kwargs=dict(shape=self.shape),
            random_seed=0,
        ).to(resolve_device())
        model.reset_parameters_()


class PredictBaseTestCase(unittest.TestCase):
    """Base test for prediction workflows."""

    batch_size: ClassVar[int] = 2
    model_cls: ClassVar[type[Model]]
    model_kwargs: ClassVar[Mapping[str, Any]]

    factory: TriplesFactory
    batch: MappedTriples
    model: Model

    def setUp(self) -> None:
        """Prepare model."""
        self.factory = Nations().training
        self.batch = self.factory.mapped_triples[: self.batch_size, :]
        self.model = self.model_cls(
            triples_factory=self.factory,
            **self.model_kwargs,
        )


class CleanerTestCase(GenericTestCase[Cleaner]):
    """Test cases for cleaner."""

    def post_instantiation_hook(self) -> None:
        """Prepare triples."""
        self.dataset = Nations()
        self.all_entities = set(range(self.dataset.num_entities))
        self.mapped_triples = self.dataset.training.mapped_triples
        # unfavourable split to ensure that cleanup is necessary
        self.reference, self.other = torch.split(
            self.mapped_triples,
            split_size_or_sections=[24, self.mapped_triples.shape[0] - 24],
            dim=0,
        )
        # check for unclean split
        assert get_entities(self.reference) != self.all_entities

    def test_cleanup_pair(self):
        """Test cleanup_pair."""
        reference_clean, other_clean = self.instance.cleanup_pair(
            reference=self.reference,
            other=self.other,
            random_state=42,
        )
        # check that no triple got lost
        assert triple_tensor_to_set(self.mapped_triples) == triple_tensor_to_set(
            torch.cat(
                [
                    reference_clean,
                    other_clean,
                ],
                dim=0,
            )
        )
        # check that triples where only moved from other to reference
        assert is_triple_tensor_subset(self.reference, reference_clean)
        assert is_triple_tensor_subset(other_clean, self.other)
        # check that all entities occur in reference
        assert get_entities(reference_clean) == self.all_entities

    def test_call(self):
        """Test call."""
        triples_groups = [self.reference] + list(torch.split(self.other, split_size_or_sections=3, dim=0))
        clean_groups = self.instance(triples_groups=triples_groups, random_state=42)
        assert all(torch.is_tensor(triples) and triples.dtype for triples in clean_groups)


class SplitterTestCase(GenericTestCase[Splitter]):
    """Test cases for triples splitter."""

    def post_instantiation_hook(self) -> None:
        """Prepare data."""
        dataset = Nations()
        self.all_entities = set(range(dataset.num_entities))
        self.mapped_triples = dataset.training.mapped_triples

    def _test_split(self, ratios: Union[float, Sequence[float]], exp_parts: int):
        """Test splitting."""
        splitted = self.instance.split(
            mapped_triples=self.mapped_triples,
            ratios=ratios,
            random_state=None,
        )
        assert len(splitted) == exp_parts
        # check that no triple got lost
        assert triple_tensor_to_set(self.mapped_triples) == set().union(
            *(triple_tensor_to_set(triples) for triples in splitted)
        )
        # check that all entities are covered in first part
        assert triple_tensor_to_set(splitted[0]) == self.all_entities


class EvaluatorTestCase(unittest_templates.GenericTestCase[Evaluator]):
    """A test case for quickly defining common tests for evaluators models."""

    # the model
    model: Model

    # Settings
    batch_size: int = 8
    embedding_dim: int = 7

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        self.dataset = Nations()
        return super()._pre_instantiation_hook(kwargs=kwargs)

    @property
    def factory(self) -> CoreTriplesFactory:
        """Return the evaluation factory."""
        return self.dataset.validation

    def post_instantiation_hook(self) -> None:  # noqa: D102
        # Use small model (untrained)
        self.model = TransE(triples_factory=self.factory, embedding_dim=self.embedding_dim)

    def _get_input(
        self,
        inverse: bool = False,
    ) -> tuple[torch.LongTensor, torch.FloatTensor, Optional[torch.BoolTensor]]:
        # Get batch
        hrt_batch = self.factory.mapped_triples[: self.batch_size].to(self.model.device)

        # Compute scores
        if inverse:
            scores = self.model.score_h(rt_batch=hrt_batch[:, 1:])
        else:
            scores = self.model.score_t(hr_batch=hrt_batch[:, :2])

        # Compute mask only if required
        if self.instance.requires_positive_mask:
            # TODO: Re-use filtering code
            triples = self.factory.mapped_triples.to(self.model.device)
            if inverse:
                sel_col, start_col = 0, 1
            else:
                sel_col, start_col = 2, 0
            stop_col = start_col + 2

            # shape: (batch_size, num_triples)
            triple_mask = (triples[None, :, start_col:stop_col] == hrt_batch[:, None, start_col:stop_col]).all(dim=-1)
            batch_indices, triple_indices = triple_mask.nonzero(as_tuple=True)
            entity_indices = triples[triple_indices, sel_col]

            # shape: (batch_size, num_entities)
            mask = torch.zeros_like(scores, dtype=torch.bool)
            mask[batch_indices, entity_indices] = True
        else:
            mask = None

        return hrt_batch, scores, mask

    def test_process_tail_scores_(self) -> None:
        """Test the evaluator's ``process_tail_scores_()`` function."""
        hrt_batch, scores, mask = self._get_input()
        true_scores = scores[torch.arange(0, hrt_batch.shape[0]), hrt_batch[:, 2]][:, None]
        self.instance.process_scores_(
            hrt_batch=hrt_batch,
            target=LABEL_TAIL,
            true_scores=true_scores,
            scores=scores,
            dense_positive_mask=mask,
        )

    def test_process_head_scores_(self) -> None:
        """Test the evaluator's ``process_head_scores_()`` function."""
        hrt_batch, scores, mask = self._get_input(inverse=True)
        true_scores = scores[torch.arange(0, hrt_batch.shape[0]), hrt_batch[:, 0]][:, None]
        self.instance.process_scores_(
            hrt_batch=hrt_batch,
            target=LABEL_HEAD,
            true_scores=true_scores,
            scores=scores,
            dense_positive_mask=mask,
        )

    def _process_batches(self):
        """Process one batch per side."""
        hrt_batch, scores, mask = self._get_input()
        true_scores = scores[torch.arange(0, hrt_batch.shape[0]), hrt_batch[:, 2]][:, None]
        for target in (LABEL_HEAD, LABEL_TAIL):
            self.instance.process_scores_(
                hrt_batch=hrt_batch,
                target=target,
                true_scores=true_scores,
                scores=scores,
                dense_positive_mask=mask,
            )
        return hrt_batch, scores, mask

    def test_finalize(self) -> None:
        """Test the finalize() function."""
        # Process one batch
        hrt_batch, scores, mask = self._process_batches()

        result = self.instance.finalize()
        assert isinstance(result, MetricResults)

        self._validate_result(
            result=result,
            data={"batch": hrt_batch, "scores": scores, "mask": mask},
        )

    def _validate_result(
        self,
        result: MetricResults,
        data: dict[str, torch.Tensor],
    ):
        logger.warning(f"{self.__class__.__name__} did not overwrite _validate_result.")

    def test_pipeline(self):
        """Test interaction with pipeline."""
        pipeline(
            training=self.factory,
            testing=self.factory,
            model="distmult",
            evaluator=evaluator_resolver.normalize_cls(self.cls),
            evaluator_kwargs=self.instance_kwargs,
            training_kwargs=dict(
                num_epochs=1,
            ),
        )


class AnchorSelectionTestCase(GenericTestCase[pykeen.nn.node_piece.AnchorSelection]):
    """Tests for anchor selection."""

    num_anchors: int = 7
    num_entities: int = 33
    num_triples: int = 101
    edge_index: numpy.ndarray

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        """Prepare kwargs."""
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["num_anchors"] = self.num_anchors
        return kwargs

    def post_instantiation_hook(self) -> None:
        """Prepare edge index."""
        generator = numpy.random.default_rng(seed=42)
        self.edge_index = generator.integers(low=0, high=self.num_entities, size=(2, self.num_triples))

    def test_call(self):
        """Test __call__."""
        anchors = self.instance(edge_index=self.edge_index)
        # shape
        assert len(anchors) == self.num_anchors
        # value range
        assert (0 <= anchors).all()
        assert (anchors < self.num_entities).all()
        # no duplicates
        assert len(set(anchors.tolist())) == len(anchors)


class AnchorSearcherTestCase(GenericTestCase[pykeen.nn.node_piece.AnchorSearcher]):
    """Tests for anchor search."""

    num_entities = 33
    k: int = 2
    edge_index: numpy.ndarray
    anchors: numpy.ndarray

    def post_instantiation_hook(self) -> None:
        """Prepare circular edge index."""
        self.edge_index = numpy.stack(
            [numpy.arange(self.num_entities), (numpy.arange(self.num_entities) + 1) % self.num_entities],
            axis=0,
        )
        self.anchors = numpy.arange(0, self.num_entities, 10)

    def test_call(self):
        """Test __call__."""
        tokens = self.instance(edge_index=self.edge_index, anchors=self.anchors, k=self.k)
        # shape
        assert tokens.shape == (self.num_entities, self.k)
        # value range
        assert (tokens >= -1).all()
        assert (tokens < len(self.anchors)).all()
        # no duplicates
        for row in tokens.tolist():
            self.assertDictEqual({k: v for k, v in Counter(row).items() if k >= 0 and v > 1}, {}, msg="duplicate token")


class TokenizerTestCase(GenericTestCase[pykeen.nn.node_piece.Tokenizer]):
    """Tests for tokenization."""

    num_tokens: int = 2
    factory: CoreTriplesFactory

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        """Prepare triples."""
        self.factory = Nations().training
        return {}

    def test_call(self):
        """Test __call__."""
        vocabulary_size, tokens = self.instance(
            mapped_triples=self.factory.mapped_triples,
            num_tokens=self.num_tokens,
            num_entities=self.factory.num_entities,
            num_relations=self.factory.num_relations,
        )
        assert isinstance(vocabulary_size, int)
        assert vocabulary_size > 0
        # shape
        assert tokens.shape == (self.factory.num_entities, self.num_tokens)
        # value range
        assert (tokens >= -1).all()
        # no repetition, except padding idx
        for row in tokens.tolist():
            self.assertDictEqual({k: v for k, v in Counter(row).items() if k >= 0 and v > 1}, {}, msg="duplicate token")


class NodePieceTestCase(RepresentationTestCase):
    """General test case for node piece representations."""

    cls = pykeen.nn.node_piece.NodePieceRepresentation
    num_relations: ClassVar[int] = 7
    num_triples: ClassVar[int] = 31

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        kwargs["triples_factory"] = generation.generate_triples_factory(
            num_entities=self.max_id,
            num_relations=self.num_relations,
            num_triples=self.num_triples,
            create_inverse_triples=False,
        )
        # inferred from triples factory
        kwargs.pop("max_id")
        return kwargs

    def test_estimate_diversity(self):
        """Test estimating diversity."""
        diversity = self.instance.estimate_diversity()
        assert len(diversity.uniques_per_representation) == len(self.instance.base)
        assert 0.0 <= diversity.uniques_total <= 1.0


class EvaluationLoopTestCase(GenericTestCase[pykeen.evaluation.evaluation_loop.EvaluationLoop]):
    """Tests for evaluation loops."""

    batch_size: int = 2
    factory: CoreTriplesFactory

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        self.factory = Nations().training
        kwargs["model"] = FixedModel(triples_factory=self.factory)
        return kwargs

    @torch.inference_mode()
    def test_process_batch(self):
        """Test processing a single batch."""
        batch = next(iter(self.instance.get_loader(batch_size=self.batch_size)))
        self.instance.process_batch(batch=batch)


class EvaluationOnlyModelTestCase(unittest_templates.GenericTestCase[pykeen.models.EvaluationOnlyModel]):
    """Test case for evaluation only models."""

    #: The batch size
    batch_size: int = 3

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
        dataset = Nations()
        self.factory = kwargs["triples_factory"] = dataset.training
        return kwargs

    def _verify(self, scores: torch.FloatTensor):
        """Verify scores."""

    def test_score_t(self):
        """Test score_t."""
        hr_batch = self.factory.mapped_triples[torch.randint(self.factory.num_triples, size=(self.batch_size,))][:, :2]
        scores = self.instance.score_t(hr_batch=hr_batch)
        assert scores.shape == (self.batch_size, self.factory.num_entities)
        self._verify(scores)

    def test_score_h(self):
        """Test score_h."""
        rt_batch = self.factory.mapped_triples[torch.randint(self.factory.num_triples, size=(self.batch_size,))][:, 1:]
        scores = self.instance.score_h(rt_batch=rt_batch)
        assert scores.shape == (self.batch_size, self.factory.num_entities)
        self._verify(scores)


class RankBasedMetricTestCase(unittest_templates.GenericTestCase[RankBasedMetric]):
    """A test for rank-based metrics."""

    #: the maximum number of candidates
    max_num_candidates: int = 17

    #: the number of ranks
    num_ranks: int = 33

    #: the number of samples to use for monte-carlo estimation
    num_samples: int = 1_000

    #: the number of candidates for each individual ranking task
    num_candidates: numpy.ndarray

    #: the ranks for each individual ranking task
    ranks: numpy.ndarray

    def post_instantiation_hook(self) -> None:
        """Generate a coherent rank & candidate pair."""
        self.ranks, self.num_candidates = generate_num_candidates_and_ranks(
            num_ranks=self.num_ranks,
            max_num_candidates=self.max_num_candidates,
            seed=42,
        )

    def test_docdata(self):
        """Test the docdata contents of the metric."""
        self.assertTrue(hasattr(self.instance, "increasing"))
        self.assertNotEqual(
            "", self.cls.__doc__.splitlines()[0].strip(), msg="First line of docstring should not be blank"
        )
        self.assertIsNotNone(get_docdata(self.instance), msg="No docdata available")
        self.assertIsNotNone(getattr_or_docdata(self.cls, "link"))
        self.assertIsNotNone(getattr_or_docdata(self.cls, "name"))
        self.assertIsNotNone(getattr_or_docdata(self.cls, "description"))
        self.assertIsNotNone(self.instance.key)

    def _test_call(self, ranks: numpy.ndarray, num_candidates: Optional[numpy.ndarray]):
        """Verify call."""
        x = self.instance(ranks=ranks, num_candidates=num_candidates)
        # data type
        assert isinstance(x, float)
        # value range
        self.assertIn(x, self.instance.value_range.approximate(epsilon=1.0e-08))

    def test_call(self):
        """Test __call__."""
        self._test_call(ranks=self.ranks, num_candidates=self.num_candidates)

    def test_call_best(self):
        """Test __call__ with optimal ranks."""
        self._test_call(ranks=numpy.ones(shape=(self.num_ranks,)), num_candidates=self.num_candidates)

    def test_call_worst(self):
        """Test __call__ with worst ranks."""
        self._test_call(ranks=self.num_candidates, num_candidates=self.num_candidates)

    def test_call_no_candidates(self):
        """Test __call__ without candidates."""
        if self.instance.needs_candidates:
            raise SkipTest(f"{self.instance} requires candidates.")
        self._test_call(ranks=self.ranks, num_candidates=None)

    def test_increasing(self):
        """Test correct increasing annotation."""
        x, y = (
            self.instance(ranks=ranks, num_candidates=self.num_candidates)
            for ranks in [
                # original ranks
                self.ranks,
                # better ranks
                numpy.clip(self.ranks - 1, a_min=1, a_max=None),
            ]
        )
        if self.instance.increasing:
            self.assertLessEqual(x, y)
        else:
            self.assertLessEqual(y, x)

    def _test_expectation(self, weights: Optional[numpy.ndarray]):
        """Test the numeric expectation is close to the closed form one."""
        try:
            closed = self.instance.expected_value(num_candidates=self.num_candidates, weights=weights)
        except NoClosedFormError as error:
            raise SkipTest("no implementation of closed-form expectation") from error

        generator = numpy.random.default_rng(seed=0)
        low, simulated, high = self.instance.numeric_expected_value_with_ci(
            num_candidates=self.num_candidates,
            num_samples=self.num_samples,
            generator=generator,
            weights=weights,
        )
        self.assertLessEqual(low, closed)
        self.assertLessEqual(closed, high)

    def test_expectation(self):
        """Test the numeric expectation is close to the closed form one."""
        self._test_expectation(weights=None)

    def test_expectation_weighted(self):
        """Test for weighted expectation."""
        self._test_expectation(weights=self._generate_weights())

    def _test_variance(self, weights: Optional[numpy.ndarray]):
        """Test the numeric variance is close to the closed form one."""
        try:
            closed = self.instance.variance(num_candidates=self.num_candidates, weights=weights)
        except NoClosedFormError as error:
            raise SkipTest("no implementation of closed-form variance") from error

        # variances are non-negative
        self.assertLessEqual(0, closed)

        generator = numpy.random.default_rng(seed=0)
        low, simulated, high = self.instance.numeric_variance_with_ci(
            num_candidates=self.num_candidates,
            num_samples=self.num_samples,
            generator=generator,
            weights=weights,
        )
        self.assertLessEqual(low, closed)
        self.assertLessEqual(closed, high)

    def test_variance(self):
        """Test the numeric variance is close to the closed form one."""
        self._test_variance(weights=None)

    def test_variance_weighted(self):
        """Test the weighted numeric variance is close to the closed form one."""
        self._test_variance(weights=self._generate_weights())

    def _generate_weights(self):
        """Generate weights."""
        if not self.instance.supports_weights:
            raise SkipTest(f"{self.instance} does not support weights")
        # generate random weights such that sum = n
        generator = numpy.random.default_rng(seed=21)
        weights = generator.random(size=self.num_candidates.shape)
        weights = self.num_ranks * weights / weights.sum()
        return weights

    def test_different_to_base_metric(self):
        """Check whether the value is different from the base metric (relevant for adjusted metrics)."""
        if not isinstance(self.instance, DerivedRankBasedMetric):
            self.skipTest("no base metric")
        base_instance = rank_based_metric_resolver.make(self.instance.base_cls)
        base_factor = 1 if base_instance.increasing else -1
        self.assertNotEqual(
            self.instance(ranks=self.ranks, num_candidates=self.num_candidates),
            base_factor * base_instance(ranks=self.ranks, num_candidates=self.num_candidates),
        )

    def test_weights_direction(self):
        """Test monotonicity of weighting."""
        if not self.instance.supports_weights:
            raise SkipTest(f"{self.instance} does not support weights")

        # for sanity checking: give the largest weight to best rank => should improve
        idx = self.ranks.argmin()
        weights = numpy.ones_like(self.ranks, dtype=float)
        weights[idx] = 2.0
        weighted = self.instance(ranks=self.ranks, num_candidates=self.num_candidates, weights=weights)
        unweighted = self.instance(ranks=self.ranks, num_candidates=self.num_candidates, weights=None)
        if self.instance.increasing:  # increasing = larger is better => weighted should be better
            self.assertLessEqual(unweighted, weighted)
        else:
            self.assertLessEqual(weighted, unweighted)

    def test_weights_coherence(self):
        """Test coherence for weighted metrics & metric in repeated array."""
        if not self.instance.supports_weights:
            raise SkipTest(f"{self.instance} does not support weights")

        # generate two versions
        generator = numpy.random.default_rng(seed=21)
        repeats = generator.integers(low=1, high=10, size=self.ranks.shape)

        # 1. repeat each rank/candidate pair a random number of times
        repeated_ranks, repeated_num_candidates = [], []
        for rank, num_candidates, repeat in zip(self.ranks, self.num_candidates, repeats):
            repeated_ranks.append(numpy.full(shape=(repeat,), fill_value=rank))
            repeated_num_candidates.append(numpy.full(shape=(repeat,), fill_value=num_candidates))
        repeated_ranks = numpy.concatenate(repeated_ranks)
        repeated_num_candidates = numpy.concatenate(repeated_num_candidates)
        value_repeat = self.instance(ranks=repeated_ranks, num_candidates=repeated_num_candidates, weights=None)

        # 2. do not repeat, but assign a corresponding weight
        weights = repeats.astype(float)
        value_weighted = self.instance(ranks=self.ranks, num_candidates=self.num_candidates, weights=weights)

        self.assertAlmostEqual(value_repeat, value_weighted, delta=2)


class MetricResultTestCase(unittest_templates.GenericTestCase[MetricResults]):
    """Test for metric results."""

    def test_flat_dict(self):
        """Test to_flat_dict."""
        flat_dict = self.instance.to_flat_dict()
        # check flatness
        self.assertIsInstance(flat_dict, dict)
        for key, value in flat_dict.items():
            self.assertIsInstance(key, str)
            # TODO: does this suffice, or do we really need float as datatype?
            self.assertIsInstance(value, (float, int), msg=key)
        self._verify_flat_dict(flat_dict)

    def _verify_flat_dict(self, flat_dict: Mapping[str, Any]):
        pass


class TrainingInstancesTestCase(unittest_templates.GenericTestCase[Instances]):
    """Test for training instances."""

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        self.factory = Nations().training
        return {}

    @abstractmethod
    def _get_expected_length(self) -> int:
        raise NotImplementedError

    def test_getitem(self):
        """Test __getitem__."""
        self.instance: Instances
        assert self.instance[0] is not None

    def test_len(self):
        """Test __len__."""
        self.assertEqual(len(self.instance), self._get_expected_length())

    def test_data_loader(self):
        """Test usage with data loader."""
        for batch in torch.utils.data.DataLoader(
            dataset=self.instance, batch_size=2, shuffle=True, collate_fn=self.instance.get_collator()
        ):
            assert batch is not None


class BatchSLCWATrainingInstancesTestCase(unittest_templates.GenericTestCase[BaseBatchedSLCWAInstances]):
    """Test for batched sLCWA training instances."""

    batch_size: int = 2
    num_negatives_per_positive: int = 3
    kwargs = dict(
        batch_size=batch_size,
        negative_sampler_kwargs=dict(
            num_negs_per_pos=num_negatives_per_positive,
        ),
    )

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        self.factory = Nations().training
        kwargs["mapped_triples"] = self.factory.mapped_triples
        return kwargs

    def test_data_loader(self):
        """Test data loader."""
        for batch in torch.utils.data.DataLoader(dataset=self.instance, batch_size=None):
            assert isinstance(batch, SLCWABatch)
            assert batch.positives.shape == (self.batch_size, 3)
            assert batch.negatives.shape == (self.batch_size, self.num_negatives_per_positive, 3)
            assert batch.masks is None

    def test_length(self):
        """Test length."""
        assert len(self.instance) == len(list(iter(self.instance)))

    def test_data_loader_multiprocessing(self):
        """Test data loader with multiple workers."""
        self.assertEqual(
            sum(
                batch.positives.shape[0]
                for batch in torch.utils.data.DataLoader(dataset=self.instance, batch_size=None, num_workers=2)
            ),
            self.factory.num_triples,
        )


class TrainingCallbackTestCase(unittest_templates.GenericTestCase[TrainingCallback]):
    """Base test case for training callbacks."""

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:  # noqa: D102
        kwargs = super()._pre_instantiation_hook(kwargs)
        self.dataset = Nations()
        return kwargs

    def test_pipeline(self):
        """Test running a small pipeline with the provided callback."""
        pipeline(
            dataset=self.dataset,
            model="distmult",
            training_kwargs=dict(
                callbacks=self.instance,
            ),
        )


class GraphPairCombinatorTestCase(unittest_templates.GenericTestCase[GraphPairCombinator]):
    """Base test for graph pair combination methods."""

    def _add_labels(self, tf: CoreTriplesFactory) -> TriplesFactory:
        """Add artificial labels to a triples factory."""
        entity_to_id = {f"e_{i}": i for i in range(tf.num_entities)}
        relation_to_id = {f"r_{i}": i for i in range(tf.num_relations)}
        return TriplesFactory(
            mapped_triples=tf.mapped_triples, entity_to_id=entity_to_id, relation_to_id=relation_to_id
        )

    def _test_combination(self, labels: bool):
        # generate random triples factories
        left, right = (generation.generate_triples_factory(random_state=random_state) for random_state in (0, 1))
        # generate random alignment
        left_idx, right_idx = torch.stack([torch.arange(left.num_entities), torch.randperm(left.num_entities)])[
            : left.num_entities // 2
        ].numpy()
        # add label information if necessary
        if labels:
            left, right = (self._add_labels(tf) for tf in (left, right))
            left_idx = [left.entity_id_to_label[i] for i in left_idx]
            right_idx = [right.entity_id_to_label[i] for i in right_idx]
        # prepare alignment data frame
        alignment = pandas.DataFrame(data={EA_SIDE_LEFT: left_idx, EA_SIDE_RIGHT: right_idx})
        # call
        tf_both, alignment_t = self.instance(left=left, right=right, alignment=alignment)
        # check
        assert type(tf_both) is type(left)
        assert alignment_t.ndim == 2
        assert alignment_t.shape[0] == 2
        assert alignment_t.shape[1] <= len(alignment)

    def test_combination_label(self):
        """Test combination with labels."""
        self._test_combination(labels=True)

    def test_combination_id(self):
        """Test combination without labels."""
        self._test_combination(labels=False)

    def test_manual(self):
        """
        Smoke-test on a manual example.

        cf. https://github.com/pykeen/pykeen/pull/893#discussion_r861553903
        """
        left_tf = TriplesFactory.from_labeled_triples(
            pandas.DataFrame(
                [
                    ["la", "0", "lb"],
                    ["lb", "0", "lc"],
                    ["la", "1", "ld"],
                    ["le", "1", "lg"],
                    ["lh", "1", "lg"],
                ],
            ).values
        )
        right_tf = TriplesFactory.from_labeled_triples(
            pandas.DataFrame(
                [
                    ["ra", "2", "rb"],
                    ["ra", "2", "rc"],
                    ["rc", "3", "rd"],
                    ["re", "3", "rg"],
                    ["rh", "3", "rg"],
                ],
            ).values
        )
        test_links = pandas.DataFrame(
            [
                ["ld", "rd"],
                ["le", "re"],
                ["lg", "rg"],
                ["lh", "rh"],
            ],
            columns=[EA_SIDE_LEFT, EA_SIDE_RIGHT],
        )
        combined_tf, alignment_t = self.instance(left=left_tf, right=right_tf, alignment=test_links)
        self._verify_manual(combined_tf=combined_tf)

    @abstractmethod
    def _verify_manual(self, combined_tf: CoreTriplesFactory):
        """Verify the result of the combination of the manual example."""


class EarlyStopperTestCase(unittest_templates.GenericTestCase[EarlyStopper]):
    """Base test for early stopper."""

    cls = EarlyStopper

    #: The window size used by the early stopper
    patience: int = 2
    #: The mock losses the mock evaluator will return
    mock_losses: list[float] = [10.0, 9.0, 8.0, 9.0, 8.0, 8.0]
    #: The (zeroed) index  - 1 at which stopping will occur
    stop_constant: int = 4
    #: The minimum improvement
    delta: float = 0.0
    #: The best results
    best_results: list[float] = [10.0, 9.0, 8.0, 8.0, 8.0]

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        kwargs = super()._pre_instantiation_hook(kwargs)
        nations = Nations()
        kwargs.update(
            dict(
                evaluator=MockEvaluator(
                    key=("hits_at_10", SIDE_BOTH, RANK_REALISTIC),
                    values=self.mock_losses,
                    # Set automatic_memory_optimization to false for tests
                    automatic_memory_optimization=False,
                ),
                model=FixedModel(triples_factory=nations.training),
                training_triples_factory=nations.training,
                evaluation_triples_factory=nations.validation,
                patience=self.patience,
                relative_delta=self.delta,
                larger_is_better=False,
                best_model_path=pathlib.Path(tempfile.gettempdir(), "test-best-model-weights.pt"),
            )
        )
        return kwargs

    def test_initialization(self):
        """Test warm-up phase."""
        for epoch in range(self.patience):
            should_stop = self.instance.should_stop(epoch=epoch)
            assert not should_stop

    def test_result_processing(self):
        """Test that the mock evaluation of the early stopper always gives the right loss."""
        for epoch in range(len(self.mock_losses)):
            # Step early stopper
            should_stop = self.instance.should_stop(epoch=epoch)

            if should_stop:
                break
            else:
                # check storing of results
                assert self.instance.results == self.mock_losses[: epoch + 1]
                assert self.instance.best_metric == self.best_results[epoch]

    def test_should_stop(self):
        """Test that the stopper knows when to stop."""
        for epoch in range(self.stop_constant):
            self.assertFalse(self.instance.should_stop(epoch=epoch))
        self.assertTrue(self.instance.should_stop(epoch=self.stop_constant))

    def test_result_logging(self):
        """Test whether result logger is called properly."""
        self.instance.result_tracker = mock_tracker = Mock()
        self.instance.should_stop(epoch=0)
        log_metrics = mock_tracker.log_metrics
        self.assertIsInstance(log_metrics, Mock)
        log_metrics.assert_called_once()
        _, call_args = log_metrics.call_args_list[0]
        self.assertIn("step", call_args)
        self.assertEqual(0, call_args["step"])
        self.assertIn("prefix", call_args)
        self.assertEqual("validation", call_args["prefix"])

    def test_serialization(self):
        """Test for serialization."""
        summary = self.instance.get_summary_dict()
        new_stopper = EarlyStopper(
            # not needed for test
            model=...,
            evaluator=...,
            training_triples_factory=...,
            evaluation_triples_factory=...,
        )
        new_stopper._write_from_summary_dict(**summary)
        for key in summary.keys():
            assert getattr(self.instance, key) == getattr(new_stopper, key)


class CombinationTestCase(unittest_templates.GenericTestCase[pykeen.nn.combination.Combination]):
    """Test for combinations."""

    input_dims: Sequence[Sequence[int]] = [[5, 7], [5, 7, 11]]

    def _iter_input_shapes(self) -> Iterable[Sequence[tuple[int, ...]]]:
        """Iterate over test input shapes."""
        for prefix_shape in [tuple(), (2,), (2, 3)]:
            for input_dims in self.input_dims:
                yield [prefix_shape + (input_dim,) for input_dim in input_dims]

    def _create_input(self, input_shapes: Sequence[tuple[int, ...]]) -> Sequence[torch.FloatTensor]:
        return [torch.empty(size=size) for size in input_shapes]

    def test_inputs(self):
        """Test that the test uses at least one input shape."""
        assert list(self._iter_input_shapes())

    def test_forward(self):
        """Test forward call."""
        for input_shapes in self._iter_input_shapes():
            xs = self._create_input(input_shapes=input_shapes)

            # verify that the input is valid
            assert len(xs) == len(input_shapes)
            assert all(x.shape == shape for x, shape in zip(xs, input_shapes))

            # combine
            x = self.instance(xs=xs)
            self.assertIsInstance(x, torch.Tensor)

            # verify shape
            output_shape = self.instance.output_shape(input_shapes)
            self.assertTupleEqual(x.shape, output_shape)


class TextEncoderTestCase(unittest_templates.GenericTestCase[pykeen.nn.text.TextEncoder]):
    """Base tests for text encoders."""

    def test_encode(self):
        """Test encoding of texts."""
        labels = ["A first sentence", "some other label"]
        x = self.instance.encode_all(labels=labels)
        assert torch.is_tensor(x)
        assert x.shape[0] == len(labels)


class PredictionTestCase(unittest_templates.GenericTestCase[pykeen.predict.Predictions]):
    """Tests for prediction post-processing."""

    # to be initialized in subclass
    df: pandas.DataFrame

    def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]:
        kwargs = super()._pre_instantiation_hook(kwargs)
        self.dataset = Nations()
        kwargs["factory"] = self.dataset.training
        return kwargs

    def test_contains(self):
        """Test contains method."""
        pred_annotated = self.instance.add_membership_columns(**self.dataset.factory_dict)
        assert isinstance(pred_annotated, pykeen.predict.Predictions)
        df_annot = pred_annotated.df
        # no column has been removed
        assert set(df_annot.columns).issuperset(self.df.columns)
        # all old columns are unmodified
        for col in self.df.columns:
            assert (df_annot[col] == self.df[col]).all()
        # new columns are boolean
        for new_col in set(df_annot.columns).difference(self.df.columns):
            assert df_annot[new_col].dtype == bool

    def test_filter(self):
        """Test filter method."""
        pred_filtered = self.instance.filter_triples(*self.dataset.factory_dict.values())
        assert isinstance(pred_filtered, pykeen.predict.Predictions)
        df_filtered = pred_filtered.df
        # no columns have been added
        assert set(df_filtered.columns) == set(self.df.columns)
        # check subset relation
        assert set(df_filtered.itertuples()).issubset(self.df.itertuples())


class ScoreConsumerTests(unittest_templates.GenericTestCase[pykeen.predict.ScoreConsumer]):
    """Tests for score consumers."""

    batch_size: int = 2
    num_entities: int = 3
    target: Target = LABEL_TAIL

    def test_consumption(self):
        """Test calling."""
        generator = torch.manual_seed(seed=42)
        batch = torch.randint(self.num_entities, size=(self.batch_size, 2), generator=generator)
        scores = torch.rand(self.batch_size, self.num_entities)
        self.instance(batch=batch, target=self.target, scores=scores)
        self.check()

    def check(self):
        """Perform additional verification."""
        pass