autocast.processors.swin_vit#

modulate(x, shift, scale)[source]#

Modulate the input tensor with shift and scale parameters.

Parameters:
Return type:

Tensor

apply_gate(x, gate)[source]#

Apply a gating mechanism to the input tensor.

Parameters:
Return type:

Tensor

class AdaLNGenerator(hidden_dim, n_noise_channels, num_chunks, use_ada_ln=True, zero_init=True)[source]#

Bases: Module

Generate Adaptive Layer Norm parameters from noise embeddings.

Parameters:
  • hidden_dim (int)

  • n_noise_channels (int | None)

  • num_chunks (int)

  • use_ada_ln (bool)

  • zero_init (bool)

forward(x_noise=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x_noise (Tensor | None)

Return type:

tuple[Tensor | None, …]

class SwinViTBlock(hidden_dim, num_heads, n_spatial_dims, n_noise_channels, window_size, shift=False, drop_path=0.0, use_ada_ln=True, zero_init=True)[source]#

Bases: Module

Block for Swin ViT Processor.

Parameters:
forward(x, x_noise=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class PatchMerging(dim, n_spatial_dims=2)[source]#

Bases: Module

Patch merging layer.

Parameters:
  • dim (int)

  • n_spatial_dims (int)

forward(x, spatial_shape)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class PatchSplitting(dim, n_spatial_dims=2)[source]#

Bases: Module

Patch splitting layer.

Parameters:
  • dim (int)

  • n_spatial_dims (int)

forward(x, spatial_shape, crop=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class BasicSwinLayer(dim, depth, num_heads, n_spatial_dims, n_noise_channels, window_size, drop_path, downsample=None, upsample=None, use_ada_ln=True, zero_init=True)[source]#

Bases: Module

A basic Swin layer comprising multiple blocks and an optional down/upsample.

Parameters:
forward(x, x_noise, spatial_shape, crop=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

tuple[Tensor, Tensor | None, tuple[int, …], tuple[int, …] | None]

class SwinViTProcessor(in_channels, out_channels, spatial_resolution, window_size=(4, 4, 4), hidden_dim=64, encoder_depths=(2, 2, 2), encoder_num_heads=(3, 6, 12), decoder_depths=(2, 2, 2), decoder_num_heads=(12, 6, 3), drop_path=0.0, groups=12, loss_func=None, n_noise_channels=None, patch_size=None, use_ada_ln=True, zero_init=True)[source]#

Bases: Processor[EncodedBatch]

ViT Processor using 2D/3D Shifted-Window Attention (Swin) and Ada-LN.

Constructs a U-Net style encoder-decoder architecture.

Features: - Shifted-Window Multi-Head Self Attention (SW-MSA). - Patch Merging for hierarchical staging (encoder downsampling). - Patch Splitting for hierarchical expansion (decoder upsampling). - Ada-LN conditioning from noise embeddings (post-norm or per-block scaled).

References:#

  • Liu, Z. et al. “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” ICCV 2021.

  • Liu, Z. et al. “Video Swin Transformer.” CVPR 2022.

  • Microsoft Aurora Swin3D implementation: microsoft/aurora

    Shape convention: - Public processor boundary: channel-first

    • 2D: (B, C, H, W)

    • 3D: (B, C, H, W, D)

    • Internal Swin blocks: channels-last
      • 2D: (B, H, W, C)

      • 3D: (B, H, W, D, C)

get_encoder_specs(patch_res)[source]#
Parameters:

patch_res (tuple[int, ...])

Return type:

tuple[list[tuple[int, …]], list[tuple[int, …] | None]]

forward(x, x_noise=None)[source]#

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

map(x, global_cond=None)[source]#

Map encoded inputs through Swin while preserving channel-first I/O.

Parameters:
Return type:

Tensor

loss(batch)[source]#

Compute loss between output and target.

Parameters:

batch (EncodedBatch)

Return type:

Tensor

Parameters: