Acquisition functions#

Now that we’ve got the basics of DeepSensor’s active learning functionality from the Active learning page, here we will focus on the various acquisition functions available in the package. Again, we will use the pre-trained ERA5 spatial interpolation ConvNP from the previous Training page.

For an up-to-date list of acquisition functions see the API documentation for the deepsensor.active_learning.acquisition_fns module

Set-up#

Hide code cell content
import logging

logging.captureWarnings(True)

import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds
from deepsensor.data.sources import get_era5_reanalysis_data, get_earthenv_auxiliary_data, \
    get_gldas_land_mask
from deepsensor.model import ConvNP
from deepsensor.train import set_gpu_default_device

import cartopy.crs as ccrs
import pandas as pd
Hide code cell content
# Training/data config
data_range = ("2010-01-01", "2019-12-31")
train_range = ("2010-01-01", "2018-12-31")
val_range = ("2019-01-01", "2019-12-31")
date_subsample_factor = 10
extent = "usa"
station_var_IDs = ["TAVG"]
era5_var_IDs = ["2m_temperature"]
lowres_auxiliary_var_IDs = ["elevation"]
cache_dir = "../../.datacache"
deepsensor_folder = "../deepsensor_config/"
model_folder = "../model/"
verbose_download = True

val_dates = pd.date_range(val_range[0], val_range[1])[::date_subsample_factor]
Hide code cell content
era5_raw_ds = get_era5_reanalysis_data(era5_var_IDs, extent, date_range=data_range, cache=True, cache_dir=cache_dir, verbose=verbose_download, num_processes=8)
lowres_aux_raw_ds = get_earthenv_auxiliary_data(lowres_auxiliary_var_IDs, extent, "100KM", cache=True, cache_dir=cache_dir, verbose=verbose_download)
land_mask_raw_ds = get_gldas_land_mask(extent, cache=True, cache_dir=cache_dir, verbose=verbose_download)

data_processor = DataProcessor(x1_name="lat", x2_name="lon")
era5_ds = data_processor(era5_raw_ds)
lowres_aux_ds, land_mask_ds = data_processor([lowres_aux_raw_ds, land_mask_raw_ds], method="min_max")

dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq="D")
doy_ds = construct_circ_time_ds(dates, freq="D")
lowres_aux_ds["cos_D"] = doy_ds["cos_D"]
lowres_aux_ds["sin_D"] = doy_ds["sin_D"]
Downloading ERA5 data from Google Cloud Storage... 
100%|██████████████████████████████████████████████████████████████████| 120/120 [00:02<00:00, 45.88it/s]
1.41 GB loaded in 3.68 s
crs = ccrs.PlateCarree()
test_date = pd.Timestamp("2019-06-25")
# Run on GPU if available by setting GPU as default device
set_gpu_default_device()
task_loader = TaskLoader(
    context=[era5_ds, land_mask_ds, lowres_aux_ds],
    target=era5_ds,
)
task_loader.load_dask()
print(task_loader)
TaskLoader(3 context sets, 1 target sets)
Context variable IDs: (('2m_temperature',), ('GLDAS_mask',), ('elevation', 'cos_D', 'sin_D'))
Target variable IDs: (('2m_temperature',),)
# Load model
model = ConvNP(data_processor, task_loader, deepsensor_folder)
X_c = 100
task = task_loader(test_date, (X_c, "all", "all"), seed_override=42)
# xarray object containing a mask to remove ocean points from the search and target points
mask_ds = land_mask_raw_ds

Sequential greedy algorithm#

Sequentially loop over all search points, passing a query observation to the model at that point and computing the change in acquisition function.

These acquisition functions can be computationally expensive because they require one model forwards pass per query point, so we will coarsen the search space for the purposes of demonstration.

from deepsensor.active_learning import GreedyAlgorithm

greedy_alg = GreedyAlgorithm(
    model=model,
    X_t=era5_raw_ds,
    X_s=era5_raw_ds.coarsen(lat=15, lon=15, boundary="trim").mean(),  # Coarsen search points to speed up computation
    X_s_mask=mask_ds,  # Mask out ocean from search points
    X_t_mask=mask_ds,  # Mask out ocean from target points
    N_new_context=3,
    progress_bar=True,
)

MeanStddev#

Minimise the model’s mean standard deviation - i.e. minimise the expected MAE under the model.

from deepsensor.active_learning.acquisition_fns import MeanStddev

