autoemulate.core.compare#

class AutoEmulate(x, y, models=None, x_transforms_list=None, y_transforms_list=None, model_tuning=True, model_params=None, transformed_emulator_params=None, only_pytorch=False, only_probabilistic=False, n_iter=10, n_splits=5, shuffle=True, n_bootstraps=100, max_retries=3, device=None, random_seed=None, log_level='progress_bar')[source]#

Bases: ConversionMixin, TorchDeviceMixin, Results

Automated emulator fitting.

The AutoEmulate class is the main class of the AutoEmulate package. It is used to set up and compare different emulator models on a given dataset. It can also be used to summarise and visualise results, and to save and load models.

static all_emulators()[source]#

Return a list of all available emulators.

static pytorch_emulators()[source]#

Return a list of all available PyTorch emulators.

static probablistic_emulators()[source]#

Return a list of all available probabilistic emulators.

static list_emulators()[source]#

Return a dataframe with info on all available emulators.

The dataframe includes the model name and whether it has a PyTorch backend, supports multioutput data and provides uncertainty quantification.

Returns:

DataFrame with columns:

[‘Emulator’, ‘PyTorch’, ‘Multioutput’, ‘Uncertainty_Quantification’].

Return type:

pd.DataFrame

get_models(models=None, only_pytorch=False, only_probabilistic=False)[source]#

Return a list of the model classes for comparisons.

Parameters:
  • models (list[type[Emulator] | str] | None) – List of model classes or names to use for comparison. If None, all available emulators are used (or subset based on only_pytorch and only_probabilistic).

  • only_pytorch (bool) – If True, only PyTorch emulators are returned. Defaults to False.

  • only_probabilistic (bool) – If True, only probabilistic emulators are returned. Defaults to False.

get_transforms(transforms)[source]#

Process and return a list of transforms.

filter_models_if_multioutput(models, warn)[source]#

Filter models to only include those that support multi-output data.

log_compare(best_model_name, x_transforms, y_transforms, best_params_for_this_model, r2_score, rmse_score)[source]#

Log the comparison results.

compare()[source]#

Compare different models on the provided dataset.

The method will: - Loop over all combinations of x and y transforms and models. - Set up the tuner with the training/validation data. - Tune hyperparameters for each model. - Fit the best model with the tuned hyperparameters. - Evaluate the performance of the best model on the test data. - Log the results. - Save the best model and its parameters.

fit_from_reinitialized(x, y, result_id=None, random_seed=None, transformed_emulator_params=None)[source]#

Fit a fresh model with reinitialized parameters using the best configuration.

This method creates a new model instance with the same configuration as the best (or specified) model from the comparison, but with freshly initialized parameters fitted on the provided data.

Parameters:
  • x (InputLike) – Input features for training the fresh model.

  • y (InputLike) – Target values for training the fresh model.

  • result_id (int | None) – The ID of the result to use. If None, uses the best model. Defaults to None.

  • random_seed (int | None) – Random seed for parameter initialization. Defaults to None.

  • transformed_emulator_params (None | TransformedEmulatorParams) – Parameters for the transformed emulator. When None, the same parameters as used when identifying the best model are used. Defaults to None.

Returns:

A new model instance with the same configuration but fresh parameters fitted on the provided data.

Return type:

TransformedEmulator

Notes

Unlike TransformedEmulator.refit() which retrains an existing model, this method creates a completely new model instance with reinitialized parameters. This ensures that when fitting on new data that the same initialization conditions are applied. This can have an affect for example given kernel initialization in Gaussian Processes or weight initialization in neural networks.

plot(model_obj, input_index=None, output_index=None, input_ranges=None, output_ranges=None, figsize=None, ncols=3, fname=None)[source]#

Plot the evaluation of the model with the given result_id.

Parameters:
  • model_obj (int | Emulator | Result) – The model to plot. Can be an integer ID of a Result, an Emulator instance, or a Result instance.

  • input_index (int) – The index of the input feature to plot against the output.

  • output_index (int) – The index of the output feature to plot against the input.

  • input_ranges (dict | None) – The ranges of the input features to consider for the plot. Ranges are combined such that the final subset is the intersection data within the specified ranges. Defaults to None.

  • output_ranges (dict | None) – The ranges of the output features to consider for the plot. Ranges are combined such that the final subset is the intersection data within the specified ranges. Defaults to None.

  • figsize (tuple[int, int] | None) – The size of the figure to create. If None, it is set based on the number of input and output features.

  • ncols (int) – The number of columns in the subplot grid. Defaults to 3.

  • fname (str | None) – If provided, the figure will be saved to this file path.

save(model_obj, path=None, use_timestamp=True)[source]#

Save model to disk.

Parameters:
  • model_obj (int | Emulator | Result) – The model to save. Can be an integer ID of a Result, an Emulator instance, or a Result instance.

  • path (str) – Path to save the model.

  • use_timestamp (bool) – If True, appends a timestamp to the filename to ensure uniqueness.

load(path)[source]#

Load a stored model or result from disk.

Parameters:

path (str) – Path to model.

Returns:

The loaded model or result object.

Return type:

Emulator | Result

static load_model(path)[source]#

Load a stored model directly from a given path.

Parameters:

path (str | Path) – Path to the model.

Returns:

The loaded model object.

Return type:

Emulator

Raises:

FileNotFoundError – If the model file does not exist.