autocast.utils.plots#

plot_spatiotemporal_video(true, pred=None, pred_uq=None, coverage=None, batch_idx=0, fps=5, vmin=None, vmax=None, cmap='viridis', save_path=None, title='Ground Truth vs Prediction', pred_uq_label='Prediction UQ', coverage_label='Coverage', colorbar_mode='none', colorbar_mode_uq='none', channel_names=None, preserve_aspect=False)[source]#

Create a video comparing ground truth and predicted spatiotemporal time series.

Parameters:
  • true (Float[Tensor, 'batch time spatial *spatial channel']) – Ground-truth tensor of shape (B, T, W, H, C).

  • pred (Float[Tensor, 'batch time spatial *spatial channel'] | None) – Optional predicted tensor of shape (B, T, W, H, C).

  • pred_uq (Float[Tensor, 'batch time spatial *spatial channel'] | None) – Optional predicted uncertainty tensor of shape (B, T, W, H, C).

  • coverage (Float[Tensor, 'batch time spatial *spatial channel'] | None) – Optional coverage tensor of shape (B, T, W, H, C).

  • batch_idx (int) – Which batch index to visualize (default: 0).

  • fps (int) – Frames per second for the video (default: 5).

  • vmin (float | None) – Minimum value for color scale (default: auto from data).

  • vmax (float | None) – Maximum value for color scale (default: auto from data).

  • cmap (str) – Colormap to use (default: “viridis”).

  • save_path (str | None) – Optional path to save the video (e.g., “output.mp4”).

  • title (str) – Title for the video (default: “Ground Truth vs Prediction”).

  • pred_uq_label (str) – Label for the prediction uncertainty row.

  • coverage_label (str) – Label for the coverage row.

  • colorbar_mode (Literal['none', 'row', 'column', 'all']) – Select how colorbars (and underlying color scales) are shared for the first two rows (true vs prediction): - “none”: every subplot gets its own colorbar (default). - “row”: a single colorbar per row (first two rows only). - “column”: a single colorbar per column (true/pred share per channel). - “all”: one colorbar shared across the first two rows.

  • colorbar_mode_uq (Literal['none', 'row']) – Colorbar sharing mode for the UQ/coverage rows.

  • channel_names (list[str] | None) – Optional list of channel names for titles.

  • preserve_aspect (bool) – If True, resize each subplot panel to match the spatial WxH ratio of the data so the image fills the panel without distortion. If False (default), panels are square and the image is stretched to fill via aspect='auto'.

Returns:

Animation object that can be displayed in notebooks.

plot_spatiotemporal_snapshots(true, pred=None, pred_uq=None, *, timesteps, channel=0, batch_idx=0, vmin=None, vmax=None, cmap='viridis', save_path=None, extra_formats=None, title='Ground Truth vs Prediction', pred_uq_label='Std Dev', channel_names=None, preserve_aspect=False, target_width_in=6.3, diff_log=False, uq_log=False)[source]#

Create a still panel at selected timesteps for one spatial channel.

Parameters:
  • true (Float[Tensor, 'batch time spatial *spatial channel'])

  • pred (Float[Tensor, 'batch time spatial *spatial channel'] | None)

  • pred_uq (Float[Tensor, 'batch time spatial *spatial channel'] | None)

  • timesteps (Iterable[int])

  • channel (int)

  • batch_idx (int)

  • vmin (float | None)

  • vmax (float | None)

  • cmap (str)

  • save_path (str | None)

  • extra_formats (Iterable[str] | None)

  • title (str)

  • pred_uq_label (str)

  • channel_names (list[str] | None)

  • preserve_aspect (bool)

  • target_width_in (float)

  • diff_log (bool)

  • uq_log (bool)

Return type:

Figure

plot_spatiotemporal_snapshots_data_only(true, *, timesteps, channel=0, batch_idx=0, vmin=None, vmax=None, cmap='viridis', save_path=None, extra_formats=None, ylabel=None, preserve_aspect=False, target_width_in=6.3)[source]#

Single-row snapshot panel of ground-truth data with no axis ticks.

Each panel is titled $i={t}$. The leftmost panel carries ylabel if provided (typically a dataset short label, e.g. AD, CNS, GS, GPE). Sized for A4 \linewidth with Times 10pt.

Parameters:
  • true (Float[Tensor, 'batch time spatial *spatial channel'])

  • timesteps (Iterable[int])

  • channel (int)

  • batch_idx (int)

  • vmin (float | None)

  • vmax (float | None)

  • cmap (str)

  • save_path (str | None)

  • extra_formats (Iterable[str] | None)

  • ylabel (str | None)

  • preserve_aspect (bool)

  • target_width_in (float)

Return type:

