autocast.processors.flow_matching#

class FlowMatchingProcessor(*, backbone, flow_ode_steps=1, n_steps_output=4, n_channels_out=1)[source]#

Bases: Processor

Processor that wraps a flow-matching generative model.

Parameters:
  • backbone (nn.Module)

  • flow_ode_steps (int)

  • n_steps_output (int)

  • n_channels_out (int)

flow_field(z, t, x, global_cond=None)[source]#

Flow matching vector field.

The vector field over the tangent space of output states (z). conditioned on input states (x) at time (t).

Parameters:
  • z (Tensor) – Current output states of shape (B, T_out, *spatial, C_out).

  • t (Tensor) – Time tensor of shape (B,).

  • x (Tensor) – Conditioning inputs of shape (B, T_in, *spatial, C_in).

  • global_cond (Tensor | None) – Optional non-spatial conditioning/modulation tensor.

Returns:

Time derivative of output states with the same shape as z.

Return type:

Tensor

forward(x, global_cond)[source]#

Alias to map for Lightning/PyTorch compatibility.

Parameters:
Return type:

Tensor

map(x, global_cond)[source]#

Map inputs states (x) to output states (z) by integrating the flow ODE.

Starting from noise, Euler-integrate the learned vector field until t=1.

Parameters:
  • x (Tensor) – Conditioning inputs of shape (B, T_in, *spatial, C_in).

  • global_cond (Tensor | None) – Optional non-spatial conditioning/modulation tensor.

Returns:

Generated outputs of shape (B, T_out, *spatial, C_out).

Return type:

Tensor

loss(batch)[source]#

Compute flow-matching loss for a batch.

Parameters:

batch (EncodedBatch)

Return type:

Tensor