Description
errors in examples/time_series/bayesian_var_model.ipynb
Error in bayesian_var_model.ipynb:
Notebook url: https://github.com/pymc-devs/pymc-examples/tree/main/examples/time_series/bayesian_var_model.ipynb
Issue description
Error in imports: replace from pymc.sampling_jax import sample_blackjax_nuts
with from pymc.sampling.jax import sample_blackjax_nuts
Error on creating betaX in make_model and make_hierarchical_model:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[44], line 1
----> 1 make_model(n_lags, n_eqs, df, priors)
Cell In[43], line 44, in make_model(n_lags, n_eqs, df, priors, mv_norm, prior_checks)
41 data_obs = pm.Data("data_obs", df.values[n_lags:], dims=["time", "equations"])
43 betaX = calc_ar_step(lag_coefs, n_eqs, n_lags, df)
---> 44 betaX = pm.Deterministic(
45 "betaX",
46 betaX,
47 dims=[
48 "time",
49 ],
50 )
51 mean = alpha + betaX
53 if mv_norm:
File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:2254, in Deterministic(name, var, model, dims)
2252 var = var.copy(model.name_for(name))
2253 model.deterministics.append(var)
-> 2254 model.add_named_variable(var, dims)
2256 from pymc.printing import str_for_potential_or_deterministic
2258 var.str_repr = types.MethodType(
2259 functools.partial(str_for_potential_or_deterministic, dist_name="Deterministic"), var
2260 )
File c:\Users\Ivan\anaconda3\envs\pymc_env\Lib\site-packages\pymc\model\core.py:1472, in Model.add_named_variable(self, var, dims)
1470 # This check implicitly states that only vars with .ndim attribute can have dims
1471 if var.ndim != len(dims):
-> 1472 raise ValueError(
1473 f"{var} has {var.ndim} dims but {len(dims)} dim labels were provided."
1474 )
1475 self.named_vars_to_dims[var.name] = dims
1477 self.named_vars[var.name] = var
ValueError: betaX has 2 dims but 1 dim labels were provided.
Proposed solution
Adding another dimension to broken piece of code (namely "equations") solves the problem:
betaX = pm.Deterministic(
"betaX",
betaX,
dims=[
"time",
"equations",
],
)
Another issue
This error, AFAIC, is purely Windows-related (see here). In function make_hierarchical_model
this line breaks:
idata.extend(sample_blackjax_nuts(2000, random_seed=120))
. Same error with sample_numpyro_nuts
.
First error is RuntimeError: Incorrect output dtype for return value #0: Expected: int64, Actual: int32
. It's being fixed as in this issue.
But then another error shows:
TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])'], momentum=['ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[])', 'ShapedArray(float64[3,2,3])', 'ShapedArray(float64[3])', 'ShapedArray(float64[6])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[3,2,3]) vs. ShapedArray(float32[3,2,3])', 'DIFFERENT ShapedArray(float64[3]) vs. ShapedArray(float32[3])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').
and I have no idea how to solve it.
Possible solution
The workaround (if it is not an issue for Linux systems) is to use simple pm.sample
instead of sample_blackjax_nuts
if code is running on Windows (can be checked with if os.name == 'nt'
for example).
This behavior was also fixed in numpy 2.0 (link to release notes), so this solution may be temporary.