Skip to content

Commit 038f604

Browse files
mberrcthoyt
andauthoredMay 8, 2022
🔥⚡ POC PyTorch Lightning integration (pykeen#905)
Co-authored-by: Charles Tapley Hoyt <[email protected]>
1 parent 05071f5 commit 038f604

File tree

6 files changed

+179
-1
lines changed

6 files changed

+179
-1
lines changed
 

‎docs/source/contrib/lightning.rst

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
PyTorch Lightning Integration
2+
=============================
3+
.. automodapi:: pykeen.contrib.lightning
4+
:no-heading:
5+
:headings: --
6+
:no-inheritance-diagram:

‎docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ PyKEEN
2020
tutorial/representations
2121
tutorial/node_piece
2222
tutorial/inductive_lp
23+
contrib/lightning
2324

2425
.. toctree::
2526
:caption: Bring Your Own

‎setup.cfg

+2
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ tensorboard =
101101
tensorboard
102102
transformers =
103103
transformers
104+
lightning =
105+
pytorch_lightning
104106
tests =
105107
unittest-templates>=0.0.5
106108
coverage

‎src/pykeen/contrib/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""This module contains code for non-standard dependencies in PyKEEN."""

‎src/pykeen/contrib/lightning.py

+165
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""PyTorch Lightning integration.
4+
5+
PyTorch Lightning poses an alternative way to implement a training
6+
loop and evaluation loop for knowledge graph embedding models that
7+
has some nice features:
8+
9+
- mixed precision training
10+
- multi-gpu training
11+
12+
.. code-block:: python
13+
14+
model = LitLCWAModule(
15+
dataset="fb15k237",
16+
dataset_kwargs=dict(create_inverse_triples=True),
17+
model="mure",
18+
model_kwargs=dict(embedding_dim=128, loss="bcewithlogits"),
19+
batch_size=128,
20+
)
21+
trainer = pytorch_lightning.Trainer(
22+
accelerator="auto", # automatically choose accelerator
23+
logger=False, # defaults to TensorBoard; explicitly disabled here
24+
precision=16, # mixed precision training
25+
)
26+
trainer.fit(model=model)
27+
28+
"""
29+
30+
import pytorch_lightning
31+
import torch
32+
import torch.utils.data
33+
from class_resolver import HintOrType, OptionalKwargs
34+
35+
from pykeen.datasets import get_dataset
36+
from pykeen.datasets.base import Dataset
37+
from pykeen.models import Model, model_resolver
38+
from pykeen.optimizers import optimizer_resolver
39+
from pykeen.triples.triples_factory import CoreTriplesFactory
40+
41+
__all__ = [
42+
"LitLCWAModule",
43+
]
44+
45+
46+
class LitLCWAModule(pytorch_lightning.LightningModule):
47+
"""A PyTorch Lightning module for training a model with LCWA training loop.
48+
49+
.. seealso:: https://github.com/pykeen/pykeen/pull/905
50+
"""
51+
52+
def __init__(
53+
self,
54+
# dataset
55+
dataset: HintOrType[Dataset] = "nations",
56+
dataset_kwargs: OptionalKwargs = None,
57+
# model
58+
model: HintOrType[Model] = "distmult",
59+
model_kwargs: OptionalKwargs = None,
60+
# stored outside of the training loop / optimizer to give access to auto-tuning from Lightning
61+
batch_size: int = 32,
62+
learning_rate: float = 1.0e-03,
63+
# optimizer
64+
optimizer: HintOrType[torch.optim.Optimizer] = None,
65+
optimizer_kwargs: OptionalKwargs = None,
66+
):
67+
"""
68+
Create the lightning module.
69+
70+
:param dataset:
71+
the dataset, or a hint thereof
72+
:param dataset_kwargs:
73+
additional keyword-based parameters passed to the dataset
74+
75+
:param model:
76+
the model, or a hint thereof
77+
:param model_kwargs:
78+
additional keyword-based parameters passed to the model
79+
80+
:param batch_size:
81+
the training batch size
82+
:param learning_rate:
83+
the learning rate
84+
85+
:param optimizer:
86+
the optimizer, or a hint thereof
87+
:param optimizer_kwargs:
88+
additional keyword-based parameters passed to the optimizer. should not contain `lr`, or `params`.
89+
"""
90+
super().__init__()
91+
self.dataset = get_dataset(dataset=dataset, dataset_kwargs=dataset_kwargs)
92+
self.model = model_resolver.make(model, model_kwargs, triples_factory=self.dataset.training)
93+
self.loss = self.model.loss
94+
self.optimizer = optimizer
95+
self.optimizer_kwargs = optimizer_kwargs
96+
self.learning_rate = learning_rate
97+
self.batch_size = batch_size
98+
99+
def forward(self, x):
100+
"""
101+
Perform the prediction or inference step.
102+
103+
.. note::
104+
in lightning, forward defines the prediction/inference actions
105+
"""
106+
return self.model.predict_t(x)
107+
108+
def _step(self, batch, prefix: str):
109+
"""Refactored step."""
110+
hr_batch, labels = batch
111+
scores = self.model.score_t(hr_batch=hr_batch)
112+
loss = self.loss.process_lcwa_scores(predictions=scores, labels=labels)
113+
self.log(f"{prefix}_loss", loss)
114+
return loss
115+
116+
def training_step(self, batch, batch_idx):
117+
"""Perform a training step."""
118+
return self._step(batch, prefix="train")
119+
120+
def validation_step(self, batch, batch_idx, *args, **kwargs):
121+
"""Perform a validation step."""
122+
return self._step(batch, prefix="val")
123+
124+
def _dataloader(self, triples_factory: CoreTriplesFactory, shuffle: bool = False) -> torch.utils.data.DataLoader:
125+
"""Create a data loader."""
126+
return torch.utils.data.DataLoader(
127+
dataset=triples_factory.create_lcwa_instances(),
128+
batch_size=self.batch_size,
129+
shuffle=shuffle,
130+
)
131+
132+
def train_dataloader(self):
133+
"""Create the training data loader."""
134+
return self._dataloader(triples_factory=self.dataset.training, shuffle=True)
135+
136+
def val_dataloader(self):
137+
"""Create the validation data loader."""
138+
return self._dataloader(triples_factory=self.dataset.validation, shuffle=False)
139+
140+
def configure_optimizers(self):
141+
"""Configure the optimizers."""
142+
return optimizer_resolver.make(
143+
self.optimizer, self.optimizer_kwargs, params=self.parameters(), lr=self.learning_rate
144+
)
145+
146+
147+
def _main():
148+
"""Run PyTorch lightning model."""
149+
model = LitLCWAModule(
150+
dataset="fb15k237",
151+
dataset_kwargs=dict(create_inverse_triples=True),
152+
model="mure",
153+
model_kwargs=dict(embedding_dim=128, loss="bcewithlogits"),
154+
batch_size=128,
155+
)
156+
trainer = pytorch_lightning.Trainer(
157+
accelerator="auto", # automatically choose accelerator
158+
logger=False, # defaults to TensorBoard; explicitly disabled here
159+
precision=16, # mixed precision training
160+
)
161+
trainer.fit(model=model)
162+
163+
164+
if __name__ == "__main__":
165+
_main()

‎tox.ini

+2-1
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ description = Test building the documentation in an isolated environment.
190190
changedir = docs
191191
extras =
192192
docs
193+
lightning
193194
commands =
194195
mkdir -p {envtmpdir}
195196
cp -r source {envtmpdir}/source
@@ -209,7 +210,7 @@ whitelist_externals =
209210
[testenv:docs]
210211
description = Build the documentation locally.
211212
extras =
212-
docs
213+
{[testenv:docs-test]extras}
213214
commands =
214215
python -m sphinx -W -b html -d docs/build/doctrees docs/source docs/build/html
215216

0 commit comments

Comments
 (0)
Please sign in to comment.