CuPy#
This example demonstrates how to use BlackJAX nested sampling with CuPy, the GPU-accelerated drop-in replacement for NumPy. CuPy enables GPU computation for likelihood and prior functions while maintaining NumPy-like syntax.
Prerequisites#
Install the required packages:
pip install git+https://github.com/handley-lab/blackjax
pip install cupy-cuda12x numpy tqdm # Or cupy-cuda11x for CUDA 11.x
Note: CuPy requires CUDA. Choose the appropriate CuPy package for your CUDA version.
Run nested sampling with CuPy functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import cupy as cp
rng_key = jax.random.PRNGKey(0)
def loglikelihood_fn(theta):
return -50.0 * cp.sum((theta - 1) ** 2, axis=1) - 2.5 * cp.log(2 * cp.pi * 0.01)
def logprior_fn(theta):
return -0.5 * cp.sum(theta ** 2, axis=1) - 2.5 * cp.log(2 * cp.pi)
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def numpy_wrapper(theta):
theta_gpu = cp.asarray(theta)
result_gpu = fn(theta_gpu)
return cp.asnumpy(result_gpu)
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(logprior_fn),
loglikelihood_fn=wrap_fn(loglikelihood_fn),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)