Julia#
This example demonstrates how to use BlackJAX nested sampling with Julia implementations of likelihood and prior functions. Julia code is called via JSON-RPC in a separate process to avoid threading conflicts.
Prerequisites#
Install the required Python packages:
pip install git+https://github.com/handley-lab/blackjax
pip install numpy tqdm
Install Julia and the required Julia packages:
using Pkg
Pkg.add("LinearAlgebra")
Pkg.add("Distributions")
Pkg.add("JSON")
Setup Instructions#
1. Create the model file#
Create a file model.jl with your likelihood and prior functions:
using LinearAlgebra
using Distributions
function loglikelihood(theta)
theta = convert(Matrix{Float64}, theta)
dist = MvNormal(ones(5), 0.01 * I(5))
return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end
function logprior(theta)
theta = convert(Matrix{Float64}, theta)
dist = MvNormal(zeros(5), I(5))
return [logpdf(dist, theta[i, :]) for i in 1:size(theta, 1)]
end
2. Create the RPC server#
Create a file julia_server.jl to handle RPC communication:
using JSON
using Base64
# Include the model functions
include("model.jl")
# Simple request/response loop
while true
try
line = readline()
if isempty(line)
break
end
request = JSON.parse(line)
func_name = request["function"]
# Decode base64 numpy array
theta_bytes = base64decode(request["data"])
theta = reinterpret(Float64, theta_bytes)
rows = request["rows"]
cols = request["cols"]
theta = reshape(theta, cols, rows)' # Transpose for column-major
# Call function
if func_name == "loglikelihood"
result = loglikelihood(theta)
elseif func_name == "logprior"
result = logprior(theta)
else
error("Unknown function: $func_name")
end
# Send response
response = Dict("result" => result)
println(JSON.json(response))
flush(stdout)
catch e
println(JSON.json(Dict("error" => string(e))))
flush(stdout)
end
end
3. Run nested sampling with Julia functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import subprocess
import json
import base64
julia_proc = subprocess.Popen(
['julia', 'julia_server.jl'],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1
)
def call_julia_rpc(func_name, theta):
theta = np.ascontiguousarray(theta, dtype=np.float64)
theta_bytes = theta.tobytes()
theta_b64 = base64.b64encode(theta_bytes).decode('ascii')
request = {
'function': func_name,
'data': theta_b64,
'rows': theta.shape[0],
'cols': theta.shape[1]
}
julia_proc.stdin.write(json.dumps(request) + '\n')
julia_proc.stdin.flush()
response = json.loads(julia_proc.stdout.readline())
if 'error' in response:
raise RuntimeError(f"Julia error: {response['error']}")
return np.array(response['result'])
def wrap_fn(func_name, vmap_method='legacy_vectorized'):
def numpy_wrapper(theta):
return call_julia_rpc(func_name, np.asarray(theta))
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
rng_key = jax.random.PRNGKey(0)
print("Testing Julia RPC...")
test_theta = np.ones((5, 5))
result = call_julia_rpc('loglikelihood', test_theta)
print(f"Test successful: {result[:3]}")
algo = blackjax.nss(
logprior_fn=wrap_fn('logprior'),
loglikelihood_fn=wrap_fn('loglikelihood'),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)
julia_proc.terminate()