autocast.processors.unet#

class UNetClassic(dim_in, dim_out, n_spatial_dims, spatial_resolution, init_features=32, gradient_checkpointing=False)[source]#

Bases: Module

Classic U-Net architecture for spatiotemporal prediction.

Adapted from:

Takamoto et al. 2022, PDEBENCH: An Extensive Benchmark for Scientific Machine Learning Source: github.com/pdebench/PDEBench/blob/main/pdebench/models/unet/unet.py

Via the_well repository: github.com/PolymathicAI/the_well/blob/master/ the_well/benchmark/models/unet_classic/__init__.py

If you use this implementation, please cite the original work above.

Parameters:
  • dim_in (int) – Number of input channels.

  • dim_out (int) – Number of output channels.

  • n_spatial_dims (int) – Number of spatial dimensions (1, 2, or 3).

  • spatial_resolution (Sequence[int]) – Spatial resolution of the input data.

  • init_features (int) – Number of features in the first encoder block. Default is 32.

  • gradient_checkpointing (bool) – Whether to use gradient checkpointing to reduce memory usage. Default is False.

optional_checkpointing(layer, *inputs, **kwargs)[source]#

Apply gradient checkpointing if enabled.

Parameters:
Return type:

Tensor

forward(x)[source]#

Forward pass through the U-Net.

Parameters:

x (Tensor) – Input tensor of shape (B, C_in, *spatial_dims).

Returns:

Output tensor of shape (B, C_out, *spatial_dims).

Return type:

Tensor

class UNetProcessor(in_channels, out_channels, spatial_resolution, n_spatial_dims=2, init_features=32, gradient_checkpointing=False, loss_func=None)[source]#

Bases: Processor[EncodedBatch]

UNet Processor for spatiotemporal prediction.

This processor wraps the classic U-Net architecture for learning mappings between spatiotemporal fields.

Adapted from:

Takamoto et al. 2022, PDEBENCH: An Extensive Benchmark for Scientific Machine Learning Source: github.com/pdebench/PDEBench/blob/main/pdebench/models/unet/unet.py

Via the_well repository: github.com/PolymathicAI/the_well/blob/master/the_well/benchmark/models/unet_classic/__init__.py

If you use this implementation, please cite the original work above.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • spatial_resolution (Sequence[int]) – Spatial resolution of the input data (e.g., [64, 64] for 2D).

  • n_spatial_dims (int) – Number of spatial dimensions (1, 2, or 3). Default is 2.

  • init_features (int) – Number of features in the first encoder block. Default is 32.

  • gradient_checkpointing (bool) – Whether to use gradient checkpointing to reduce memory usage. Default is False.

  • loss_func (Module | None) – Loss function. Defaults to MSELoss.

forward(x)[source]#

Forward pass through the UNet.

Parameters:

x (Tensor)

Return type:

Tensor

map(x, global_cond)[source]#

Map input states to output states.

Parameters:
  • x (Tensor) – Input tensor of shape (B, T_in, *spatial_dims).

  • global_cond (Tensor | None) – Optional conditioning tensor (currently unused).

Returns:

Output tensor of shape (B, T_out, *spatial_dims).

Return type:

Tensor

loss(batch)[source]#

Compute loss between output and target.

Parameters:

batch (EncodedBatch) – Batch containing encoded inputs and output fields.

Returns:

Loss value.

Return type:

Tensor

class AzulaUNetProcessor(in_channels, out_channels, hid_channels=(64, 128, 256, 512), hid_blocks=(2, 2, 2, 2), norm='layer', groups=8, ffn_factor=2, dropout=0.0, periodic=False, gradient_checkpointing=False, loss_func=None, n_noise_channels=None, global_cond_channels=None, include_global_cond=False)[source]#

Bases: Processor[EncodedBatch]

UNet Processor using Azula’s modern UNet architecture.

This processor wraps TemporalUNetBackbone with an Azula UNet backbone.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • hid_channels (Sequence[int]) – Hidden channel dimensions at each level. Default is [64, 128, 256, 512].

  • hid_blocks (Sequence[int]) – Number of residual blocks at each level. Default is [2, 2, 2, 2].

  • norm (str) – Normalization type: ‘batch’, ‘group’, ‘layer’, or ‘rms’. Default is ‘layer’.

  • groups (int) – Number of groups for GroupNorm. Default is 8.

  • ffn_factor (int) – Feed-forward network expansion factor. Default is 2.

  • dropout (float) – Dropout probability. Default is 0.0.

  • periodic (bool) – Whether to use periodic boundary conditions. Default is False.

  • gradient_checkpointing (bool) – Whether to use gradient checkpointing. Default is False.

  • loss_func (Module | None) – Loss function. Defaults to MSELoss.

  • n_noise_channels (int | None) – Number of noise channels for conditional normalization. If None, no noise conditioning is used. Default is None.

  • global_cond_channels (int | None) – Width of the optional global conditioning vector.

  • include_global_cond (bool) – Whether to inject global conditioning into modulation. Uses the same two-layer embedding pattern as ViT temporal backbones. Default is False.

forward(x, x_noise=None, global_cond=None)[source]#

Forward pass through the Azula UNet.

Parameters:
  • x (Tensor) – Input tensor of shape (B, C_in, *spatial_dims).

  • x_noise (Tensor | None) – Optional noise conditioning tensor of shape (B, n_noise_channels).

  • global_cond (Tensor | None) – Optional global conditioning tensor of shape (B, C_global). Used only when include_global_cond=True.

Returns:

Output tensor of shape (B, C_out, *spatial_dims).

Return type:

Tensor

map(x, global_cond)[source]#

Map input states to output states.

Parameters:
  • x (Tensor) – Input tensor of shape (B, C_in, *spatial_dims).

  • global_cond (Tensor | None) – Optional conditioning vector. Used when include_global_cond=True.

Returns:

Output tensor of shape (B, C_out, *spatial_dims).

Return type:

Tensor

loss(batch)[source]#

Compute loss between output and target.

Parameters:

batch (EncodedBatch) – Batch containing encoded inputs and output fields.

Returns:

Loss value.

Return type:

Tensor