autocast.decoders.dc

autocast.decoders.dc#

class DCDecoder(in_channels, out_channels, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, pixel_shuffle=True, norm='layer', attention_heads=None, ffn_factor=1, spatial=2, patch_size=1, periodic=False, dropout=None, checkpointing=False, identity_init=True, ffn_out_scale=None)[source]#

Bases: Decoder

Deep Compressed (DC) decoder module.

Progressively upsamples from latent representation back to original spatial dimensions using residual blocks with optional attention.

Parameters:
  • in_channels (int) – Number of input (latent) channels.

  • out_channels (int) – Number of output channels.

  • hid_channels (Sequence[int]) – Number of channels at each depth level.

  • hid_blocks (Sequence[int]) – Number of residual blocks at each depth level.

  • kernel_size (int | Sequence[int]) – Kernel size for convolutions.

  • stride (int | Sequence[int]) – Stride for upsampling operations.

  • pixel_shuffle (bool) – Whether to use pixel shuffling or nearest upsampling.

  • norm (str) – Type of normalization (‘layer’ or ‘group’).

  • attention_heads (dict[int, int] | None) – Dict mapping depth index to number of attention heads.

  • ffn_factor (int) – Channel expansion factor in FFN blocks.

  • spatial (int) – Number of spatial dimensions (2 for 2D, 3 for 3D).

  • patch_size (int | Sequence[int]) – Patch size for unpatchifying at the end.

  • periodic (bool) – Whether spatial dimensions are periodic (use circular padding).

  • dropout (float | None) – Dropout rate.

  • checkpointing (bool) – Whether to use gradient checkpointing.

  • identity_init (bool) – Initialize up/downsampling convolutions as identity.

  • ffn_out_scale (float | None) – Optional multiplicative scale applied to each ResBlock FFN output conv.

Note

Based on the implementation from: - Deep Compression Autoencoder for Efficient High-Resolution Diffusion

Models (Chen et al., 2024), https://arxiv.org/abs/2410.10733v1

decoder_model: Module#
decode(z)[source]#

Decode latent tensor with time dimension back to original space.

Parameters:

z (Float[Tensor, 'batch time spatial *spatial channel']) – Latent tensor with shape (B, T, spatial…, C_i) where C_i is last dim.

Returns:

Decoded tensor with shape (B, T, spatial_expanded…, C_o).

Return type:

Float[Tensor, ‘batch time spatial *spatial channel’]