Fortran#
This example demonstrates how to use BlackJAX nested sampling with Fortran implementations of likelihood and prior functions. The Fortran code is compiled to a shared library and accessed via Python’s ctypes library, with JAX’s pure_callback providing the bridge.
Prerequisites#
Install the required Python packages:
pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm
You’ll also need a Fortran compiler (gfortran) installed on your system.
Setup Instructions#
1. Create the Fortran implementation#
First, create a file model.f90 with your likelihood and prior functions:
! model.f90
module model_mod
use iso_c_binding
implicit none
real(c_double), parameter :: LOG_2PI = 1.8378770664093454d0
contains
! Sequential (scalar) functions
function loglikelihood_scalar(theta, d) result(logp)
real(c_double), intent(in) :: theta(*)
integer(c_int), intent(in), value :: d
real(c_double) :: logp
real(c_double) :: inv_var, log_det, mu, q, diff
integer :: i
inv_var = 1.0d0 / 0.01d0
log_det = d * log(0.01d0)
mu = 1.0d0
q = 0.0d0
do i = 1, d
diff = theta(i) - mu
q = q + diff * diff * inv_var
end do
logp = -0.5d0 * (d * LOG_2PI + log_det + q)
end function loglikelihood_scalar
function logprior_scalar(theta, d) result(logp)
real(c_double), intent(in) :: theta(*)
integer(c_int), intent(in), value :: d
real(c_double) :: logp
real(c_double) :: q
integer :: i
q = 0.0d0
do i = 1, d
q = q + theta(i) * theta(i)
end do
logp = -0.5d0 * (d * LOG_2PI + q)
end function logprior_scalar
end module model_mod
! C-compatible wrapper functions (expecting C row-major order)
subroutine loglikelihood(theta, result, batch, d) bind(c)
use iso_c_binding
use model_mod
implicit none
integer(c_int), intent(in), value :: batch, d
real(c_double), intent(in) :: theta(batch * d)
real(c_double), intent(out) :: result(batch)
integer :: b, offset
do b = 1, batch
offset = (b - 1) * d + 1
result(b) = loglikelihood_scalar(theta(offset), d)
end do
end subroutine loglikelihood
subroutine logprior(theta, result, batch, d) bind(c)
use iso_c_binding
use model_mod
implicit none
integer(c_int), intent(in), value :: batch, d
real(c_double), intent(in) :: theta(batch * d)
real(c_double), intent(out) :: result(batch)
integer :: b, offset
do b = 1, batch
offset = (b - 1) * d + 1
result(b) = logprior_scalar(theta(offset), d)
end do
end subroutine logprior
Note: This implementation defines scalar likelihood and prior functions, then provides simple sequential batched versions. While the batching is sequential rather than parallel, this approach is still significantly faster than pure Python (or non-JIT compiled) implementations because it reduces the number of Python callbacks by a factor of num_delete (typically 50-100), which is the dominant cost for fast likelihoods. One could also define batched functions that process multiple parameter vectors in parallel (e.g., using OpenMP directives or coarray features) if the likelihood computation is expensive enough to benefit from parallelization.
Save this as model.f90 in your working directory.
2. Compile the Fortran library#
gfortran -shared -fPIC -O3 -o libmodel.so model.f90
3. Create the Python interface#
Create model.py to interface with the compiled Fortran library:
# model.py
import ctypes
import numpy as np
from numpy.ctypeslib import ndpointer
lib = ctypes.CDLL("./libmodel.so")
lib.loglikelihood.argtypes = [
ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
ctypes.c_int,
ctypes.c_int
]
lib.loglikelihood.restype = None
lib.logprior.argtypes = [
ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
ctypes.c_int,
ctypes.c_int
]
lib.logprior.restype = None
def loglikelihood(theta):
theta = np.ascontiguousarray(theta, dtype=np.float64)
batch, d = theta.shape
result = np.empty(batch, dtype=np.float64)
lib.loglikelihood(theta, result, batch, d)
return result
def logprior(theta):
theta = np.ascontiguousarray(theta, dtype=np.float64)
batch, d = theta.shape
result = np.empty(batch, dtype=np.float64)
lib.logprior(theta, result, batch, d)
return result
Save this as model.py in your working directory.
4. Run nested sampling with Fortran functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import model
rng_key = jax.random.PRNGKey(0)
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(model.logprior),
loglikelihood_fn=wrap_fn(model.loglikelihood),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)