Adding models to DeepSensor#
To extend DeepSensor with a new model, simply create a new class that inherits from deepsensor.model.DeepSensorModel
and implement the low-level prediction methods defined in deepsensor.model.ProbabilisticModel
, such as .mean
and .stddev
.
To demonstrate this, we’ll create a very basic new model called NewModel
, and show
that it inherits the convenient .predict
method from DeepSensorModel
.
To build more complex model classes, you may like to check out the ConvNP
source code as an example.
Show code cell source
import logging
logging.captureWarnings(True)
from deepsensor.model import DeepSensorModel
from deepsensor.data import DataProcessor, TaskLoader, Task
import xarray as xr
import numpy as np
class NewModel(DeepSensorModel):
"""A very naive model that predicts the mean of the first context set with a fixed stddev"""
def __init__(self, data_processor: DataProcessor, task_loader: TaskLoader):
super().__init__(data_processor, task_loader)
def mean(self, task: Task):
"""Compute mean at target locations"""
task = task.flatten_gridded_data()
# Shape of the mean should be (N_dim, N_target). Here we assume the number
# of dimensions is the same for the first context and the target set.
shape = (task["Y_c"][0].shape[0], task["X_t"][0].shape[1])
return np.ones(shape) * task["Y_c"][0].mean()
def stddev(self, task: Task):
"""Compute stddev at target locations"""
task = task.flatten_gridded_data()
shape = (task["Y_c"][0].shape[0], task["X_t"][0].shape[1])
return np.ones(shape) * 0.1
Show code cell source
# Load raw data
ds_raw = xr.tutorial.open_dataset("air_temperature")
# Normalise data
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)
# Set up task loader
task_loader = TaskLoader(context=ds, target=ds)
model = NewModel(data_processor, task_loader)
task = task_loader("2014-01-01", 100)
pred = model.predict(task, X_t=ds_raw)
pred["air"]
<xarray.Dataset> Dimensions: (time: 1, lat: 25, lon: 53) Coordinates: * lat (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0 * lon (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0 * time (time) datetime64[ns] 2014-01-01 Data variables: mean (time, lat, lon) float32 275.1 275.1 275.1 ... 275.1 275.1 275.1 std (time, lat, lon) float32 1.631 1.631 1.631 ... 1.631 1.631 1.631