PyTorch#

This example demonstrates how to use BlackJAX nested sampling with PyTorch distributions and likelihood functions. The key is wrapping PyTorch functions to be compatible with JAX using jax.pure_callback.

Prerequisites#

Install the required packages:

pip install git+https://github.com/handley-lab/blackjax
pip install torch numpy tqdm

Run nested sampling with PyTorch functions#

import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import torch

rng_key = jax.random.PRNGKey(0)

dtype = torch.float32
loglikelihood_fn = torch.distributions.MultivariateNormal(torch.ones(5), 0.01 * torch.eye(5)).log_prob
logprior_fn = torch.distributions.MultivariateNormal(torch.zeros(5), torch.eye(5)).log_prob

def wrap_fn(fn, vmap_method='legacy_vectorized'):
    def numpy_wrapper(theta):
        x = torch.from_numpy(np.asarray(theta).copy()).to(dtype)
        result = fn(x)
        return result.numpy()
    
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

algo = blackjax.nss(
    logprior_fn=wrap_fn(logprior_fn),
    loglikelihood_fn=wrap_fn(loglikelihood_fn),
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)