R#
This example demonstrates how to use BlackJAX nested sampling with R implementations of likelihood and prior functions. The R code is called via rpy2 bridge, with JAX’s pure_callback providing the interface.
Prerequisites#
Install the required Python packages:
pip install git+https://github.com/handley-lab/blackjax
pip install rpy2 numpy tqdm
Install the required R package:
install.packages('mvtnorm')
Setup Instructions#
1. Create the R implementation#
Create a file model.R with your likelihood and prior functions:
library(mvtnorm)
loglikelihood <- function(theta) {
dmvnorm(theta, mean = rep(1, 5), sigma = 0.01 * diag(5), log = TRUE)
}
logprior <- function(theta) {
dmvnorm(theta, mean = rep(0, 5), sigma = diag(5), log = TRUE)
}
2. Run nested sampling with R functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import rpy2.robjects as ro
from rpy2.robjects import numpy2ri
from rpy2.robjects.conversion import localconverter
rng_key = jax.random.PRNGKey(0)
ro.r('source("model.R")')
loglikelihood_fn = ro.globalenv['loglikelihood']
logprior_fn = ro.globalenv['logprior']
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def numpy_wrapper(theta):
with localconverter(ro.default_converter + numpy2ri.converter):
return fn(np.asarray(theta))
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(logprior_fn),
loglikelihood_fn=wrap_fn(loglikelihood_fn),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)