autocast.processors.unet#
- class UNetClassic(dim_in, dim_out, n_spatial_dims, spatial_resolution, init_features=32, gradient_checkpointing=False)[source]#
Bases:
ModuleClassic U-Net architecture for spatiotemporal prediction.
- Adapted from:
Takamoto et al. 2022, PDEBENCH: An Extensive Benchmark for Scientific Machine Learning Source: github.com/pdebench/PDEBench/blob/main/pdebench/models/unet/unet.py
Via the_well repository: github.com/PolymathicAI/the_well/blob/master/ the_well/benchmark/models/unet_classic/__init__.py
If you use this implementation, please cite the original work above.
- Parameters:
dim_in (int) – Number of input channels.
dim_out (int) – Number of output channels.
n_spatial_dims (int) – Number of spatial dimensions (1, 2, or 3).
spatial_resolution (Sequence[int]) – Spatial resolution of the input data.
init_features (int) – Number of features in the first encoder block. Default is 32.
gradient_checkpointing (bool) – Whether to use gradient checkpointing to reduce memory usage. Default is False.
- class UNetProcessor(in_channels, out_channels, spatial_resolution, n_spatial_dims=2, init_features=32, gradient_checkpointing=False, loss_func=None)[source]#
Bases:
Processor[EncodedBatch]UNet Processor for spatiotemporal prediction.
This processor wraps the classic U-Net architecture for learning mappings between spatiotemporal fields.
- Adapted from:
Takamoto et al. 2022, PDEBENCH: An Extensive Benchmark for Scientific Machine Learning Source: github.com/pdebench/PDEBench/blob/main/pdebench/models/unet/unet.py
Via the_well repository: github.com/PolymathicAI/the_well/blob/master/the_well/benchmark/models/unet_classic/__init__.py
If you use this implementation, please cite the original work above.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
spatial_resolution (Sequence[int]) – Spatial resolution of the input data (e.g., [64, 64] for 2D).
n_spatial_dims (int) – Number of spatial dimensions (1, 2, or 3). Default is 2.
init_features (int) – Number of features in the first encoder block. Default is 32.
gradient_checkpointing (bool) – Whether to use gradient checkpointing to reduce memory usage. Default is False.
loss_func (Module | None) – Loss function. Defaults to MSELoss.
- loss(batch)[source]#
Compute loss between output and target.
- Parameters:
batch (EncodedBatch) – Batch containing encoded inputs and output fields.
- Returns:
Loss value.
- Return type:
- class AzulaUNetProcessor(in_channels, out_channels, hid_channels=(64, 128, 256, 512), hid_blocks=(2, 2, 2, 2), norm='layer', groups=8, ffn_factor=2, dropout=0.0, periodic=False, gradient_checkpointing=False, loss_func=None, n_noise_channels=None, global_cond_channels=None, include_global_cond=False)[source]#
Bases:
Processor[EncodedBatch]UNet Processor using Azula’s modern UNet architecture.
This processor wraps TemporalUNetBackbone with an Azula UNet backbone.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
hid_channels (Sequence[int]) – Hidden channel dimensions at each level. Default is [64, 128, 256, 512].
hid_blocks (Sequence[int]) – Number of residual blocks at each level. Default is [2, 2, 2, 2].
norm (str) – Normalization type: ‘batch’, ‘group’, ‘layer’, or ‘rms’. Default is ‘layer’.
groups (int) – Number of groups for GroupNorm. Default is 8.
ffn_factor (int) – Feed-forward network expansion factor. Default is 2.
dropout (float) – Dropout probability. Default is 0.0.
periodic (bool) – Whether to use periodic boundary conditions. Default is False.
gradient_checkpointing (bool) – Whether to use gradient checkpointing. Default is False.
loss_func (Module | None) – Loss function. Defaults to MSELoss.
n_noise_channels (int | None) – Number of noise channels for conditional normalization. If None, no noise conditioning is used. Default is None.
global_cond_channels (int | None) – Width of the optional global conditioning vector.
include_global_cond (bool) – Whether to inject global conditioning into modulation. Uses the same two-layer embedding pattern as ViT temporal backbones. Default is False.
- forward(x, x_noise=None, global_cond=None)[source]#
Forward pass through the Azula UNet.
- Parameters:
- Returns:
Output tensor of shape (B, C_out, *spatial_dims).
- Return type:
- loss(batch)[source]#
Compute loss between output and target.
- Parameters:
batch (EncodedBatch) – Batch containing encoded inputs and output fields.
- Returns:
Loss value.
- Return type: