C++#

This example demonstrates how to use BlackJAX nested sampling with C++ implementations of likelihood and prior functions. The C++ code is compiled using pybind11 to create a Python module, with JAX’s pure_callback providing the bridge.

Prerequisites#

Install the required Python packages:

pip install git+https://github.com/handley-lab/blackjax
pip install pybind11 numpy tqdm

You’ll also need a C++ compiler (g++) installed on your system.

Setup Instructions#

1. Create the C++ implementation#

First, create a file model.cpp with your likelihood and prior functions using pybind11:

/* model.cpp */
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include <cmath>

namespace py = pybind11;

constexpr double LOG_2PI = 1.8378770664093454;

// Sequential (scalar) implementations
double loglikelihood_scalar(py::array_t<double> theta_array) {
    auto theta = theta_array.unchecked<1>();
    const int d = theta.shape(0);
    
    const double inv_var = 1.0 / 0.01;
    const double log_det = d * std::log(0.01);
    const double mu = 1.0;
    
    double q = 0.0;
    for (int i = 0; i < d; ++i) {
        double diff = theta(i) - mu;
        q += diff * diff * inv_var;
    }
    return -0.5 * (d * LOG_2PI + log_det + q);
}

double logprior_scalar(py::array_t<double> theta_array) {
    auto theta = theta_array.unchecked<1>();
    const int d = theta.shape(0);
    
    double q = 0.0;
    for (int i = 0; i < d; ++i) {
        q += theta(i) * theta(i);
    }
    return -0.5 * (d * LOG_2PI + q);
}

// Sequential (scalar) C++ implementation
static double loglikelihood_scalar_impl(const double* theta, size_t d) {
    const double inv_var = 1.0 / 0.01;
    const double log_det = static_cast<double>(d) * std::log(0.01);
    const double mu = 1.0;
    
    double q = 0.0;
    for (size_t i = 0; i < d; ++i) {
        double diff = theta[i] - mu;
        q += diff * diff * inv_var;
    }
    return -0.5 * (d * LOG_2PI + log_det + q);
}

static double logprior_scalar_impl(const double* theta, size_t d) {
    double q = 0.0;
    for (size_t i = 0; i < d; ++i) {
        q += theta[i] * theta[i];
    }
    return -0.5 * (d * LOG_2PI + q);
}

// Batched wrappers with GIL release
py::array_t<double> loglikelihood(py::array_t<double, py::array::c_style | py::array::forcecast> theta) {
    py::buffer_info info = theta.request();
    if (info.ndim != 2)
        throw py::value_error("theta must be 2D (batch, dim)");
    
    const size_t batch = static_cast<size_t>(info.shape[0]);
    const size_t d = static_cast<size_t>(info.shape[1]);
    const double* data = static_cast<const double*>(info.ptr);
    
    py::array_t<double> out(batch);
    double* out_ptr = out.mutable_data();
    
    {
        py::gil_scoped_release release;
        for (size_t b = 0; b < batch; ++b) {
            out_ptr[b] = loglikelihood_scalar_impl(data + b * d, d);
        }
    }
    return out;
}

py::array_t<double> logprior(py::array_t<double, py::array::c_style | py::array::forcecast> theta) {
    py::buffer_info info = theta.request();
    if (info.ndim != 2)
        throw py::value_error("theta must be 2D (batch, dim)");
    
    const size_t batch = static_cast<size_t>(info.shape[0]);
    const size_t d = static_cast<size_t>(info.shape[1]);
    const double* data = static_cast<const double*>(info.ptr);
    
    py::array_t<double> out(batch);
    double* out_ptr = out.mutable_data();
    
    {
        py::gil_scoped_release release;
        for (size_t b = 0; b < batch; ++b) {
            out_ptr[b] = logprior_scalar_impl(data + b * d, d);
        }
    }
    return out;
}

PYBIND11_MODULE(model, m) {
    m.doc() = "Sequential C++ likelihood and prior functions with batching wrapper";
    m.def("loglikelihood", &loglikelihood, "Log likelihood function (batched)");
    m.def("logprior", &logprior, "Log prior function (batched)");
    m.def("loglikelihood_scalar", &loglikelihood_scalar, "Log likelihood function (scalar)");
    m.def("logprior_scalar", &logprior_scalar, "Log prior function (scalar)");
}

Note: This implementation defines scalar likelihood and prior functions, then provides simple sequential batched versions. While the batching is sequential rather than parallel, this approach is still significantly faster than pure Python (or non-JIT compiled) implementations because it reduces the number of Python callbacks by a factor of num_delete (typically 50-100), which is the dominant cost for fast likelihoods. One could also define batched functions that process multiple parameter vectors in parallel (e.g., using std::execution::par, OpenMP, or SIMD instructions) if the likelihood computation is expensive enough to benefit from parallelization.

Save this as model.cpp in your working directory.

2. Create setup script#

# setup_model_cpp.py
from pybind11.setup_helpers import Pybind11Extension, build_ext
from setuptools import setup

ext_modules = [
    Pybind11Extension(
        "model",
        ["model.cpp"],
        cxx_std=11,
    ),
]

setup(
    name="model",
    ext_modules=ext_modules,
    cmdclass={"build_ext": build_ext},
    zip_safe=False,
    python_requires=">=3.7",
)

Save this as setup_model_cpp.py in your working directory.

3. Compile the C++ module#

Install pybind11 and compile the module:

pip install pybind11
python setup_model_cpp.py build_ext --inplace

This will create a model module that can be imported directly in Python.

4. Run nested sampling with C++ functions#

import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import model

rng_key = jax.random.PRNGKey(0)

def wrap_fn(fn, vmap_method='legacy_vectorized'):
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

algo = blackjax.nss(
    logprior_fn=wrap_fn(model.logprior),
    loglikelihood_fn=wrap_fn(model.loglikelihood),
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)