Skip to content

Commit c94213c

Browse files
authoredOct 7, 2023
☑️💸 Update Losses' HPO default range checks (#1337)
* [x] check instantiation of losses for a few configurations derived from the `hpo_default` * [x] fix an issue with the default values of `DoubleMarginLoss` ([bb88719](bb88719)) * [x] add suitable `hpo_default`s to the general - `MarginPairwiseLoss` - `DeltaPointwiseLoss` Fix #1334
1 parent b603e5a commit c94213c

File tree

3 files changed

+72
-7
lines changed

3 files changed

+72
-7
lines changed
 

‎src/pykeen/losses.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,18 @@ class MarginPairwiseLoss(PairwiseLoss):
455455
function like the ReLU or softmax, and $\lambda$ is the margin.
456456
"""
457457

458+
hpo_default: ClassVar[Mapping[str, Any]] = dict(
459+
margin=DEFAULT_MARGIN_HPO_STRATEGY,
460+
margin_activation=dict(
461+
type="categorical",
462+
choices=margin_activation_resolver.options,
463+
),
464+
)
465+
458466
def __init__(
459467
self,
460-
margin: float,
461-
margin_activation: Hint[nn.Module],
468+
margin: float = 1.0,
469+
margin_activation: Hint[nn.Module] = None,
462470
reduction: str = "mean",
463471
):
464472
r"""Initialize the margin loss instance.
@@ -730,6 +738,10 @@ def resolve_margin(
730738
:raises ValueError:
731739
In case of an invalid combination.
732740
"""
741+
# 0. default
742+
if all(p is None for p in (positive_margin, negative_margin, offset)):
743+
return 1.0, 0.0
744+
733745
# 1. positive & negative margin
734746
if positive_margin is not None and negative_margin is not None and offset is None:
735747
if negative_margin > positive_margin:
@@ -771,8 +783,8 @@ def resolve_margin(
771783
def __init__(
772784
self,
773785
*,
774-
positive_margin: Optional[float] = 1.0,
775-
negative_margin: Optional[float] = 0.0,
786+
positive_margin: Optional[float] = None,
787+
negative_margin: Optional[float] = None,
776788
offset: Optional[float] = None,
777789
positive_negative_balance: float = 0.5,
778790
margin_activation: Hint[nn.Module] = "relu",
@@ -905,6 +917,14 @@ class DeltaPointwiseLoss(PointwiseLoss):
905917
============================= ========== ====================== ======================================================== =============================================
906918
""" # noqa:E501
907919

920+
hpo_default: ClassVar[Mapping[str, Any]] = dict(
921+
margin=DEFAULT_MARGIN_HPO_STRATEGY,
922+
margin_activation=dict(
923+
type="categorical",
924+
choices=margin_activation_resolver.options,
925+
),
926+
)
927+
908928
def __init__(
909929
self,
910930
margin: Optional[float] = 0.0,

‎tests/cases.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
"""Test cases for PyKEEN."""
44
import inspect
5+
import itertools
56
import logging
67
import os
78
import pathlib
@@ -61,7 +62,7 @@
6162
from pykeen.datasets.mocks import create_inductive_dataset
6263
from pykeen.datasets.nations import NATIONS_TEST_PATH, NATIONS_TRAIN_PATH
6364
from pykeen.evaluation import Evaluator, MetricResults, evaluator_resolver
64-
from pykeen.losses import Loss, PairwiseLoss, PointwiseLoss, SetwiseLoss, UnsupportedLabelSmoothingError
65+
from pykeen.losses import Loss, PairwiseLoss, PointwiseLoss, SetwiseLoss, UnsupportedLabelSmoothingError, loss_resolver
6566
from pykeen.metrics import rank_based_metric_resolver
6667
from pykeen.metrics.ranking import (
6768
DerivedRankBasedMetric,
@@ -241,6 +242,27 @@ def tearDown(self) -> None:
241242
self.directory.cleanup()
242243

243244

245+
def iter_from_space(key: str, space: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]:
246+
"""Iterate over some configurations for a single hyperparameter."""
247+
typ = space["type"]
248+
if typ == "categorical":
249+
for choice in space["choices"]:
250+
yield {key: choice}
251+
elif typ in (float, int):
252+
yield {key: space["low"]}
253+
yield {key: space["high"]}
254+
else:
255+
raise NotImplementedError(typ)
256+
257+
258+
def iter_hpo_configs(hpo_default: Mapping[str, Mapping[str, Any]]) -> Iterable[Mapping[str, Any]]:
259+
"""Iterate over some representative configurations from the HPO default."""
260+
for combination in itertools.product(
261+
*(iter_from_space(key=key, space=space) for key, space in hpo_default.items())
262+
):
263+
yield ChainMap(*combination)
264+
265+
244266
class LossTestCase(GenericTestCase[Loss]):
245267
"""Base unittest for loss functions."""
246268

@@ -354,8 +376,21 @@ def test_optimization_direction_slcwa(self):
354376
def test_hpo_defaults(self):
355377
"""Test hpo defaults."""
356378
signature = inspect.signature(self.cls.__init__)
379+
# check for invalid keys
357380
invalid_keys = set(self.cls.hpo_default.keys()).difference(signature.parameters)
358381
assert not invalid_keys
382+
# check that each parameter without a default occurs
383+
required_parameters = {
384+
parameter_name
385+
for parameter_name, parameter in signature.parameters.items()
386+
if parameter.default == inspect.Parameter.empty
387+
and parameter.kind not in {inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL}
388+
}.difference({"self"})
389+
missing_required = required_parameters.difference(self.cls.hpo_default.keys())
390+
assert not missing_required
391+
# try to instantiate loss for some configurations in the HPO search space
392+
for config in iter_hpo_configs(self.cls.hpo_default):
393+
assert loss_resolver.make(query=self.cls, pos_kwargs=config)
359394

360395

361396
class PointwiseLossTestCase(LossTestCase):

‎tests/test_losses.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,18 @@ class SoftMarginrankingLossTestCase(cases.GMRLTestCase):
149149
cls = pykeen.losses.SoftMarginRankingLoss
150150

151151

152+
class MarginPairwiseLossTestCase(cases.GMRLTestCase):
153+
"""Tests for general margin pairwise loss."""
154+
155+
cls = pykeen.losses.MarginPairwiseLoss
156+
157+
158+
class DeltaPointwiseLossTestCase(cases.PointwiseLossTestCase):
159+
"""Tests for general delta point-wise loss."""
160+
161+
cls = pykeen.losses.DeltaPointwiseLoss
162+
163+
152164
class PairwiseLogisticLossTestCase(cases.GMRLTestCase):
153165
"""Tests for the pairwise logistic loss."""
154166

@@ -171,8 +183,6 @@ class TestLosses(unittest_templates.MetaTestCase[Loss]):
171183
pykeen.losses.PairwiseLoss,
172184
pykeen.losses.PointwiseLoss,
173185
pykeen.losses.SetwiseLoss,
174-
pykeen.losses.DeltaPointwiseLoss,
175-
pykeen.losses.MarginPairwiseLoss,
176186
pykeen.losses.AdversarialLoss,
177187
}
178188

0 commit comments

Comments
 (0)
Please sign in to comment.