Skip to content

Commit a582d04

Browse files
authored
Move to NNX (#48)
1 parent 4254bac commit a582d04

23 files changed

+2005
-501
lines changed

README.md

+21-21
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
[![active](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
44
[![ci](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml/badge.svg)](https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml)
55
[![coverage](https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ)](https://codecov.io/gh/ramsey-devs/ramsey)
6-
[![quality](https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202)](https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade)
76
[![documentation](https://readthedocs.org/projects/ramsey/badge/?version=latest)](https://ramsey.readthedocs.io/en/latest/?badge=latest)
87
[![version](https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat)](https://pypi.org/project/ramsey/)
98

@@ -12,13 +11,10 @@
1211
## About
1312

1413
Ramsey is a library for probabilistic deep learning using [JAX](https://github.com/google/jax),
15-
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro).
16-
17-
Ramsey's scope covers
14+
[Flax](https://github.com/google/flax) and [NumPyro](https://github.com/pyro-ppl/numpyro). Its scope covers
1815

1916
- neural processes (vanilla, attentive, Markovian, convolutional, ...),
2017
- neural Laplace and Fourier operator models,
21-
- flow matching and denoising diffusion models,
2218
- etc.
2319

2420
## Example usage
@@ -29,35 +25,44 @@ You can, for instance, construct a simple neural process like this:
2925
from flax import nnx
3026

3127
from ramsey import NP
32-
from ramsey.nn import MLP
28+
from ramsey.nn import MLP # just a flax.nnx module
3329

3430
def get_neural_process(in_features, out_features):
3531
dim = 128
3632
np = NP(
37-
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(0)),\
3833
latent_encoder=(
39-
MLP(in_features, [dim, dim], rngs=nnx.Rngs(1)),
40-
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(2))
41-
)
34+
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
35+
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
36+
),
37+
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2))
4238
)
4339
return np
4440

4541
neural_process = get_neural_process(1, 1)
4642
```
4743

48-
The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
49-
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
50-
it by accessing the ELBO given input-output pairs via
44+
The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
45+
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.
46+
47+
Ramsey provides a unified interface where each method implements (at least) `__call__` and `loss`
48+
functions to transform a set of inputs and compute a training loss, respectively:
5149

5250
```python
5351
from jax import random as jr
5452
from ramsey.data import sample_from_sine_function
5553

56-
key = jr.PRNGKey(0)
57-
data = sample_from_sine_function(key)
58-
54+
data = sample_from_sine_function(jr.key(0))
5955
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
6056
x_target, y_target = data.x, data.y
57+
58+
# make a prediction
59+
pred = neural_process(
60+
x_context=x_context,
61+
y_context=y_context,
62+
x_target=x_target,
63+
)
64+
65+
# compute the loss
6166
loss = neural_process.loss(
6267
x_context=x_context,
6368
y_context=y_context,
@@ -66,11 +71,6 @@ loss = neural_process.loss(
6671
)
6772
```
6873

69-
Making predictions can be done like this:
70-
```python
71-
pred = neural_process(x_context=x_context, y_context=y_context, x_target=x_target)
72-
```
73-
7474
## Installation
7575

7676
To install from PyPI, call:

docs/conf.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@
4444

4545
autosummary_generate = True
4646
autodoc_typehints = 'none'
47-
typehints_fully_qualified = True
48-
always_document_param_types = True
47+
# typehints_fully_qualified = True
48+
# always_document_param_types = True
49+
# autodoc_inherit_docstrings = False
50+
# typehints_document_rtype= False
4951

5052
html_theme = "sphinx_book_theme"
5153
html_theme_options = {

docs/index.rst

+21-20
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77

88
Ramsey is a library for probabilistic modelling using `JAX <https://github.com/google/jax>`_ ,
99
`Flax <https://github.com/google/flax>`_ and `NumPyro <https://github.com/pyro-ppl/numpyro>`_.
10-
1110
Ramsey's scope covers
1211

1312
- neural processes (vanilla, attentive, Markovian, convolutional, ...),
1413
- neural Laplace and Fourier operator models,
15-
- flow matching and denoising diffusion models,
1614
- etc.
1715

1816
Example
@@ -25,48 +23,52 @@ You can, for instance, construct a simple neural process like this:
2523
from flax import nnx
2624
2725
from ramsey import NP
28-
from ramsey.nn import MLP
26+
from ramsey.nn import MLP # just a flax.nnx module
2927
3028
def get_neural_process(in_features, out_features):
3129
dim = 128
3230
np = NP(
33-
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(0)),
3431
latent_encoder=(
35-
MLP(in_features, [dim, dim], rngs=nnx.Rngs(1)),
36-
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(2))
37-
)
32+
MLP(in_features, [dim, dim], rngs=nnx.Rngs(0)),
33+
MLP(dim, [dim, dim * 2], rngs=nnx.Rngs(1))
34+
),
35+
decoder=MLP(in_features, [dim, dim, out_features * 2], rngs=nnx.Rngs(2)),
3836
)
3937
return np
4038
4139
neural_process = get_neural_process(1, 1)
4240
43-
The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically `flax.nnx` MLPs, but
44-
Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
45-
the model by accessing the ELBO given input-output pairs via
41+
The neural process above takes a decoder and a set of two latent encoders as arguments.
42+
All of these are typically ``flax.nnx`` MLPs, but Ramsey is flexible enough that you can
43+
change them, for instance, to CNNs or RNNs.
44+
45+
Ramsey provides a unified interface where each method implements (at least) ``__call__`` and ``loss``
46+
functions to transform a set of inputs and compute a training loss, respectively:
4647

4748
.. code-block:: python
4849
4950
from jax import random as jr
5051
from ramsey.data import sample_from_sine_function
5152
52-
key = jr.PRNGKey(0)
53-
data = sample_from_sine_function(key)
54-
53+
data = sample_from_sine_function(jr.key(0))
5554
x_context, y_context = data.x[:, :20, :], data.y[:, :20, :]
5655
x_target, y_target = data.x, data.y
56+
57+
# make a prediction
58+
pred = neural_process(
59+
x_context=x_context,
60+
y_context=y_context,
61+
x_target=x_target,
62+
)
63+
64+
# compute the loss
5765
loss = neural_process.loss(
5866
x_context=x_context,
5967
y_context=y_context,
6068
x_target=x_target,
6169
y_target=y_target
6270
)
6371
64-
Making predictions can be done like this:
65-
66-
.. code-block:: python
67-
68-
pred = neural_process(x_context=x_context, y_context=y_context, x_target=x_target)
69-
7072
7173
Why Ramsey
7274
----------
@@ -119,7 +121,6 @@ Ramsey is licensed under the Apache 2.0 License.
119121
:hidden:
120122

121123
🏠 Home <self>
122-
📰 News <news>
123124
📚 References <references>
124125

125126
.. toctree::

docs/news.rst

-16
This file was deleted.

docs/notebooks/neural_processes.ipynb

+59-55
Large diffs are not rendered by default.

docs/ramsey.experimental.rst

+6-6
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ Covariance functions
2525
ExponentiatedQuadratic
2626
~~~~~~~~~~~~~~~~~~~~~~
2727

28-
.. autoclass:: ExponentiatedQuadratic
29-
:members: __call__
28+
.. autoclass:: ExponentiatedQuadratic
29+
:members: __call__
3030

3131
.. autofunction:: exponentiated_quadratic
3232

3333
Linear
3434
~~~~~~
3535

36-
.. autoclass:: Linear
37-
:members: __call__
36+
.. autoclass:: Linear
37+
:members: __call__
3838

3939
.. autofunction:: linear
4040

4141
Periodic
4242
~~~~~~~~~
4343

44-
.. autoclass:: Periodic
45-
:members: __call__
44+
.. autoclass:: Periodic
45+
:members: __call__
4646

4747
.. autofunction:: periodic

docs/ramsey.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@ Neural processes
1717
~~~~~~~~~~~~~~~~
1818

1919
.. autoclass:: NP
20-
:members: __call__
20+
:members: __call__, loss
2121

2222
.. autoclass:: ANP
23-
:members: __call__
23+
:members: __call__, loss
2424

2525
.. autoclass:: DANP
26-
:members: __call__
26+
:members: __call__, loss
2727

2828
Train functions
2929
---------------

0 commit comments

Comments
 (0)