autocast.processors.swin_vit#
- class AdaLNGenerator(hidden_dim, n_noise_channels, num_chunks, use_ada_ln=True, zero_init=True)[source]#
Bases:
ModuleGenerate Adaptive Layer Norm parameters from noise embeddings.
- Parameters:
- forward(x_noise=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class SwinViTBlock(hidden_dim, num_heads, n_spatial_dims, n_noise_channels, window_size, shift=False, drop_path=0.0, use_ada_ln=True, zero_init=True)[source]#
Bases:
ModuleBlock for Swin ViT Processor.
- Parameters:
- forward(x, x_noise=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class PatchMerging(dim, n_spatial_dims=2)[source]#
Bases:
ModulePatch merging layer.
- forward(x, spatial_shape)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class PatchSplitting(dim, n_spatial_dims=2)[source]#
Bases:
ModulePatch splitting layer.
- forward(x, spatial_shape, crop=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class BasicSwinLayer(dim, depth, num_heads, n_spatial_dims, n_noise_channels, window_size, drop_path, downsample=None, upsample=None, use_ada_ln=True, zero_init=True)[source]#
Bases:
ModuleA basic Swin layer comprising multiple blocks and an optional down/upsample.
- Parameters:
- forward(x, x_noise, spatial_shape, crop=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class SwinViTProcessor(in_channels, out_channels, spatial_resolution, window_size=(4, 4, 4), hidden_dim=64, encoder_depths=(2, 2, 2), encoder_num_heads=(3, 6, 12), decoder_depths=(2, 2, 2), decoder_num_heads=(12, 6, 3), drop_path=0.0, groups=12, loss_func=None, n_noise_channels=None, patch_size=None, use_ada_ln=True, zero_init=True)[source]#
Bases:
Processor[EncodedBatch]ViT Processor using 2D/3D Shifted-Window Attention (Swin) and Ada-LN.
Constructs a U-Net style encoder-decoder architecture.
Features: - Shifted-Window Multi-Head Self Attention (SW-MSA). - Patch Merging for hierarchical staging (encoder downsampling). - Patch Splitting for hierarchical expansion (decoder upsampling). - Ada-LN conditioning from noise embeddings (post-norm or per-block scaled).
References:#
Liu, Z. et al. “Swin Transformer: Hierarchical Vision Transformer using Shifted Windows.” ICCV 2021.
Liu, Z. et al. “Video Swin Transformer.” CVPR 2022.
Microsoft Aurora Swin3D implementation: microsoft/aurora
Shape convention: - Public processor boundary: channel-first
2D: (B, C, H, W)
3D: (B, C, H, W, D)
- Internal Swin blocks: channels-last
2D: (B, H, W, C)
3D: (B, H, W, D, C)
- forward(x, x_noise=None)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- map(x, global_cond=None)[source]#
Map encoded inputs through Swin while preserving channel-first I/O.
- loss(batch)[source]#
Compute loss between output and target.
- Parameters:
batch (EncodedBatch)
- Return type: