autocast.processors.base#

class Processor(*, stride=1, teacher_forcing_ratio=0.0, max_rollout_steps=1, loss_func=None, **kwargs)[source]#

Bases: ABC, Module, Generic[BatchT]

Processor Base Class.

Parameters:
  • stride (int)

  • teacher_forcing_ratio (float)

  • max_rollout_steps (int)

  • loss_func (Module | None)

  • kwargs (Any)

learning_rate: float#
abstract loss(batch)[source]#

Compute loss between output and target.

Parameters:

batch (BatchT)

Return type:

Tensor

abstract map(x, global_cond)[source]#

Map input states to output states.

Parameters:
  • x (Tensor) – Input tensor of shape (B, T_in, …)

  • global_cond (Tensor | None) – Optional conditioning/modulation tensor.

Returns:

Output tensor of shape (B, T_out, …)

Return type:

y (Tensor)