NumPy/SciPy#
This example demonstrates how to use BlackJAX nested sampling with NumPy and SciPy functions. This approach is useful when you have existing likelihood functions written in NumPy/SciPy that you want to use with BlackJAX.
Prerequisites#
Install the required packages:
pip install git+https://github.com/handley-lab/blackjax
pip install scipy numpy tqdm
Run nested sampling with NumPy/SciPy functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
from scipy import stats
rng_key = jax.random.PRNGKey(0)
loglikelihood_fn = stats.multivariate_normal(np.ones(5), 0.01 * np.eye(5)).logpdf
logprior_fn = stats.multivariate_normal(np.zeros(5), np.eye(5)).logpdf
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(logprior_fn),
loglikelihood_fn=wrap_fn(loglikelihood_fn),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)