Skip to content

Commit f9ad580

Browse files
mberrcthoytPyKEEN-bot
authoredJan 4, 2022
🎲 🔮 Uncertainty Estimate via MC dropout (pykeen#688)
* add get_dropout_modules utility * add uncertainty estimate via MC dropout * move uncertain prediction to predict * add predict_h_uncertain * add tests * add predict_{h,r}_uncertain * fix method name * add test for predict_{h,r}_uncertain * run black * trigger ci * update __all__ trigger ci * use named tuple trigger ci * add example use to docstring trigger ci * add missing info to docstring trigger ci * Small refactoring * Add documentation todos * Reorganize tests * Clean up docs * Reorganize * Update uncertainty.py * Update docs * More reorg * Update uncertainty.rst * update some docstrings * make helper public again * rename file * extend module docstring and fix std vs. var * Update references.rst * Update uncertainty.py * add example to module docstring * remove unused reference * fix flake8 trigger ci * Update uncertainty.py * Delete uncertainty.rst * Update index.rst * trigger ci * trigger ci * Update uncertainty.py * Trigger CI Co-authored-by: Charles Tapley Hoyt <[email protected]> Co-authored-by: PyKEEN_bot <[email protected]>
1 parent 3b986a0 commit f9ad580

File tree

10 files changed

+484
-13
lines changed

10 files changed

+484
-13
lines changed
 

‎docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ PyKEEN
5656
reference/ablation
5757
reference/lookup
5858
reference/predict
59+
reference/uncertainty
5960
reference/sealant
6061
reference/constants
6162
reference/nn/index

‎docs/source/reference/uncertainty.rst

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Uncertainty
2+
===========
3+
.. automodapi:: pykeen.models.uncertainty
4+
:no-heading:
5+
:headings: --
6+
:no-inheritance-diagram:

‎docs/source/references.rst

+3
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,6 @@ References
8383
8484
.. [sharifzadeh2019vrd] Sharifzadeh, S., *et al*. (2019). `Improving Visual Relation Detection using Depth
8585
Maps <http://arxiv.org/abs/1905.00966>`_. *arXiv*, 1905.00966.
86+
87+
.. [gal2016] Gal, Y., & Ghahramani, Z. (2016). `Dropout as a Bayesian Approximation: Representing Model Uncertainty in
88+
Deep Learning <https://dl.acm.org/doi/10.5555/3045390.3045502>`_. ICML 2016.

‎src/pykeen/models/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,10 @@ def predict_t(
351351
352352
.. note::
353353
354-
We only expect the right side-side predictions, i.e., $(h,r,*)$ to change its
354+
We only expect the right side-predictions, i.e., $(h,r,*)$ to change its
355355
default behavior when the model has been trained with inverse relations
356356
(mainly because of the behavior of the LCWA training approach). This is why
357-
the :func:`predict_scores_all_heads()` has different behavior depending on
357+
the :func:`predict_h` has different behavior depending on
358358
if inverse triples were used in training, and why this function has the same
359359
behavior regardless of the use of inverse triples.
360360
"""

‎src/pykeen/models/predict.py

+12-11
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def get_all_prediction_df(
258258
259259
Example usage:
260260
261-
.. code-block:: python
261+
.. code-block::
262262
263263
from pykeen.pipeline import pipeline
264264
from pykeen.models.predict import get_all_prediction_df
@@ -667,16 +667,6 @@ def predict_triples_df(
667667
"""
668668
Predict on labeled or mapped triples.
669669
670-
Example:
671-
>>> from pykeen.pipeline import pipeline
672-
>>> result = pipeline(dataset="nations", model="TransE")
673-
>>> from pykeen.models.predict import predict_triples_df
674-
>>> df = predict_triples_df(
675-
... model=result.model,
676-
... triples=("uk", "conferences", "brazil"),
677-
... triples_factory=result.training,
678-
... )
679-
680670
:param model:
681671
The model.
682672
:param triples: shape: (num_triples, 3)
@@ -699,6 +689,17 @@ def predict_triples_df(
699689
700690
:raises ValueError:
701691
If label-based triples have been provided, but the triples factory does not provide a mapping.
692+
693+
The TransE model can be trained and used to predict a given triple.
694+
695+
>>> from pykeen.pipeline import pipeline
696+
>>> result = pipeline(dataset="nations", model="TransE")
697+
>>> from pykeen.models.predict import predict_triples_df
698+
>>> df = predict_triples_df(
699+
... model=result.model,
700+
... triples=("uk", "conferences", "brazil"),
701+
... triples_factory=result.training,
702+
... )
702703
"""
703704
if triples is None:
704705
if triples_factory is None:

0 commit comments

Comments
 (0)