3
3
[ ![ active] ( https://www.repostatus.org/badges/latest/active.svg )] ( https://www.repostatus.org/#active )
4
4
[ ![ ci] ( https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml/badge.svg )] ( https://github.com/ramsey-devs/ramsey/actions/workflows/ci.yaml )
5
5
[ ![ coverage] ( https://codecov.io/gh/ramsey-devs/ramsey/branch/main/graph/badge.svg?token=dn1xNBSalZ )] ( https://codecov.io/gh/ramsey-devs/ramsey )
6
- [ ![ quality] ( https://app.codacy.com/project/badge/Grade/ed13460537fd4ac099c8534b1d9a0202 )] ( https://app.codacy.com/gh/ramsey-devs/ramsey/dashboard?utm_source=gh&utm_medium=referral&utm_content=&utm_campaign=Badge_grade )
7
6
[ ![ documentation] ( https://readthedocs.org/projects/ramsey/badge/?version=latest )] ( https://ramsey.readthedocs.io/en/latest/?badge=latest )
8
7
[ ![ version] ( https://img.shields.io/pypi/v/ramsey.svg?colorB=black&style=flat )] ( https://pypi.org/project/ramsey/ )
9
8
12
11
## About
13
12
14
13
Ramsey is a library for probabilistic deep learning using [ JAX] ( https://github.com/google/jax ) ,
15
- [ Flax] ( https://github.com/google/flax ) and [ NumPyro] ( https://github.com/pyro-ppl/numpyro ) .
16
-
17
- Ramsey's scope covers
14
+ [ Flax] ( https://github.com/google/flax ) and [ NumPyro] ( https://github.com/pyro-ppl/numpyro ) . Its scope covers
18
15
19
16
- neural processes (vanilla, attentive, Markovian, convolutional, ...),
20
17
- neural Laplace and Fourier operator models,
21
- - flow matching and denoising diffusion models,
22
18
- etc.
23
19
24
20
## Example usage
@@ -29,35 +25,44 @@ You can, for instance, construct a simple neural process like this:
29
25
from flax import nnx
30
26
31
27
from ramsey import NP
32
- from ramsey.nn import MLP
28
+ from ramsey.nn import MLP # just a flax.nnx module
33
29
34
30
def get_neural_process (in_features , out_features ):
35
31
dim = 128
36
32
np = NP(
37
- decoder = MLP(in_features, [dim, dim, out_features * 2 ], rngs = nnx.Rngs(0 )),\
38
33
latent_encoder = (
39
- MLP(in_features, [dim, dim], rngs = nnx.Rngs(1 )),
40
- MLP(dim, [dim, dim * 2 ], rngs = nnx.Rngs(2 ))
41
- )
34
+ MLP(in_features, [dim, dim], rngs = nnx.Rngs(0 )),
35
+ MLP(dim, [dim, dim * 2 ], rngs = nnx.Rngs(1 ))
36
+ ),
37
+ decoder = MLP(in_features, [dim, dim, out_features * 2 ], rngs = nnx.Rngs(2 ))
42
38
)
43
39
return np
44
40
45
41
neural_process = get_neural_process(1 , 1 )
46
42
```
47
43
48
- The neural process takes a decoder and a set of two latent encoders as arguments. All of these are typically ` flax.nnx ` MLPs, but
49
- Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs. Once the model is defined, you can train
50
- it by accessing the ELBO given input-output pairs via
44
+ The neural process above takes a decoder and a set of two latent encoders as arguments. All of these are typically ` flax.nnx ` MLPs, but
45
+ Ramsey is flexible enough that you can change them, for instance, to CNNs or RNNs.
46
+
47
+ Ramsey provides a unified interface where each method implements (at least) ` __call__ ` and ` loss `
48
+ functions to transform a set of inputs and compute a training loss, respectively:
51
49
52
50
``` python
53
51
from jax import random as jr
54
52
from ramsey.data import sample_from_sine_function
55
53
56
- key = jr.PRNGKey(0 )
57
- data = sample_from_sine_function(key)
58
-
54
+ data = sample_from_sine_function(jr.key(0 ))
59
55
x_context, y_context = data.x[:, :20 , :], data.y[:, :20 , :]
60
56
x_target, y_target = data.x, data.y
57
+
58
+ # make a prediction
59
+ pred = neural_process(
60
+ x_context = x_context,
61
+ y_context = y_context,
62
+ x_target = x_target,
63
+ )
64
+
65
+ # compute the loss
61
66
loss = neural_process.loss(
62
67
x_context = x_context,
63
68
y_context = y_context,
@@ -66,11 +71,6 @@ loss = neural_process.loss(
66
71
)
67
72
```
68
73
69
- Making predictions can be done like this:
70
- ``` python
71
- pred = neural_process(x_context = x_context, y_context = y_context, x_target = x_target)
72
- ```
73
-
74
74
## Installation
75
75
76
76
To install from PyPI, call:
0 commit comments