Adding emulators#
In addition to providing a library of core emulators, AutoEmulate is designed to be easily extensible. This tutorial walks you through the steps of adding new emulators to the library. We cover two scenarios: adding new Gaussian Process kernels and adding entirely new models.
1. Adding Gaussian Process kernels#
Gaussian Processes (GPs) are primarily defined by their kernel functions, which determine the covariance structure of the data. AutoEmulate includes several built-in GP kernels:
Radial Basis Function (RBF)
Matern 3/2
Matern 5/2
Rational Quadratic (RQ)
Linear
You can easily create new kernels by composing any two or more of these existing kernels. For example, you might want to create a kernel that combines the RBF and Linear kernels to capture both smooth variations and linear trends in your data.
In AutoEmulate, each kernel is defined by an initialisation function that takes as inputs the number of data input features and the number of output features. Below we define a custom kernel function following this pattern.
from autoemulate.emulators.gaussian_process.kernel import rbf_kernel, linear_kernel
def rbs_plus_linear_kernel(n_features, n_outputs):
"""
Example of a custom kernel function that combines RBF and linear kernels.
"""
return rbf_kernel(n_features, n_outputs) + linear_kernel(n_features, n_outputs)
Once this function has been defined, you can create a new GP emulator class using the create_gp_subclass function.
from autoemulate.emulators.gaussian_process.exact import GaussianProcess, create_gp_subclass
GaussianProcessRBFandLinear = create_gp_subclass(
"GaussianProcessRBFandLinear",
GaussianProcess,
# the custom kernel function goes here
covar_module_fn=rbs_plus_linear_kernel,
auto_register=True,
overwrite=True,
)
Now we can tell AutoEmulate to use the new GP class by passing it to the models argument when initialising an AutoEmulate object.
from autoemulate import AutoEmulate
import torch
# create some example data
x = torch.linspace(0, 1, 100).unsqueeze(-1)
y = torch.sin(2 * 3.14 * x) + 0.1 * torch.randn_like(x)
ae = AutoEmulate(x, y, models=[GaussianProcessRBFandLinear])
Comparing models: 0%| | 0.00/1.00 [00:00<?, ?model/s]
Comparing models: 100%|██████████| 1.00/1.00 [00:30<00:00, 30.7s/model]
Comparing models: 100%|██████████| 1.00/1.00 [00:30<00:00, 30.7s/model]
ae.summarise()
| model_name | x_transforms | y_transforms | params | r2_test | r2_test_std | rmse_test | rmse_test_std | r2_train | r2_train_std | rmse_train | rmse_train_std | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | GaussianProcessRBFandLinear | [StandardizeTransform()] | [StandardizeTransform()] | {'epochs': 100, 'lr': 0.5, 'likelihood_cls': <... | 0.972941 | 0.014803 | 0.098698 | 0.015649 | 0.987079 | 0.001874 | 0.078599 | 0.005709 |
2. Adding new models#
It is also possible to add entirely new models to AutoEmulate. AutoEmulate has a base Emulator class that handles most of the general functionality required for training and prediction. To implement a new emulator, one must simply subclass Emulator and implement the abstract methods (_fit, _predict and is_multioutput), get_tune_params to enable model tuning, as well any model specific functionality and initialisations.
Since AutoEmulate supports a variety of models, there are additional Emulator subclasses that handle specific functionality for each model type:
PytorchBackendfor PyTorch modelsSklearnBackendfor scikit-learn modelsGaussianProcessfor exact Gaussian Process implementationsEnsemblefor ensemble models
Subclassing one of these directly has slightly different requirements. For example, when subclassing PytorchBackend or GaussianProcess, one must implement the forward method to define the model’s forward pass.
There are also some static methods that should be implemented to provide metadata about the model, such as is_multioutput and get_tune_params.
Below demonstrates adding a simple feedforward neural network (FNN) using PyTorch. The new class SimpleFNN subclasses PytorchBackend, which already handles fitting and prediction.
from autoemulate.core.device import TorchDeviceMixin
from autoemulate.emulators.base import PyTorchBackend
from autoemulate.emulators import register
import torch.nn as nn
@register(overwrite=True) # decorate to register the emulator, optionally overwrite existing registration
class SimpleFNN(PyTorchBackend):
def __init__(
self,
x,
y,
hidden_dim=64,
device = None,
):
TorchDeviceMixin.__init__(self, device=device)
nn.Module.__init__(self)
input_dim = x.shape[1]
output_dim = y.shape[1] if len(y.shape) > 1 else 1
layers = []
layers.append(nn.Linear(input_dim, hidden_dim, device=self.device))
layers.append(nn.ReLU())
layers.append(nn.Linear(hidden_dim, output_dim, device=self.device))
self.model = nn.Sequential(*layers)
self.optimizer = self.optimizer_cls(self.model.parameters(), lr=self.lr) # type: ignore[call-arg] since all optimizers include lr
self.scheduler = None
self.to(self.device)
def forward(self, x):
return self.model(x)
@staticmethod
def is_multioutput():
return True
@staticmethod
def get_tune_params():
return {
"hidden_dim": [32, 64, 128]
}
ae = AutoEmulate(x, y, models=[SimpleFNN])
Comparing models: 0%| | 0.00/1.00 [00:00<?, ?model/s]
Comparing models: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.87s/model]
Comparing models: 100%|██████████| 1.00/1.00 [00:01<00:00, 1.87s/model]
ae.summarise()
| model_name | x_transforms | y_transforms | params | r2_test | r2_test_std | rmse_test | rmse_test_std | r2_train | r2_train_std | rmse_train | rmse_train_std | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | SimpleFNN | [StandardizeTransform()] | [StandardizeTransform()] | {'hidden_dim': 32} | 0.642511 | 0.117467 | 0.353908 | 0.044417 | 0.7763 | 0.043331 | 0.336149 | 0.02812 |
The emulator can also be reinitialized and fitted on a dataset having been registered with AutoEmulate using the @register decorator.
em = ae.fit_from_reinitialized(x, y)
print(em.model)
SimpleFNN(
(model): Sequential(
(0): Linear(in_features=1, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=1, bias=True)
)
)