|
2 | 2 |
|
3 | 3 | """Test cases for PyKEEN."""
|
4 | 4 | import inspect
|
| 5 | +import itertools |
5 | 6 | import logging
|
6 | 7 | import os
|
7 | 8 | import pathlib
|
|
61 | 62 | from pykeen.datasets.mocks import create_inductive_dataset
|
62 | 63 | from pykeen.datasets.nations import NATIONS_TEST_PATH, NATIONS_TRAIN_PATH
|
63 | 64 | from pykeen.evaluation import Evaluator, MetricResults, evaluator_resolver
|
64 |
| -from pykeen.losses import Loss, PairwiseLoss, PointwiseLoss, SetwiseLoss, UnsupportedLabelSmoothingError |
| 65 | +from pykeen.losses import Loss, PairwiseLoss, PointwiseLoss, SetwiseLoss, UnsupportedLabelSmoothingError, loss_resolver |
65 | 66 | from pykeen.metrics import rank_based_metric_resolver
|
66 | 67 | from pykeen.metrics.ranking import (
|
67 | 68 | DerivedRankBasedMetric,
|
@@ -241,6 +242,27 @@ def tearDown(self) -> None:
|
241 | 242 | self.directory.cleanup()
|
242 | 243 |
|
243 | 244 |
|
| 245 | +def iter_from_space(key: str, space: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]: |
| 246 | + """Iterate over some configurations for a single hyperparameter.""" |
| 247 | + typ = space["type"] |
| 248 | + if typ == "categorical": |
| 249 | + for choice in space["choices"]: |
| 250 | + yield {key: choice} |
| 251 | + elif typ in (float, int): |
| 252 | + yield {key: space["low"]} |
| 253 | + yield {key: space["high"]} |
| 254 | + else: |
| 255 | + raise NotImplementedError(typ) |
| 256 | + |
| 257 | + |
| 258 | +def iter_hpo_configs(hpo_default: Mapping[str, Mapping[str, Any]]) -> Iterable[Mapping[str, Any]]: |
| 259 | + """Iterate over some representative configurations from the HPO default.""" |
| 260 | + for combination in itertools.product( |
| 261 | + *(iter_from_space(key=key, space=space) for key, space in hpo_default.items()) |
| 262 | + ): |
| 263 | + yield ChainMap(*combination) |
| 264 | + |
| 265 | + |
244 | 266 | class LossTestCase(GenericTestCase[Loss]):
|
245 | 267 | """Base unittest for loss functions."""
|
246 | 268 |
|
@@ -354,8 +376,21 @@ def test_optimization_direction_slcwa(self):
|
354 | 376 | def test_hpo_defaults(self):
|
355 | 377 | """Test hpo defaults."""
|
356 | 378 | signature = inspect.signature(self.cls.__init__)
|
| 379 | + # check for invalid keys |
357 | 380 | invalid_keys = set(self.cls.hpo_default.keys()).difference(signature.parameters)
|
358 | 381 | assert not invalid_keys
|
| 382 | + # check that each parameter without a default occurs |
| 383 | + required_parameters = { |
| 384 | + parameter_name |
| 385 | + for parameter_name, parameter in signature.parameters.items() |
| 386 | + if parameter.default == inspect.Parameter.empty |
| 387 | + and parameter.kind not in {inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL} |
| 388 | + }.difference({"self"}) |
| 389 | + missing_required = required_parameters.difference(self.cls.hpo_default.keys()) |
| 390 | + assert not missing_required |
| 391 | + # try to instantiate loss for some configurations in the HPO search space |
| 392 | + for config in iter_hpo_configs(self.cls.hpo_default): |
| 393 | + assert loss_resolver.make(query=self.cls, pos_kwargs=config) |
359 | 394 |
|
360 | 395 |
|
361 | 396 | class PointwiseLossTestCase(LossTestCase):
|
|
0 commit comments