Numba#
This example demonstrates how to use BlackJAX nested sampling with Numba JIT-compiled functions. Numba compiles Python functions to machine code at runtime, providing significant speedups for numerical computations while keeping the code in pure Python.
Prerequisites#
Install the required packages:
pip install git+https://github.com/handley-lab/blackjax
pip install numba numpy tqdm
Run nested sampling with Numba functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import numba
rng_key = jax.random.PRNGKey(0)
@numba.jit(nopython=True)
def loglikelihood_fn(theta):
return -50.0 * np.sum((theta - 1) ** 2, axis=1) - 2.5 * np.log(2 * np.pi * 0.01)
@numba.jit(nopython=True)
def logprior_fn(theta):
return -0.5 * np.sum(theta ** 2, axis=1) - 2.5 * np.log(2 * np.pi)
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def numpy_wrapper(theta):
return fn(np.asarray(theta))
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(logprior_fn),
loglikelihood_fn=wrap_fn(loglikelihood_fn),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)