autocast.models.variational_autoencoder#

class VAELoss(beta=1.0)[source]#

Bases: Module

Variational Autoencoder Loss Function.

forward(model, batch)[source]#

Compute VAE loss as reconstruction loss + beta * KL divergence.

Parameters:
Return type:

Tensor

kl_divergence(encoded)[source]#

Compute the KL divergence loss.

Parameters:

encoded (Float[Tensor, 'batch *optional_dims channel']) – Encoded tensor containing mean and log variance. Shape: [B, 2*C, H, W, …] for spatial or [B, 2*latent_channels] for flat.

Returns:

KL divergence loss.

Return type:

Tensor

class VAE(encoder, decoder, spatial=None, norm=None)[source]#

Bases: EncoderDecoder

Variational Autoencoder Model.

Supports both flat (B, latent_channels) and spatial (B, C, H, W, …) latent representations.

Parameters:
encoder: EncoderWithCond#
decoder: Decoder#
fc_mean: Module#
fc_log_var: Module#
forward(batch)[source]#

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

  • batch (Batch)

Returns:

Your model’s output

Return type:

Float[Tensor, ‘batch time spatial *spatial channel’]

forward_with_latent(batch)[source]#
Parameters:

batch (Batch)

Return type:

tuple[Float[Tensor, ‘batch time spatial *spatial channel’], Float[Tensor, ‘batch *optional_dims channel’]]

reparametrize(mean, log_var)[source]#

Reparameterisation trick.

Samples z ~ N(mean, sigma) during training, but returns the mean deterministically in evaluation mode. This makes model.eval() produce deterministic reconstructions while training remains stochastic.

Parameters:
  • mean (Float[Tensor, 'batch *optional_dims channel'])

  • log_var (Float[Tensor, 'batch *optional_dims channel'])

Return type:

Float[Tensor, ‘batch *optional_dims channel’]

encode(batch)[source]#
Parameters:

batch (Batch)

Return type:

Float[Tensor, ‘batch *optional_dims channel’]

training_step(batch, batch_idx)[source]#

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch (Batch) – The output of your data iterable, normally a DataLoader.

  • batch_idx (int) – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

Return type:

Tensor

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.