diff --git a/nflows/transforms/splines/rational_quadratic.py b/nflows/transforms/splines/rational_quadratic.py index d36b53e..99251ce 100644 --- a/nflows/transforms/splines/rational_quadratic.py +++ b/nflows/transforms/splines/rational_quadratic.py @@ -29,34 +29,36 @@ def unconstrained_rational_quadratic_spline( logabsdet = torch.zeros_like(inputs) if tails == "linear": - unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) constant = np.log(np.exp(1 - min_derivative) - 1) - unnormalized_derivatives[..., 0] = constant - unnormalized_derivatives[..., -1] = constant + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1), + value=constant) - outputs[outside_interval_mask] = inputs[outside_interval_mask] - logabsdet[outside_interval_mask] = 0 + outputs += outside_interval_mask*inputs + # logabsdet += outside_interval_mask*0 else: raise RuntimeError("{} tails are not implemented.".format(tails)) - if torch.any(inside_interval_mask): - ( - outputs[inside_interval_mask], - logabsdet[inside_interval_mask], - ) = rational_quadratic_spline( - inputs=inputs[inside_interval_mask], - unnormalized_widths=unnormalized_widths[inside_interval_mask, :], - unnormalized_heights=unnormalized_heights[inside_interval_mask, :], - unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], - inverse=inverse, - left=-tail_bound, - right=tail_bound, - bottom=-tail_bound, - top=tail_bound, - min_bin_width=min_bin_width, - min_bin_height=min_bin_height, - min_derivative=min_derivative, - ) + ( + inside_outputs, + inside_logabsdet, + ) = rational_quadratic_spline( + # Clamp inputs to the domain to prevent out of domain errors + inputs=inputs.clamp(-tail_bound, tail_bound), + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + left=-tail_bound, + right=tail_bound, + bottom=-tail_bound, + top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + ) + + outputs += inside_interval_mask*inside_outputs + logabsdet += inside_interval_mask*inside_logabsdet return outputs, logabsdet @@ -87,23 +89,21 @@ def rational_quadratic_spline( widths = F.softmax(unnormalized_widths, dim=-1) widths = min_bin_width + (1 - min_bin_width * num_bins) * widths - cumwidths = torch.cumsum(widths, dim=-1) - cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0) - cumwidths = (right - left) * cumwidths + left - cumwidths[..., 0] = left + widths *= right - left + cumwidths = F.pad(widths, pad=(1, 0), mode="constant", value=left) + cumwidths = torch.cumsum(cumwidths, dim=-1) + # Make right-most knot at the right boundary cumwidths[..., -1] = right - widths = cumwidths[..., 1:] - cumwidths[..., :-1] derivatives = min_derivative + F.softplus(unnormalized_derivatives) heights = F.softmax(unnormalized_heights, dim=-1) heights = min_bin_height + (1 - min_bin_height * num_bins) * heights - cumheights = torch.cumsum(heights, dim=-1) - cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0) - cumheights = (top - bottom) * cumheights + bottom - cumheights[..., 0] = bottom + heights *= top - bottom + cumheights = F.pad(heights, pad=(1, 0), mode="constant", value=bottom) + cumheights = torch.cumsum(cumheights, dim=-1) + # Make top-most knot at the top boundary cumheights[..., -1] = top - heights = cumheights[..., 1:] - cumheights[..., :-1] if inverse: bin_idx = torchutils.searchsorted(cumheights, inputs)[..., None]