autocast.nn.spatial_attention#

class SpatialAttentionWrapper(attention_module, spatial)[source]#

Bases: Module

Wrapper to handle spatial dimensions for MultiheadSelfAttention.

Converts (B, C, W, H, …) -> (B, W*H*…, C) -> attention -> (B, C, W, H, …).

Parameters:
forward(x)[source]#

Forward pass handling spatial dimension transformation.

Parameters:

x (Float[Tensor, 'batch channel spatial *spatial']) – Input tensor with shape (B, C, spatial_dims…).

Returns:

Output tensor with shape (B, C, spatial_dims…).

Return type:

Float[Tensor, ‘batch channel spatial *spatial’]

class Residual(*args: Module)[source]#
class Residual(arg: OrderedDict[str, Module])

Bases: Sequential

Residual wrapper that adds input to output.

This wrapper can be used to add a residual connection around any sequence of layers.

forward(input)[source]#

Forward pass with residual connection.

Parameters:

input (Tensor) – Input tensor.

Returns:

Input + output from sequential layers.

Return type:

Tensor