Source code for autocast.decoders.identity
from autocast.decoders.base import Decoder
from autocast.types.types import TensorBNC, TensorBTSC
[docs]
class IdentityDecoder(Decoder):
"""Identity decoder that returns the input as output."""
def __init__(self, in_channels: int) -> None:
super().__init__()
self.latent_channels = in_channels
[docs]
def decode(self, z: TensorBNC) -> TensorBTSC:
if self.latent_channels is None:
self.latent_channels = z.shape[-1]
return z