autocast.nn.residual#

class ResBlock(channels, norm='layer', groups=16, attention_heads=None, ffn_factor=1, spatial=2, dropout=None, checkpointing=False, ffn_out_scale=None, **kwargs)[source]#

Bases: Module

Residual block with normalization, optional attention, and FFN.

Parameters:
  • channels (int) – Number of channels.

  • norm (str) – Type of normalization (‘layer’ or ‘group’).

  • groups (int) – Number of groups for GroupNorm.

  • attention_heads (int | None) – Number of attention heads (None for no attention).

  • ffn_factor (int) – Channel expansion factor in FFN.

  • spatial (int) – Number of spatial dimensions.

  • dropout (float | None) – Dropout rate.

  • checkpointing (bool) – Whether to use gradient checkpointing.

  • ffn_out_scale (float | None) – Optional multiplicative scale applied to the final FFN conv weights.

  • **kwargs (Any) – Additional arguments for convolution layers.

forward(x)[source]#

Forward pass with optional gradient checkpointing.

Parameters:

x (Tensor) – Input tensor.

Returns:

Output tensor with same shape as input.

Return type:

Tensor

class Residual(*args: Module)[source]#
class Residual(arg: OrderedDict[str, Module])

Bases: Sequential

Residual wrapper that adds input to output.

This wrapper can be used to add a residual connection around any sequence of layers.

forward(input)[source]#

Forward pass with residual connection.

Parameters:

input (Tensor) – Input tensor.

Returns:

Input + output from sequential layers.

Return type:

Tensor

class SpatialAttentionWrapper(attention_module, spatial)[source]#

Bases: Module

Wrapper to handle spatial dimensions for MultiheadSelfAttention.

Converts (B, C, W, H, …) -> (B, W*H*…, C) -> attention -> (B, C, W, H, …).

Parameters:
forward(x)[source]#

Forward pass handling spatial dimension transformation.

Parameters:

x (Float[Tensor, 'batch channel spatial *spatial']) – Input tensor with shape (B, C, spatial_dims…).

Returns:

Output tensor with shape (B, C, spatial_dims…).

Return type:

Float[Tensor, ‘batch channel spatial *spatial’]