Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Temporary Wrappers to fix MADE #1398

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sbi/neural_nets/net_builders/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch import Tensor, nn, relu, tanh, tensor, uint8

from sbi.neural_nets.estimators import NFlowsFlow, ZukoFlow
from sbi.utils.nn_utils import get_numel
from sbi.utils.nn_utils import MADEMoGWrapper, get_numel
from sbi.utils.sbiutils import (
standardizing_net,
standardizing_transform,
Expand Down Expand Up @@ -77,7 +77,7 @@ def build_made(
standardizing_net(batch_y, structured_y), embedding_net
)

distribution = distributions_.MADEMoG(
distribution = MADEMoGWrapper(
features=x_numel,
hidden_features=hidden_features,
context_features=y_numel,
Expand Down
129 changes: 129 additions & 0 deletions sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import Optional
from warnings import warn

import nflows.nn.nde.made as made
import numpy as np
import torch
import torch.nn.functional as F
from pyknos.nflows import distributions as distributions_
from torch import Tensor, nn


Expand Down Expand Up @@ -62,3 +67,127 @@
return net.to(device)
else:
return net


"""
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
"""


class MADEWrapper(made.MADE):
"""Implementation of MADE.

It can use either feedforward blocks or residual blocks (default is residual).
Optionally, it can use batch norm or dropout within blocks (default is no).
"""

def __init__(
self,
features,
hidden_features,
context_features=None,
num_blocks=2,
output_multiplier=1,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
):
if use_residual_blocks and random_mask:
raise ValueError("Residual blocks can't be used with random masks.")
super().__init__(

Check warning on line 99 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L97-L99

Added lines #L97 - L99 were not covered by tests
features + 1,
hidden_features,
context_features,
num_blocks,
output_multiplier,
use_residual_blocks,
random_mask,
activation,
dropout_probability,
use_batch_norm,
)

def forward(self, inputs, context=None):
# add dummy input to ensure all dims conditioned on context.
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
concat_input = torch.cat((dummy_input, inputs), dim=-1)
outputs = super().forward(concat_input, context)
return outputs[..., self.output_multiplier :] # remove dummy input

Check warning on line 117 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L114-L117

Added lines #L114 - L117 were not covered by tests


"""
Temporary Patches to fix nflows MADE bug. Remove once upstream bug is fixed.
"""


class MADEMoGWrapper(distributions_.MADEMoG):
def __init__(
self,
features,
hidden_features,
context_features,
num_blocks=2,
num_mixture_components=1,
use_residual_blocks=True,
random_mask=False,
activation=F.relu,
dropout_probability=0.0,
use_batch_norm=False,
custom_initialization=False,
):
super().__init__(
features + 1,
hidden_features,
context_features,
num_blocks,
num_mixture_components,
use_residual_blocks,
random_mask,
activation,
dropout_probability,
use_batch_norm,
custom_initialization,
)

def _log_prob(self, inputs, context=None):
dummy_input = torch.zeros((inputs.shape[:-1] + (1,)))
concat_inputs = torch.cat((dummy_input, inputs), dim=-1)

Check warning on line 156 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L155-L156

Added lines #L155 - L156 were not covered by tests

outputs = self._made.forward(concat_inputs, context=context)
outputs = outputs.reshape(

Check warning on line 159 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L158-L159

Added lines #L158 - L159 were not covered by tests
*concat_inputs.shape, self._made.num_mixture_components, 3
)

logits, means, unconstrained_stds = (

Check warning on line 163 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L163

Added line #L163 was not covered by tests
outputs[..., 0],
outputs[..., 1],
outputs[..., 2],
)
# remove first dimension of means, unconstrained_stds
logits = logits[..., 1:, :]
means = means[..., 1:, :]
unconstrained_stds = unconstrained_stds[..., 1:, :]

Check warning on line 171 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L169-L171

Added lines #L169 - L171 were not covered by tests

log_mixture_coefficients = torch.log_softmax(logits, dim=-1)
stds = F.softplus(unconstrained_stds) + self._made.epsilon

Check warning on line 174 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L173-L174

Added lines #L173 - L174 were not covered by tests

log_prob = torch.sum(

Check warning on line 176 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L176

Added line #L176 was not covered by tests
torch.logsumexp(
log_mixture_coefficients
- 0.5
* (
np.log(2 * np.pi)
+ 2 * torch.log(stds)
+ ((inputs[..., None] - means) / stds) ** 2
),
dim=-1,
),
dim=-1,
)
return log_prob

Check warning on line 189 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L189

Added line #L189 was not covered by tests

def _sample(self, num_samples, context=None):
samples = self._made.sample(num_samples, context=context)
return samples[..., 1:]

Check warning on line 193 in sbi/utils/nn_utils.py

View check run for this annotation

Codecov / codecov/patch

sbi/utils/nn_utils.py#L192-L193

Added lines #L192 - L193 were not covered by tests