TensorFlow#
This example demonstrates how to use BlackJAX nested sampling with TensorFlow Probability distributions. TensorFlow Probability provides a rich library of probability distributions and statistical tools that can be integrated with BlackJAX.
Prerequisites#
Install the required packages:
pip install git+https://github.com/handley-lab/blackjax
pip install tensorflow tensorflow-probability numpy tqdm
Run nested sampling with TensorFlow functions#
import jax
import jax.numpy as jnp
import blackjax
from blackjax.ns.utils import finalise
import tqdm
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
rng_key = jax.random.PRNGKey(0)
dtype = tf.float32
loglikelihood_fn = tfp.distributions.MultivariateNormalFullCovariance(tf.ones(5, dtype=dtype), 0.01 * tf.eye(5, dtype=dtype)).log_prob
logprior_fn = tfp.distributions.MultivariateNormalFullCovariance(tf.zeros(5, dtype=dtype), tf.eye(5, dtype=dtype)).log_prob
def wrap_fn(fn, vmap_method='legacy_vectorized'):
def numpy_wrapper(theta):
x = tf.convert_to_tensor(np.asarray(theta).copy(), dtype=dtype)
result = fn(x)
return result.numpy()
def jax_wrapper(x):
out_shape = jax.ShapeDtypeStruct(x.shape[:-1], x.dtype)
return jax.pure_callback(numpy_wrapper, out_shape, x, vmap_method=vmap_method)
return jax_wrapper
algo = blackjax.nss(
logprior_fn=wrap_fn(logprior_fn),
loglikelihood_fn=wrap_fn(loglikelihood_fn),
num_delete=50,
num_inner_steps=20,
)
rng_key, sampling_key, initialization_key = jax.random.split(rng_key, 3)
live = algo.init(jax.random.normal(initialization_key, (1000, 5)))
step = jax.jit(algo.step)
dead_points = []
with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
while (not live.logZ_live - live.logZ < -3):
rng_key, subkey = jax.random.split(rng_key)
live, dead = step(subkey, live)
dead_points.append(dead)
pbar.update(len(dead.particles))
ns_run = finalise(live, dead_points)