import torch

import arviz as az

from autoemulate.simulations.epidemic import Epidemic
from autoemulate.core.compare import AutoEmulate
from autoemulate.calibration.bayes import BayesianCalibration
from autoemulate.emulators import GaussianProcess

Bayesian calibration#

Bayesian calibration is a method for estimating which input parameters were most likely to produce observed data. An advantage over other calibration methods is that it returns a probability distribution over the input parameters rather than just point estimates.

Performing Bayesian calibration requires:

  • a fitted emulator

  • observations associated with the simulator output

1. Simulate data and train an emulator#

In this example, we’ll use the Epidemic simulator, which returns the peak infection rate given two input parameters, beta(the transimission rate per day) and gamma (the recovery rate per day).

simulator = Epidemic(log_level="error")
x = simulator.sample_inputs(100)
y = simulator.forward_batch(x)

For the purposes of this tutorial, we will restrict the model choice to GaussianProcess.

ae = AutoEmulate(x, y, models=[GaussianProcess], log_level="error")

We can verify that the fitted emulator performs well on both the train and test data.

ae.summarise()
model_name x_transforms y_transforms params rmse_test r2_test r2_test_std r2_train r2_train_std
0 GaussianProcess [StandardizeTransform()] [StandardizeTransform()] {'mean_module_fn': <function zero_mean at 0x7f... 0.079372 0.999064 0.000632 0.999963 0.000012
gp = ae.best_result().model

2. Calibrate#

Calibration requires at least one or multiple observations. These can come from running experiments or from the literature.

Below we pick the initial parameter values we want to infer and simulate the output. We then add noise to generate 10 “observations”.

true_beta = 0.2
true_gamma = 0.1 

# simulator expects inputs of shape [1, number of inputs]
params = torch.tensor([true_beta, true_gamma]).view(1, -1)
true_infection_rate = simulator.forward(params)

n_obs = 10
noise = torch.normal(mean=0, std=0.01, size=(n_obs,))
observed_infection_rates = true_infection_rate[0] + noise

print("Observed infection rates:", observed_infection_rates.numpy().round(3))
Observed infection rates: [0.164 0.15  0.174 0.165 0.148 0.141 0.149 0.162 0.156 0.149]

We set up the BayesianCalibration object with the trained emulator, the simulator parameter ranges and the “observed” data we simulated above. The underlying probabilistic model assumes the observations are drawn from a Gaussian distribution with the mean predicted by the emulator. We also have to specify the observation_noise of this Gaussian likelihood.

observations = {"infection_rate": observed_infection_rates}
observation_noise = 0.01

bc = BayesianCalibration(
    gp, 
    simulator.parameters_range, 
    observations, 
    observation_noise
)

Run MCMC using the NUTS sampler.

mcmc = bc.run_mcmc(
    warmup_steps=250, 
    num_samples=1000,
    sampler='nuts',
    num_chains=2
)

The above returns the Pyro MCMC object which has a number of useful methods associated with it. One can access all the posterior samples using mcmc.get_samples() or just the summary statistics using mcmc.summary().

mcmc.summary()
                mean       std    median      5.0%     95.0%     n_eff     r_hat
      beta      0.33      0.05      0.33      0.25      0.39     16.28      1.10
     gamma      0.16      0.02      0.16      0.13      0.20     16.12      1.11

Number of divergences: 68

3. Plotting with Arviz#

The BayesianCalibrator.to_arviz method converts the mcmc object so that it is compatible with the Arviz plotting library. Using Arviz makes it very easy to produce all the standard plots of the calibration results as well as MCMC diagnostics.

az_data = bc.to_arviz(mcmc, posterior_predictive=True)

The posterior predictive mean and posterior predictive samples can be plotted alongside the observed data.

_ = az.plot_ppc(az_data, kind='scatter')
../../_images/2ad3de1d5e5c546cb42498c2cb6dd2b365aba0efaafde1ecbc8720fd336496b4.png

To visualize the posterior distribution, the samples from the posterior distribution can be viewed as a trace (right-hand plots) with 1D KDEs for each chain for each variable (left-hand plots).

_ = az.plot_trace(az_data)
../../_images/08fa6f452cc6f842a2ee345adf265fb6d765e3613532d48c033cc4fed494847c.png

The 2D KDE of the posterior distribution can also be visualized.

_ = az.plot_pair(az_data, kind='kde')
../../_images/23725d3681749e5a25628fd08ef5197f98c815c93a35b10d58bfadbf236dc031.png

Finally, autocorrelation plots for each chain and each variable can be visualized to assess convergence of the MCMC chains.

_ = az.plot_autocorr(az_data)
../../_images/eab0c6183790e8f90228f4a450098f24d507d724da1489c7dadc3b9d40dfe792.png