acquisition_fn = MeanStddev(model)
X_new_df, acquisition_fn_ds = greedy_alg(acquisition_fn, task, diff=True)

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(task, acquisition_fn_ds, X_new_df, data_processor, crs)
100%|██████████████████████████████████████████████████████████████████| 579/579 [00:32<00:00, 17.91it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/509795073e05838cebf6c8d793f8a1c4ce34f6c50d0a2f6c91c1d6f0a646357f.png ../_images/f50db2d4af9adf756b379f55bab153c1c600cc9c7ba7f22acd354e7494585c3e.png

pNormStddev#

Computing the p-norm of the standard deviations can be used to place greater emphasis on reducing the largest standard deviations

from deepsensor.active_learning.acquisition_fns import pNormStddev

acquisition_fn = pNormStddev(model, p=6)
X_new_df, acquisition_fn_ds = greedy_alg(acquisition_fn, task, diff=True)

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(task, acquisition_fn_ds, X_new_df, data_processor, crs)
100%|██████████████████████████████████████████████████████████████████| 579/579 [00:32<00:00, 17.77it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/51e09db13ce7191895d405687afb40d85c9cd55b12f34b072bc6a8cdcead41c7.png ../_images/488ae9d16d7a67a5299c9339dd99bfac3ae118110f0de025bbcab0a268b99036.png

Oracle sequential greedy algorithm#

Acquisition functions that inherit from AcquisitionFunctionOracle use the true target values to compute performance metrics. This assumes that the true target values are available at all target points, which will often not be the case.

Using oracle acquisition functions requires that the GreedyAlgorithm is initialised with a task_loader object so that it can load the true target values for each target point.

greedy_alg_with_groundtruth = GreedyAlgorithm(
    model=model,
    X_t=era5_raw_ds,
    X_s=era5_raw_ds.coarsen(lat=10, lon=10, boundary="trim").mean(),  # Coarsen search points to speed up computation
    X_s_mask=mask_ds,  # Mask out ocean from search points
    X_t_mask=mask_ds,  # Mask out ocean from target points
    query_infill=era5_ds,
    proposed_infill=era5_ds,
    N_new_context=3,
    task_loader=task_loader,
    verbose=True,
    progress_bar=True,
)
query_infill not on search grid, interpolating.
proposed_infill not on search grid, interpolating.

OracleMAE#

from deepsensor.active_learning.acquisition_fns import OracleMAE

acquisition_fn = OracleMAE(model)
X_new_df, acquisition_fn_ds = greedy_alg_with_groundtruth(acquisition_fn, task, diff=True)

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(task, acquisition_fn_ds, X_new_df, data_processor, crs)
100%|████████████████████████████████████████████████████████████████| 1311/1311 [01:14<00:00, 17.61it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/80a3533503df36866685a054f0e8a6ea8fc5844e211b1b5ee3de8466a084bdcd.png ../_images/00e7391285ad0a23d217120f9c20d6926001d3de4f8f5ab35f43e51a94dfeb51.png

OracleRMSE#

from deepsensor.active_learning.acquisition_fns import OracleRMSE

acquisition_fn = OracleRMSE(model)
X_new_df, acquisition_fn_ds = greedy_alg_with_groundtruth(acquisition_fn, task, diff=True)

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(task, acquisition_fn_ds, X_new_df, data_processor, crs)
100%|████████████████████████████████████████████████████████████████| 1311/1311 [01:13<00:00, 17.82it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/24dc1f2ac24f427dfbe135ee1f7a80db8d7a37d4e0b7194e2c4e41e1028ade16.png ../_images/b7a59c635e7c612edcc8359da18e8f65c53e75b7f1490fb69b4d692eee452c36.png

OracleMarginalNLL#

from deepsensor.active_learning.acquisition_fns import OracleMarginalNLL

acquisition_fn = OracleMarginalNLL(model)
X_new_df, acquisition_fn_ds = greedy_alg_with_groundtruth(acquisition_fn, task, diff=True)

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(task, acquisition_fn_ds, X_new_df, data_processor, crs)
100%|████████████████████████████████████████████████████████████████| 1311/1311 [02:26<00:00,  8.96it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/178bd07351dab10c117e96bd31c12ebdcb55ec04e301d21bd870aec2c006e739.png ../_images/076e00af7001644d67284ed2d84bfb106d3e30ca9637c23adc8db0fca2c5b9d7.png

Parallel greedy algorithm#

Acquisition functions that inherit from AcquisitionFunctionParallel can be computed over all search points in parallel by running the model forwards once. Parallel acquisition functions are much faster to compute than sequential acquisition functions, which required one forward pass per search point. This enables finer search grids, averaging acquisition functions over more tasks, and more proposed context points.

greedy_alg = GreedyAlgorithm(
    model=model,
    X_t=era5_raw_ds,
    X_s=era5_raw_ds,
    X_s_mask=mask_ds,  # Mask out ocean from search points
    X_t_mask=mask_ds,  # Mask out ocean from target points
    N_new_context=10,
    progress_bar=True,
)
dates = val_dates
tasks = task_loader(dates, (X_c, "all", "all"), seed_override=42)

Stddev#

Use the model’s standard deviation at search points as the acquisition function. Maximising this acquisition function will place context points at locations where the model is most uncertain.

from deepsensor.active_learning.acquisition_fns import Stddev
acquisition_fn = Stddev(model)
X_new_df, acquisition_fn_ds = greedy_alg(acquisition_fn, tasks)

fig = deepsensor.plot.placements(tasks[0], X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(tasks[0], acquisition_fn_ds.sel(iteration=slice(0, 4)), X_new_df, data_processor, crs, cmap="Greys")
100%|██████████████████████████████████████████████████████████████████| 370/370 [00:21<00:00, 17.29it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../_images/74407a92cd9c13e7e4ff0789405abbdcf75c9e26963d31986c0ccacfa9dff556.png ../_images/66d2f5ddb8c6b4c88ce7b8cd657b839abd722880fcdc9b7afe6aad22b901dd88.png

ExpectedImprovement#

The ExpectedImprovement acquisition function can be used to hunt for the most positive or negative values in the data.

We will average the acquisition function over tasks sampled from 52 equally spaced dates in 2020 to make the acquisition function more robust to the weather on a single day.

greedy_alg_with_groundtruth = GreedyAlgorithm(
    model=model,
    X_t=era5_raw_ds,
    X_s=era5_raw_ds,
    X_s_mask=mask_ds,  # Mask out ocean from search points
    X_t_mask=mask_ds,  # Mask out ocean from target points
    proposed_infill=era5_ds,  # EI requires ground truth after proposal
    N_new_context=10,
    progress_bar=True,
    verbose=True,
)
from deepsensor.active_learning.acquisition_fns import ExpectedImprovement

acquisition_fn = ExpectedImprovement(model)
X_new_df, acquisition_fn_ds = greedy_alg_with_groundtruth(acquisition_fn, tasks)
100%|██████████████████████████████████████████████████████████████████| 370/370 [00:42<00:00,  8.69it/s]
fig = deepsensor.plot.placements(tasks[0], X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(tasks[0], acquisition_fn_ds.sel(iteration=slice(0, 5)), X_new_df, data_processor, crs, cmap="Greys", add_colorbar=False, max_ncol=5)
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../_images/f708fa3993b7fa909597e8666480a2e0e5667f0a2c51277a76a90fafa6097115.png ../_images/4ea045a3a63de9128e016f22cc479595544073cefa39337d1b6f4deca3b55c94.png

Heuristic baseline acquisition functions#

Acquisition functions that don’t use a model can be used as baselines in sensor placement studies.

ContextDist#

Distance to the closest context point. Maximising this acquisition function will place context points at locations that are furthest from existing context points.

from deepsensor.active_learning.acquisition_fns import ContextDist

acquisition_fn = ContextDist(context_set_idx=0)
X_new_df, acquisition_fn_ds = greedy_alg(acquisition_fn, tasks[0])

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
fig = deepsensor.plot.acquisition_fn(tasks[0], acquisition_fn_ds.sel(iteration=slice(0, 55, 5)), X_new_df, data_processor, crs, cmap="Greys", add_colorbar=False, max_ncol=5)
100%|████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 30.19it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
Averaging acquisition function over dims for plotting: ['time']
../_images/b65a737d863cf5630de0b9750604b95e15eb4b5952e184c48c50e88cb19d5792.png ../_images/be8faac5487cbb19137388cc8c7d4dbc3d588e8261049c1e9bfa3e3eadf561e3.png

Random#

Random acquisition function leading to random placements - a useful baseline!

from deepsensor.active_learning.acquisition_fns import Random

acquisition_fn = Random()
X_new_df, acquisition_fn_ds = greedy_alg(acquisition_fn, tasks[0])

fig = deepsensor.plot.placements(task, X_new_df, data_processor, crs)
100%|███████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 251.40it/s]
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.
../_images/85b6a9b19e2e86b56cf2104458ba9f787a54896fcd452fe1fd789ef70590280e.png