deepsensor.model.model#
- class DeepSensorModel(data_processor=None, task_loader=None)[source]#
Bases:
ProbabilisticModel
Implements DeepSensor prediction functionality of a ProbabilisticModel. Allows for outputting an xarray object containing on-grid predictions or a pandas object containing off-grid predictions.
- Parameters:
data_processor (
DataProcessor
) – DataProcessor object, used to unnormalise predictions.task_loader (
TaskLoader
) – TaskLoader object, used to determine target variables for unnormalising.
- N_mixture_components = 1#
- predict(tasks, X_t, X_t_mask=None, X_t_is_normalised=False, aux_at_targets_override=None, aux_at_targets_override_is_normalised=False, resolution_factor=1, pred_params=('mean', 'std'), n_samples=0, ar_sample=False, ar_subsample_factor=1, unnormalise=True, seed=0, append_indexes=None, progress_bar=0, verbose=False)[source]#
Predict on a regular grid or at off-grid locations.
- Parameters:
tasks (List[Task] | Task) – List of tasks containing context data.
X_t (
xarray.Dataset
|xarray.DataArray
|pandas.DataFrame
|pandas.Series
|pandas.Index
|numpy.ndarray
) – Target locations to predict at. Can be an xarray object containingon-grid locations or a pandas object containing off-grid locations.X_t_mask –
xarray.Dataset
|xarray.DataArray
, optional 2D mask to apply to griddedX_t
(zero/False will be NaNs). Will be interpolated to the same grid asX_t
. Default None (no mask).X_t_is_normalised (bool) – Whether the
X_t
coords are normalised. If False, will normalise the coords before passing to model. DefaultFalse
.aux_at_targets_override (
xarray.Dataset
|xarray.DataArray
) – Optional auxiliary xarray data to override from the task_loader.aux_at_targets_override_is_normalised (bool) – Whether the aux_at_targets_override coords are normalised. If False, the DataProcessor will normalise the coords before passing to model. Default False.
pred_params (Tuple[str]) – Tuple of prediction parameters to return. The strings refer to methods of the model class which will be called and stored in the Prediction object. Default (“mean”, “std”).
resolution_factor (float) – Optional factor to increase the resolution of the target grid by. E.g. 2 will double the target resolution, 0.5 will halve it.Applies to on-grid predictions only. Default 1.
n_samples (int) – Number of joint samples to draw from the model. If 0, will not draw samples. Default 0.
ar_sample (bool) – Whether to use autoregressive sampling. Default
False
.unnormalise (bool) – Whether to unnormalise the predictions. Only works if
self
hasadata_processor
andtask_loader
attribute. DefaultTrue
.seed (int) – Random seed for deterministic sampling. Default 0.
append_indexes (dict) – Dictionary of index metadata to append to pandas indexes in the off-grid case. Default
None
.progress_bar (int) – Whether to display a progress bar over tasks. Default 0.
verbose (bool) – Whether to print time taken for prediction. Default
False
.
- Returns:
Prediction
) – A dict-like object mapping from target variable IDs to xarray or pandas objects containing model predictions. - IfX_t
is a pandas object, returns pandas objects containing off-grid predictions. - IfX_t
is an xarray object, returns xarray object containing on-grid predictions. - Ifn_samples
== 0, returns only mean and std predictions. - Ifn_samples
> 0, returns mean, std and samples predictions.- Raises:
ValueError – If
X_t
is not an xarray object andresolution_factor
is not 1 orar_subsample_factor
is not 1.ValueError – If
X_t
is not a pandas object andappend_indexes
is notNone
.ValueError – If
X_t
is not an xarray, pandas or numpy object.ValueError – If
append_indexes
are not all the same length asX_t
.
- class ProbabilisticModel[source]#
Bases:
object
Base class for probabilistic model used for DeepSensor. Ensures a set of methods required for DeepSensor are implemented by specific model classes that inherit from it.
- covariance(task, *args, **kwargs)[source]#
Computes the model covariance matrix over target points based on given context data. Shape (N, N).
- Parameters:
task (
Task
) – Task containing context data.- Returns:
numpy.ndarray
– Covariance matrix over target points.- Raises:
NotImplementedError – If not implemented by child class.
- joint_entropy(task, *args, **kwargs)[source]#
Computes the model joint entropy over target points based on given context data.
- Parameters:
task (
Task
) – Task containing context data.- Returns:
float – Joint entropy over target points.
- Raises:
NotImplementedError – If not implemented by child class.
- logpdf(task, *args, **kwargs)[source]#
Computes the joint model logpdf over target points based on given context data.
- Parameters:
task (
Task
) – Task containing context data.- Returns:
float – Joint logpdf over target points.
- Raises:
NotImplementedError – If not implemented by child class.
- loss(task, *args, **kwargs)[source]#
Computes the model loss over target points based on given context data.
- Parameters:
task (
Task
) – Task containing context data.- Returns:
float – Loss over target points.
- Raises:
NotImplementedError – If not implemented by child class.
- mean(task, *args, **kwargs)[source]#
Computes the model mean prediction over target points based on given context data.
- Parameters:
task (
Task
) – Task containing context data.- Returns:
numpy.ndarray
– Mean prediction over target points.- Raises:
NotImplementedError – If not implemented by child class.
- mean_marginal_entropy(task, *args, **kwargs)[source]#
Computes the mean marginal entropy over target points based on given context data.
Note
Note: Getting a vector of marginal entropies would be useful too.
- Parameters:
task (
Task
) – Task containing context data.- Returns:
float – Mean marginal entropy over target points.
- Raises:
NotImplementedError – If not implemented by child class.
- sample(task, n_samples=1, *args, **kwargs)[source]#
Draws
n_samples
joint samples over target points based on given context data. Returned shape is(n_samples, n_target)
.- Parameters:
- Returns:
tuple[
numpy.ndarray
] – Joint samples over target points.- Raises:
NotImplementedError – If not implemented by child class.
- std(task)[source]#
Model marginal standard deviation over target points given context points. Shape (N,).
- Parameters:
task (
Task
) – Task containing context data.- Returns:
numpy.ndarray
– Marginal standard deviation over target points.
- variance(task, *args, **kwargs)[source]#
Model marginal variance over target points given context points. Shape (N,).
- Parameters:
task (
Task
) – Task containing context data.- Returns:
numpy.ndarray
– Marginal variance over target points.- Raises:
NotImplementedError – If not implemented by child class.
- add_valid_time_coord_to_pred_and_move_time_dims(pred)[source]#
Add a valid time coordinate “time” to a Prediction object based on the initialisation times “init_time” and lead times “lead_time”, and reorder the time dims from (“lead_time”, “init_time”) to (“init_time”, “lead_time”).
- Parameters:
pred (
Prediction
) – Prediction object to add valid time coordinate to.- Returns:
Prediction
– Prediction object with valid time coordinate added.