Skip to content

Commit aadf5a4

Browse files
mberrcthoyt
andauthoredJan 29, 2023
😮🪑 Add OGB Evaluator (pykeen#948)
Add an evaluator to call OGB evaluation from within the pipeline. Currently blocked by `Evaluator.evaluate` receiving multiple `kwargs` (including `additional_filter_triples`) to be passed to `evaluate`, which are not supported by the OGB evaluate method. #### Dependencies * [x] pykeen#1088 --------- Co-authored-by: Charles Tapley Hoyt <[email protected]>
1 parent c15b7a4 commit aadf5a4

File tree

5 files changed

+64
-41
lines changed

5 files changed

+64
-41
lines changed
 

‎README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -287,12 +287,13 @@ The following 2 stoppers are implemented in PyKEEN.
287287

288288
### Evaluators
289289

290-
The following 4 evaluators are implemented in PyKEEN.
290+
The following 5 evaluators are implemented in PyKEEN.
291291

292292
| Name | Reference | Description |
293293
|------------------|-----------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------|
294294
| classification | [`pykeen.evaluation.ClassificationEvaluator`](https://pykeen.readthedocs.io/en/latest/api/pykeen.evaluation.ClassificationEvaluator.html) | An evaluator that uses a classification metrics. |
295295
| macrorankbased | [`pykeen.evaluation.MacroRankBasedEvaluator`](https://pykeen.readthedocs.io/en/latest/api/pykeen.evaluation.MacroRankBasedEvaluator.html) | Macro-average rank-based evaluation. |
296+
| ogb | [`pykeen.evaluation.OGBEvaluator`](https://pykeen.readthedocs.io/en/latest/api/pykeen.evaluation.OGBEvaluator.html) | A sampled, rank-based evaluator that applies a custom OGB evaluation. |
296297
| rankbased | [`pykeen.evaluation.RankBasedEvaluator`](https://pykeen.readthedocs.io/en/latest/api/pykeen.evaluation.RankBasedEvaluator.html) | A rank-based evaluator for KGE models. |
297298
| sampledrankbased | [`pykeen.evaluation.SampledRankBasedEvaluator`](https://pykeen.readthedocs.io/en/latest/api/pykeen.evaluation.SampledRankBasedEvaluator.html) | A rank-based evaluator using sampled negatives instead of all negatives. |
298299

‎src/pykeen/evaluation/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .classification_evaluator import ClassificationEvaluator, ClassificationMetricResults
88
from .evaluation_loop import LCWAEvaluationLoop
99
from .evaluator import Evaluator, MetricResults, evaluate
10+
from .ogb_evaluator import OGBEvaluator
1011
from .rank_based_evaluator import (
1112
MacroRankBasedEvaluator,
1213
RankBasedEvaluator,
@@ -23,6 +24,7 @@
2324
"MacroRankBasedEvaluator",
2425
"LCWAEvaluationLoop",
2526
"SampledRankBasedEvaluator",
27+
"OGBEvaluator",
2628
"ClassificationEvaluator",
2729
"ClassificationMetricResults",
2830
"evaluator_resolver",

‎src/pykeen/evaluation/ogb_evaluator.py

+41
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,44 @@
1414
from ..typing import RANK_REALISTIC, SIDE_BOTH, ExtendedTarget, MappedTriples, RankType, Target
1515

1616
__all__ = [
17+
"OGBEvaluator",
1718
"evaluate_ogb",
1819
]
1920

2021
logger = logging.getLogger(__name__)
2122

2223

24+
class OGBEvaluator(SampledRankBasedEvaluator):
25+
"""A sampled, rank-based evaluator that applies a custom OGB evaluation."""
26+
27+
# docstr-coverage: inherited
28+
def __init__(self, filtered: bool = False, **kwargs):
29+
if filtered:
30+
raise ValueError(
31+
"OGB evaluator is already filtered, but not dynamically like other evaluators because "
32+
"it requires pre-calculated filtered negative triples. Therefore, it is not allowed to "
33+
"accept filtered=True"
34+
)
35+
super().__init__(**kwargs, filtered=filtered)
36+
37+
def evaluate(
38+
self,
39+
model: Model,
40+
mapped_triples: MappedTriples,
41+
batch_size: Optional[int] = None,
42+
slice_size: Optional[int] = None,
43+
**kwargs,
44+
) -> MetricResults:
45+
"""Run :func:`evaluate_ogb` with this evaluator."""
46+
return evaluate_ogb(
47+
evaluator=self,
48+
model=model,
49+
mapped_triples=mapped_triples,
50+
batch_size=batch_size,
51+
**kwargs,
52+
)
53+
54+
2355
def evaluate_ogb(
2456
evaluator: SampledRankBasedEvaluator,
2557
model: Model,
@@ -52,6 +84,8 @@ def evaluate_ogb(
5284
if ogb is not installed
5385
:raises NotImplementedError:
5486
if `batch_size` is None, i.e., automatic batch size selection is selected
87+
:raises ValueError:
88+
if illegal ``additional_filter_triples`` argument is given in the kwargs
5589
"""
5690
try:
5791
import ogb.linkproppred
@@ -61,6 +95,13 @@ def evaluate_ogb(
6195
if batch_size is None:
6296
raise NotImplementedError("Automatic batch size selection not available for OGB evaluation.")
6397

98+
additional_filter_triples = kwargs.pop("additional_filter_triples", None)
99+
if additional_filter_triples is not None:
100+
raise ValueError(
101+
f"evaluate_ogb received additional_filter_triples={additional_filter_triples}. However, it uses "
102+
f"explicitly given filtered negative triples, and therefore shouldn't be passed any additional ones"
103+
)
104+
64105
class _OGBEvaluatorBridge(ogb.linkproppred.Evaluator):
65106
"""A wrapper around OGB's evaluator to support evaluation on non-OGB datasets."""
66107

‎src/pykeen/evaluation/rank_based_evaluator.py

+2-36
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,8 @@ def __init__(
587587
LABEL_HEAD: head_negatives,
588588
LABEL_TAIL: tail_negatives,
589589
}
590+
if additional_filter_triples is not None:
591+
logger.warning(f"Ignoring parameter additional_filter_triples={additional_filter_triples}")
590592

591593
# verify input
592594
for side, side_negatives in negatives.items():
@@ -630,42 +632,6 @@ def process_scores_(
630632
# TODO: should we give num_entities in the constructor instead of inferring it every time ranks are processed?
631633
self.num_entities = num_entities
632634

633-
def evaluate_ogb(
634-
self,
635-
model,
636-
mapped_triples: MappedTriples,
637-
batch_size: Optional[int] = None,
638-
**kwargs,
639-
) -> MetricResults:
640-
"""
641-
Evaluate a model using OGB's evaluator.
642-
643-
:param model:
644-
the model; will be set to evaluation mode.
645-
:param mapped_triples:
646-
the evaluation triples
647-
648-
.. note ::
649-
the evaluation triples have to match with the stored explicit negatives
650-
651-
:param batch_size:
652-
the batch size
653-
:param kwargs:
654-
additional keyword-based parameters passed to :meth:`pykeen.nn.Model.predict`
655-
656-
:return:
657-
the evaluation results
658-
"""
659-
from .ogb_evaluator import evaluate_ogb
660-
661-
return evaluate_ogb(
662-
evaluator=self,
663-
model=model,
664-
mapped_triples=mapped_triples,
665-
batch_size=batch_size,
666-
**kwargs,
667-
)
668-
669635

670636
class MacroRankBasedEvaluator(RankBasedEvaluator):
671637
"""Macro-average rank-based evaluation."""

‎tests/test_evaluation/test_evaluators.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from pykeen.constants import COLUMN_LABELS
2020
from pykeen.datasets import Nations
21-
from pykeen.evaluation import Evaluator, MetricResults, RankBasedEvaluator, RankBasedMetricResults
21+
from pykeen.evaluation import Evaluator, MetricResults, OGBEvaluator, RankBasedEvaluator, RankBasedMetricResults
2222
from pykeen.evaluation.classification_evaluator import (
2323
CLASSIFICATION_METRICS,
2424
ClassificationEvaluator,
@@ -139,12 +139,25 @@ def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMa
139139
kwargs["additional_filter_triples"] = self.dataset.training.mapped_triples
140140
return kwargs
141141

142-
@needs_packages("ogb")
143-
def test_ogb_evaluate(self):
142+
143+
@needs_packages("ogb")
144+
class OGBEvaluatorTests(RankBasedEvaluatorTests):
145+
"""Unit test for OGB evaluator."""
146+
147+
cls = OGBEvaluator
148+
kwargs = dict(num_negatives=3)
149+
150+
def _pre_instantiation_hook(self, kwargs: MutableMapping[str, Any]) -> MutableMapping[str, Any]: # noqa: D102
151+
kwargs = super()._pre_instantiation_hook(kwargs=kwargs)
152+
kwargs["evaluation_factory"] = self.factory
153+
kwargs["batch_size"] = 1
154+
return kwargs
155+
156+
def test_ogb_evaluate_alternate(self):
144157
"""Test OGB evaluation."""
145158
self.instance: SampledRankBasedEvaluator
146159
model = FixedModel(triples_factory=self.factory)
147-
result = self.instance.evaluate_ogb(model=model, mapped_triples=self.factory.mapped_triples, batch_size=1)
160+
result = self.instance.evaluate(model=model, mapped_triples=self.factory.mapped_triples, batch_size=1)
148161
assert isinstance(result, MetricResults)
149162

150163

0 commit comments

Comments
 (0)
Please sign in to comment.