Adding models to DeepSensor

Adding models to DeepSensor#

To extend DeepSensor with a new model, simply create a new class that inherits from deepsensor.model.DeepSensorModel and implement the low-level prediction methods defined in deepsensor.model.ProbabilisticModel, such as .mean and .stddev.

To demonstrate this, we’ll create a very basic new model called NewModel, and show that it inherits the convenient .predict method from DeepSensorModel. To build more complex model classes, you may like to check out the ConvNP source code as an example.

Hide code cell source
import logging

logging.captureWarnings(True)
from deepsensor.model import DeepSensorModel
from deepsensor.data import DataProcessor, TaskLoader, Task

import xarray as xr
import numpy as np
class NewModel(DeepSensorModel):
    """A very naive model that predicts the mean of the first context set with a fixed stddev"""
    
    def __init__(self, data_processor: DataProcessor, task_loader: TaskLoader):
        super().__init__(data_processor, task_loader)
        
    def mean(self, task: Task):
        """Compute mean at target locations"""
        task = task.flatten_gridded_data()
        # Shape of the mean should be (N_dim, N_target). Here we assume the number
        # of dimensions is the same for the first context and the target set.
        shape = (task["Y_c"][0].shape[0], task["X_t"][0].shape[1])
        return np.ones(shape) * task["Y_c"][0].mean()
    
    def stddev(self, task: Task):
        """Compute stddev at target locations"""
        task = task.flatten_gridded_data()
        shape = (task["Y_c"][0].shape[0], task["X_t"][0].shape[1])
        return np.ones(shape) * 0.1
Hide code cell source
# Load raw data
ds_raw = xr.tutorial.open_dataset("air_temperature")

# Normalise data
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
ds = data_processor(ds_raw)

# Set up task loader
task_loader = TaskLoader(context=ds, target=ds)
model = NewModel(data_processor, task_loader)
task = task_loader("2014-01-01", 100)
pred = model.predict(task, X_t=ds_raw)
pred["air"]
<xarray.Dataset>
Dimensions:  (time: 1, lat: 25, lon: 53)
Coordinates:
  * lat      (lat) float32 75.0 72.5 70.0 67.5 65.0 ... 25.0 22.5 20.0 17.5 15.0
  * lon      (lon) float32 200.0 202.5 205.0 207.5 ... 322.5 325.0 327.5 330.0
  * time     (time) datetime64[ns] 2014-01-01
Data variables:
    mean     (time, lat, lon) float32 275.1 275.1 275.1 ... 275.1 275.1 275.1
    std      (time, lat, lon) float32 1.631 1.631 1.631 ... 1.631 1.631 1.631