Fortran#

This example demonstrates how to use BlackJAX nested sampling with Fortran implementations of likelihood and prior functions. The Fortran code is compiled to a shared library and accessed via Python’s ctypes library, with JAX’s pure_callback providing the bridge.

Prerequisites#

Install the required Python packages:

pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm

You’ll also need a Fortran compiler (gfortran) installed on your system.

Setup Instructions#

1. Create the Fortran implementation#

First, create a file model.f90 with your likelihood and prior functions:

! model.f90
module model_mod
    use iso_c_binding
    implicit none
    real(c_double), parameter :: LOG_2PI = 1.8378770664093454d0
    
contains
    ! Sequential (scalar) functions
    function loglikelihood_scalar(theta, d) result(logp)
        real(c_double), intent(in) :: theta(*)
        integer(c_int), intent(in), value :: d
        real(c_double) :: logp
        
        real(c_double) :: inv_var, log_det, mu, q, diff
        integer :: i
        
        inv_var = 1.0d0 / 0.01d0
        log_det = d * log(0.01d0)
        mu = 1.0d0
        
        q = 0.0d0
        do i = 1, d
            diff = theta(i) - mu
            q = q + diff * diff * inv_var
        end do
        
        logp = -0.5d0 * (d * LOG_2PI + log_det + q)
    end function loglikelihood_scalar
    
    function logprior_scalar(theta, d) result(logp)
        real(c_double), intent(in) :: theta(*)
        integer(c_int), intent(in), value :: d
        real(c_double) :: logp
        
        real(c_double) :: q
        integer :: i
        
        q = 0.0d0
        do i = 1, d
            q = q + theta(i) * theta(i)
        end do
        
        logp = -0.5d0 * (d * LOG_2PI + q)
    end function logprior_scalar
    
end module model_mod

! C-compatible wrapper functions (expecting C row-major order)
subroutine loglikelihood(theta, result, batch, d) bind(c)
    use iso_c_binding
    use model_mod
    implicit none
    
    integer(c_int), intent(in), value :: batch, d
    real(c_double), intent(in) :: theta(batch * d)
    real(c_double), intent(out) :: result(batch)
    
    integer :: b, offset
    
    do b = 1, batch
        offset = (b - 1) * d + 1
        result(b) = loglikelihood_scalar(theta(offset), d)
    end do
end subroutine loglikelihood

subroutine logprior(theta, result, batch, d) bind(c)
    use iso_c_binding
    use model_mod
    implicit none
    
    integer(c_int), intent(in), value :: batch, d
    real(c_double), intent(in) :: theta(batch * d)
    real(c_double), intent(out) :: result(batch)
    
    integer :: b, offset
    
    do b = 1, batch
        offset = (b - 1) * d + 1
        result(b) = logprior_scalar(theta(offset), d)
    end do
end subroutine logprior

Note: This implementation defines scalar likelihood and prior functions, then provides simple sequential batched versions. While the batching is sequential rather than parallel, this approach is still significantly faster than pure Python (or non-JIT compiled) implementations because it reduces the number of Python callbacks by a factor of num_delete (typically 50-100), which is the dominant cost for fast likelihoods. One could also define batched functions that process multiple parameter vectors in parallel (e.g., using OpenMP directives or coarray features) if the likelihood computation is expensive enough to benefit from parallelization.

Save this as model.f90 in your working directory.

2. Compile the Fortran library#

gfortran -shared -fPIC -O3 -o libmodel.so model.f90

3. Create the Python interface#

Create model.py to interface with the compiled Fortran library:

# model.py
import ctypes
import numpy as np
from numpy.ctypeslib import ndpointer

lib = ctypes.CDLL("./libmodel.so")

lib.loglikelihood.argtypes = [
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ctypes.c_int,
    ctypes.c_int
]
lib.loglikelihood.restype = None

lib.logprior.argtypes = [
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ndpointer(ctypes.c_double, flags="C_CONTIGUOUS"),
    ctypes.c_int,
    ctypes.c_int
]
lib.logprior.restype = None

def loglikelihood(theta):
    theta = np.ascontiguousarray(theta, dtype=np.float64)
    batch, d = theta.shape
    result = np.empty(batch, dtype=np.float64)
    lib.loglikelihood(theta, result, batch, d)
    return result

def logprior(theta):
    theta = np.ascontiguousarray(theta, dtype=np.float64)
    batch, d = theta.shape
    result = np.empty(batch, dtype=np.float64)
    lib.logprior(theta, result, batch, d)
    return result

Save this as model.py in your working directory.

4. Run nested sampling with Fortran functions#

import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import model  

rng_key = jax.random.PRNGKey(0)

def wrap_fn(fn, vmap_method='legacy_vectorized'):
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(fn, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

algo = blackjax.nss(
    logprior_fn=wrap_fn(model.logprior),
    loglikelihood_fn=wrap_fn(model.loglikelihood),
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)