import torch
import arviz as az
from autoemulate.simulations.epidemic import Epidemic
from autoemulate.core.compare import AutoEmulate
from autoemulate.calibration.bayes import BayesianCalibration
from autoemulate.emulators import GaussianProcess
Bayesian calibration#
Bayesian calibration is a method for estimating which input parameters were most likely to produce observed data. An advantage over other calibration methods is that it returns a probability distribution over the input parameters rather than just point estimates.
Performing Bayesian calibration requires:
a fitted emulator
observations associated with the simulator output
1. Simulate data and train an emulator#
In this example, we’ll use the Epidemic
simulator, which returns the peak infection rate given two input parameters, beta
(the transimission rate per day) and gamma
(the recovery rate per day).
simulator = Epidemic(log_level="error")
x = simulator.sample_inputs(100)
y = simulator.forward_batch(x)
For the purposes of this tutorial, we will restrict the model choice to GaussianProcess
.
ae = AutoEmulate(x, y, models=[GaussianProcess], log_level="error")
We can verify that the fitted emulator performs well on both the train and test data.
ae.summarise()
model_name | x_transforms | y_transforms | params | rmse_test | r2_test | r2_test_std | r2_train | r2_train_std | |
---|---|---|---|---|---|---|---|---|---|
0 | GaussianProcess | [StandardizeTransform()] | [StandardizeTransform()] | {'mean_module_fn': <function zero_mean at 0x7f... | 0.079372 | 0.999064 | 0.000632 | 0.999963 | 0.000012 |
gp = ae.best_result().model
2. Calibrate#
Calibration requires at least one or multiple observations. These can come from running experiments or from the literature.
Below we pick the initial parameter values we want to infer and simulate the output. We then add noise to generate 10 “observations”.
true_beta = 0.2
true_gamma = 0.1
# simulator expects inputs of shape [1, number of inputs]
params = torch.tensor([true_beta, true_gamma]).view(1, -1)
true_infection_rate = simulator.forward(params)
n_obs = 10
noise = torch.normal(mean=0, std=0.01, size=(n_obs,))
observed_infection_rates = true_infection_rate[0] + noise
print("Observed infection rates:", observed_infection_rates.numpy().round(3))
Observed infection rates: [0.164 0.15 0.174 0.165 0.148 0.141 0.149 0.162 0.156 0.149]
We set up the BayesianCalibration
object with the trained emulator, the simulator parameter ranges and the “observed” data we simulated above. The underlying probabilistic model assumes the observations are drawn from a Gaussian distribution with the mean predicted by the emulator. We also have to specify the observation_noise
of this Gaussian likelihood.
observations = {"infection_rate": observed_infection_rates}
observation_noise = 0.01
bc = BayesianCalibration(
gp,
simulator.parameters_range,
observations,
observation_noise
)
Run MCMC using the NUTS sampler.
mcmc = bc.run_mcmc(
warmup_steps=250,
num_samples=1000,
sampler='nuts',
num_chains=2
)
The above returns the Pyro MCMC object which has a number of useful methods associated with it. One can access all the posterior samples using mcmc.get_samples()
or just the summary statistics using mcmc.summary()
.
mcmc.summary()
mean std median 5.0% 95.0% n_eff r_hat
beta 0.33 0.05 0.33 0.25 0.39 16.28 1.10
gamma 0.16 0.02 0.16 0.13 0.20 16.12 1.11
Number of divergences: 68
3. Plotting with Arviz#
The BayesianCalibrator.to_arviz
method converts the mcmc
object so that it is compatible with the Arviz plotting library. Using Arviz makes it very easy to produce all the standard plots of the calibration results as well as MCMC diagnostics.
az_data = bc.to_arviz(mcmc, posterior_predictive=True)
The posterior predictive mean and posterior predictive samples can be plotted alongside the observed data.
_ = az.plot_ppc(az_data, kind='scatter')

To visualize the posterior distribution, the samples from the posterior distribution can be viewed as a trace (right-hand plots) with 1D KDEs for each chain for each variable (left-hand plots).
_ = az.plot_trace(az_data)

The 2D KDE of the posterior distribution can also be visualized.
_ = az.plot_pair(az_data, kind='kde')

Finally, autocorrelation plots for each chain and each variable can be visualized to assess convergence of the MCMC chains.
_ = az.plot_autocorr(az_data)
