autocast.decoders.dc#
- class DCDecoder(in_channels, out_channels, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, pixel_shuffle=True, norm='layer', attention_heads=None, ffn_factor=1, spatial=2, patch_size=1, periodic=False, dropout=None, checkpointing=False, identity_init=True, ffn_out_scale=None)[source]#
Bases:
DecoderDeep Compressed (DC) decoder module.
Progressively upsamples from latent representation back to original spatial dimensions using residual blocks with optional attention.
- Parameters:
in_channels (int) – Number of input (latent) channels.
out_channels (int) – Number of output channels.
hid_channels (Sequence[int]) – Number of channels at each depth level.
hid_blocks (Sequence[int]) – Number of residual blocks at each depth level.
kernel_size (int | Sequence[int]) – Kernel size for convolutions.
stride (int | Sequence[int]) – Stride for upsampling operations.
pixel_shuffle (bool) – Whether to use pixel shuffling or nearest upsampling.
norm (str) – Type of normalization (‘layer’ or ‘group’).
attention_heads (dict[int, int] | None) – Dict mapping depth index to number of attention heads.
ffn_factor (int) – Channel expansion factor in FFN blocks.
spatial (int) – Number of spatial dimensions (2 for 2D, 3 for 3D).
patch_size (int | Sequence[int]) – Patch size for unpatchifying at the end.
periodic (bool) – Whether spatial dimensions are periodic (use circular padding).
dropout (float | None) – Dropout rate.
checkpointing (bool) – Whether to use gradient checkpointing.
identity_init (bool) – Initialize up/downsampling convolutions as identity.
ffn_out_scale (float | None) – Optional multiplicative scale applied to each ResBlock FFN output conv.
Note
Based on the implementation from: - Deep Compression Autoencoder for Efficient High-Resolution Diffusion
Models (Chen et al., 2024), https://arxiv.org/abs/2410.10733v1
Lost in Latent Space: An Empirical Study of Latent Diffusion Models for Physics Emulation (Rozet et al., 2024), https://arxiv.org/abs/2507.02608, PolymathicAI/lola
- decode(z)[source]#
Decode latent tensor with time dimension back to original space.
- Parameters:
z (Float[Tensor, 'batch time spatial *spatial channel']) – Latent tensor with shape (B, T, spatial…, C_i) where C_i is last dim.
- Returns:
Decoded tensor with shape (B, T, spatial_expanded…, C_o).
- Return type:
Float[Tensor, ‘batch time spatial *spatial channel’]