JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs

Our latest work, presented in the paper JAX-bandflux: differentiable supernovae SALT modelling for cosmological analysis on GPUs, introduces a powerful new tool for supernova cosmology. This project, led by group member Samuel Alan Kossoff Leeney, accelerates and modernizes the critical task of modelling Type Ia supernovae light curves by leveraging the full power of the JAX framework and GPU computing.
The Cosmological Context: Standard Candles in the Modern Era
Type Ia supernovae are a cornerstone of modern cosmology, serving as “standard candles” that enable us to measure cosmic distances and map the expansion history of the Universe. The precision of these measurements hinges on accurately modelling their observed brightness over time—their light curves. For years, the astrophysics community has relied on robust and well-vetted tools like SNCosmo for this purpose. However, with the advent of massive datasets from next-generation surveys and the concurrent revolution in machine learning, there is a pressing need for cosmological tools that can fully exploit modern hardware and algorithmic paradigms like automatic differentiation.
JAX-bandflux: A Differentiable, GPU-Native Solution
This is precisely the gap that JAX-bandflux is designed to fill. It provides a JAX-native re-implementation of the core functionalities for modelling supernova flux, making the entire analysis pipeline differentiable and massively parallelizable on GPUs. This leap in performance and capability unlocks new efficiencies for cosmological analyses, allowing for faster and more sophisticated inference.
At its heart, JAX-bandflux computes the “bandflux”—the integrated model flux across a specific observational filter—using the widely adopted Spectral Adaptive Lightcurve Template (SALT) model. The model flux is described by:
\(F(p, \lambda) = x_0 \left[ M_0(p, \lambda) + x_1 M_1(p, \lambda) + \ldots \right] \times \exp \left[ c \times CL(\lambda) \right]\)
The ultimate goal is to infer the parameters ($x_0, x_1, t_0, c$) that best fit the observational data for each supernova. By implementing this entire pipeline in JAX (Bradbury et al., 2018), we can propagate gradients through the whole model. This enables the use of highly efficient gradient-based optimisation techniques and facilitates robust parameter inference with advanced sampling methods.
Implementation and Key Features
The software is engineered for both performance and practical usability, maintaining functional parity with key components of SNCosmo while introducing major enhancements.
- Performance: Key operations are just-in-time (JIT) compiled for optimal execution speed, and JAX’s
vmapfunction is used to effortlessly parallelize computations across thousands of supernovae simultaneously on a single GPU. - Modern Models: The package supports the latest SALT models, including SALT3 (10.3847/1538-4357/ac2ea3) and its crucial near-infrared extension, SALT3-NIR (10.3847/1538-4357/ac9229), which is vital for next-generation cosmological measurements.
- Usability: It includes comprehensive utilities for handling astronomical bandpass filters and loading supernova data, ensuring a smooth workflow from raw observations to cosmological parameter estimation.
JAX-bandflux represents a significant step forward in our analytical toolkit. By building a bridge between established astrophysical models and the modern differentiable programming ecosystem, this work by Samuel Alan Kossoff Leeney paves the way for faster, more scalable, and more sophisticated analyses of the vast datasets from upcoming surveys, exemplifying our group’s commitment to developing cutting-edge computational methods to unlock the secrets of the cosmos.

Content generated by gemini-2.5-pro using this prompt.
Image generated by imagen-4.0-generate-001 using this prompt.