Figure

compute_metrics_from_dataloader(dataloader, metric_fns, predict_fn, windows=None, return_tensors=False, return_per_batch=False, device=None)[source]#

Compute metrics from a dataloader by running model forward passes.

Parameters:
  • dataloader (Iterable) – DataLoader that yields batches.

  • metric_fns (dict[str, Callable[[], Metric]]) – Dictionary of functions that return fresh metric instances, keyed by metric name.

  • predict_fn (Callable) – Custom function (batch) -> (preds, trues) for cases like rollout or simply the model forward. Should return a tuple of (preds, trues) tensors or a single tensor of predictions (in which case trues will be taken from batch).

  • windows (list[tuple[int, int] | None] | None) – List of (t_start, t_end) windows to evaluate. None means use all timesteps. If multiple windows provided, evaluates each independently.

  • return_tensors (bool) – If True, also return concatenated (pred, true) tensors.

  • return_per_batch (bool) – If True, also return per-batch metric dictionaries.

  • device (str | device | None) – Device to move metrics to before updating.

Returns:

The populated metrics, optionally the tensors, and optionally per-batch metrics.

Return type:

tuple[dict[None | tuple[int, int], dict[str, Metric]], tuple[Float[Tensor, ‘batch time spatial *spatial channel ensemble’], Float[Tensor, ‘batch time spatial *spatial channel’]] | None, list[dict[str, float | str]] | None]

compute_metrics_per_timestep_from_dataloader(dataloader, metric_fns, predict_fn, max_timesteps=None, device=None)[source]#

Compute per-channel, per-timestep metrics from a dataloader, batch-averaged.

For each timestep t, metrics are computed on the slice (B, t:t+1, …) and averaged over batches. Returns one (T, C) array per metric (T = timesteps, C = channels). MultiCoverage is expanded to one (T, C) per coverage level (e.g. coverage_0.05, coverage_0.10, …) so reliability curves can be built per timestep.

Parameters:
  • dataloader (Iterable) – DataLoader that yields batches (e.g. rollout test dataloader).

  • metric_fns (dict[str, Callable[[], Metric]]) – Metric factory functions. Metrics should return (1, C) when updated with (B, 1, S, C) and reduce_all=False (deterministic metrics) or be MultiCoverage (expanded to one key per alpha).

  • predict_fn (Callable) – (batch) -> (preds, trues) returning tensors of shape (B, T, S, C) or (B, T, S, C, M). Returns None, None to skip a batch.

  • max_timesteps (int | None) – Cap the number of timesteps (uses min over batches otherwise).

  • device (str | device | None) – Device to move metrics to before updating.

Returns:

Keys are metric names (and coverage_0.05, coverage_0.10, … for

MultiCoverage). Values are arrays of shape (T, C), batch-averaged.

Return type:

dict[str, ndarray]

compute_coverage_scores_from_dataloader(dataloader, predict_fn, coverage_levels=None, windows=None, return_tensors=False)[source]#

Compute coverage scores from a dataloader by running model forward passes.

Parameters:
  • dataloader (Iterable) – DataLoader that yields batches.

  • model – Model with forward(batch) that returns predictions with ensemble dimension. Either model or predict_fn must be provided.

  • predict_fn (Callable) – Custom function (batch) -> (preds, trues) for cases like rollout. Either model or predict_fn must be provided.

  • coverage_levels (list[float] | None) – Coverage levels to evaluate (default: 0.05 to 0.95).

  • windows (list[tuple[int, int] | None] | None) – List of (t_start, t_end) windows to evaluate. None means use all timesteps. If multiple windows provided, evaluates each independently.

  • return_tensors (bool) – If True, also return concatenated (pred, true) tensors.

Returns:

The populated MultiCoverage metric and optionally the tensors.

Return type:

tuple[dict[None | tuple[int, int], MultiCoverage], tuple[Float[Tensor, ‘batch time spatial *spatial channel ensemble’], Float[Tensor, ‘batch time spatial *spatial channel’]] | None]

plot_coverage(pred, true, coverage_levels=None, save_path=None, title='Coverage plot')[source]#

Plot reliability diagram showing expected vs observed coverage.

This is a convenience wrapper around MultiCoverage.plot().

Parameters:
  • pred (Float[Tensor, 'batch time spatial *spatial channel ensemble']) – Ensemble predictions (last dimension is ensemble members).

  • true (Float[Tensor, 'batch time spatial *spatial channel']) – Ground truth tensor.

  • coverage_levels (list[float] | None) – Coverage levels to evaluate (default: 0.05 to 0.95).

  • save_path (str | None) – Path to save the plot.

  • title (str) – Plot title.

Returns:

matplotlib.figure.Figure