C#

This example demonstrates how to use BlackJAX nested sampling with C implementations of likelihood and prior functions. The C code is compiled to a shared library and accessed via Python’s ctypes library, with JAX’s pure_callback providing the bridge.

Prerequisites#

Install the required Python packages:

pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm

You’ll also need a C compiler (gcc) installed on your system.

Setup Instructions#

1. Create the C implementation#

First, create a file model.c with your likelihood and prior functions:

/* model.c */
#include <math.h>

static const double LOG_2PI = 1.8378770664093454;

// Sequential (scalar) function
static double loglikelihood_scalar(const double* theta, int d) {
    const double inv_var = 1.0 / 0.01;
    const double log_det = d * log(0.01);
    const double mu = 1.0;
    
    double q = 0.0;
    for (int i = 0; i < d; i++) {
        double diff = theta[i] - mu;
        q += diff * diff * inv_var;
    }
    return -0.5 * (d * LOG_2PI + log_det + q);
}

static double logprior_scalar(const double* theta, int d) {
    double q = 0.0;
    for (int i = 0; i < d; i++) {
        q += theta[i] * theta[i];
    }
    return -0.5 * (d * LOG_2PI + q);
}

// Batched wrappers
void loglikelihood(const double* theta, double* result, int batch, int d) {
    for (int b = 0; b < batch; b++) {
        result[b] = loglikelihood_scalar(theta + b * d, d);
    }
}

void logprior(const double* theta, double* result, int batch, int d) {
    for (int b = 0; b < batch; b++) {
        result[b] = logprior_scalar(theta + b * d, d);
    }
}

Note: This implementation defines scalar likelihood and prior functions, then provides simple sequential batched versions. While the batching is sequential rather than parallel, this approach is still significantly faster than pure Python (or non-JIT compiled) implementations because it reduces the number of Python callbacks by a factor of num_delete (typically 50-100), which is the dominant cost for fast likelihoods. One could also define batched functions that process multiple parameter vectors in parallel (e.g., using OpenMP or SIMD instructions) if the likelihood computation is expensive enough to benefit from parallelization.

Save this as model.c in your working directory.

2. Compile the C library#

gcc -shared -fPIC -O3 -o libmodel.so model.c -lm

3. Create the Python interface#

Create model.py to interface with the compiled C library:

# model.py
import ctypes
import numpy as np
from numpy.ctypeslib import ndpointer

lib = ctypes.CDLL("./libmodel.so")

lib.loglikelihood.argtypes = [
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ctypes.c_int,
    ctypes.c_int
]
lib.loglikelihood.restype = None

lib.logprior.argtypes = [
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ctypes.c_int,
    ctypes.c_int
]
lib.logprior.restype = None

def loglikelihood(theta):
    theta = np.ascontiguousarray(theta, dtype=np.float64)
    batch, d = theta.shape
    result = np.empty(batch, dtype=np.float64)
    lib.loglikelihood(theta, result, batch, d)
    return result

def logprior(theta):
    theta = np.ascontiguousarray(theta, dtype=np.float64)
    batch, d = theta.shape
    result = np.empty(batch, dtype=np.float64)
    lib.logprior(theta, result, batch, d)
    return result

Save this as model.py in your working directory.

4. Run nested sampling with C functions#

import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import model  

rng_key = jax.random.PRNGKey(0)

def wrap_fn(fn, vmap_method='legacy_vectorized'):
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

algo = blackjax.nss(
    logprior_fn=wrap_fn(model.logprior),
    loglikelihood_fn=wrap_fn(model.loglikelihood),
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)