Julia#

This example demonstrates how to use BlackJAX nested sampling with Julia implementations of likelihood and prior functions. Julia code is called via JSON-RPC in a separate process to avoid threading conflicts.

Prerequisites#

Install the required Python packages:

pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm

Install Julia and the required Julia packages:

using Pkg
Pkg.add("LinearAlgebra")
Pkg.add("Distributions")
Pkg.add("JSON")

Setup Instructions#

1. Create the model file#

Create a file model.jl with your likelihood and prior functions:

using LinearAlgebra
using Distributions

function loglikelihood(theta)
    theta = convert(Matrix{Float64}, theta)
    dist = MvNormal(ones(5), 0.01 * I(5))
    return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end

function logprior(theta)
    theta = convert(Matrix{Float64}, theta)
    dist = MvNormal(zeros(5), I(5))
    return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end

2. Create the RPC server#

Create a file julia_server.jl to handle RPC communication:

using JSON
using Base64

# Include the model functions
include("model.jl")

# Simple request/response loop
while true
    try
        line = readline()
        if isempty(line)
            break
        end
        
        request = JSON.parse(line)
        func_name = request["function"]
        
        # Decode base64 numpy array
        theta_bytes = base64decode(request["data"])
        theta = reinterpret(Float64, theta_bytes)
        rows = request["rows"]
        cols = request["cols"]
        theta = reshape(theta, cols, rows)'  # Transpose for column-major
        
        # Call function
        if func_name == "loglikelihood"
            result = loglikelihood(theta)
        elseif func_name == "logprior"
            result = logprior(theta)
        else
            error("Unknown function: $func_name")
        end
        
        # Send response
        response = Dict("result" => result)
        println(JSON.json(response))
        flush(stdout)
    catch e
        println(JSON.json(Dict("error" => string(e))))
        flush(stdout)
    end
end

3. Run nested sampling with Julia functions#

import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import subprocess
import json
import base64

julia_proc = subprocess.Popen(
    ['julia', 'julia_server.jl'],
    stdin=subprocess.PIPE,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True,
    bufsize=1
)

def call_julia_rpc(func_name, theta):
    theta = np.ascontiguousarray(theta, dtype=np.float64)
    
    theta_bytes = theta.tobytes()
    theta_b64 = base64.b64encode(theta_bytes).decode('ascii')
    
    request = {
        'function': func_name,
        'data': theta_b64,
        'rows': theta.shape[0],
        'cols': theta.shape[1]
    }
    julia_proc.stdin.write(json.dumps(request) + '\n')
    julia_proc.stdin.flush()
    
    response = json.loads(julia_proc.stdout.readline())
    if 'error' in response:
        raise RuntimeError(f"Julia error: {response['error']}")
    
    return np.array(response['result'])

def wrap_fn(func_name, vmap_method='legacy_vectorized'):
    def numpy_wrapper(theta):
        return call_julia_rpc(func_name, np.asarray(theta))
    
    def jax_wrapper(x):
        out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
        return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
    
    return jax_wrapper

rng_key = jax.random.PRNGKey(0)

print("Testing Julia RPC...")
test_theta = np.ones((5, 5))
result = call_julia_rpc('loglikelihood', test_theta)
print(f"Test successful: {result[:3]}")

algo = blackjax.nss(
    logprior_fn=wrap_fn('logprior'),
    loglikelihood_fn=wrap_fn('loglikelihood'),
    num_delete=50,
    num_inner_steps=20,
)

rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)

dead_points = []

with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
    while (not live.logZ_live - live.logZ < -3):
        rng_key, subkey = jax.random.split(rng_key)
        live, dead = step(subkey, live)
        dead_points.append(dead)
        pbar.update(len(dead.particles))

ns_run = finalise(live, dead_points)

julia_proc.terminate()