autocast.data.encoded_dataset#

class EncodedBatchMixin[source]#

Bases: object

A mixin class to provide EncodedBatch conversion functionality.

static to_sample(data)[source]#

Convert a dictionary of tensors to a Sample object.

Parameters:

data (dict)

Return type:

EncodedSample

class EncodedDataset[source]#

Bases: Dataset, EncodedBatchMixin

A base class for encoded datasets.

class MiniWellDataset(file, steps=1, stride=1)[source]#

Bases: Dataset

Creates a mini-Well dataset.

From LOLA:

PolymathicAI/lola

Parameters:
static from_files(files, **kwargs)[source]#
Parameters:

files (Iterable[str])

Return type:

Dataset

class MiniWellInputOutput(file_name, n_steps_input, n_steps_output, steps=1, stride=1)[source]#

Bases: EncodedDataset, EncodedBatchMixin

A wrapper around The Well’s MiniwellDataset to provide Batch objects.

Parameters:
  • file_name (str)

  • n_steps_input (int)

  • n_steps_output (int)

  • steps (int)

  • stride (int)

miniwell_dataset: MiniWellDataset#
static from_files(files, **kwargs)[source]#
Parameters:

files (Iterable[str])

Return type:

Dataset

class CachedLatentDataset(cache_dir, n_steps_input=1, n_steps_output=1, stride=1, in_memory=True, **kwargs)[source]#

Bases: EncodedDataset

Dataset that reads pre-encoded latent trajectories from a cache directory.

Each .pt file contains a dict with key encoded_fields holding the full encoded trajectory and optionally global_cond. Windowing (n_steps_input, n_steps_output, stride) is applied at load time, allowing runtime configuration without re-encoding.

These files are produced by autocast.scripts.cache_latents.cache_latents().

Parameters:
class EncodedDataModule(data_path=None, n_steps_input=1, n_steps_output=1, stride=1, batch_size=16, num_workers=0, dataset_cls=None, in_memory=True, **dataset_kwargs)[source]#

Bases: LightningDataModule

DataModule for encoded datasets that produce EncodedBatch objects.

This datamodule wraps datasets that produce EncodedSample objects (like MiniWellInputOutput) and provides train/val/test dataloaders that collate samples into EncodedBatch objects.

Parameters:
setup(stage=None)[source]#

Set up datasets for the given stage.

Parameters:

stage (str | None)

Return type:

None

train_dataloader()[source]#

Return training dataloader.

Return type:

DataLoader

val_dataloader()[source]#

Return validation dataloader.

Return type:

DataLoader

test_dataloader()[source]#

Return test dataloader.

Return type:

DataLoader

rollout_test_dataloader(batch_size=None)[source]#

DataLoader for rollout evaluation on test data.

For cached latent datasets, creates a full-trajectory dataset so that ground truth is available for the entire rollout horizon (analogous to full_trajectory_mode=True in the non-cached datamodules).

Parameters:

batch_size (int | None)

Return type:

DataLoader

class MiniWellDataModule(data_path=None, n_steps_input=1, n_steps_output=1, stride=1, batch_size=16, num_workers=0)[source]#

Bases: LightningDataModule

DataModule for MiniWell datasets.

This datamodule wraps MiniWellInputOutput datasets and provides train/val/test dataloaders that collate samples into EncodedBatch objects. Accepts a data_path with train/valid/test subdirectories containing data.h5.

Parameters:
  • data_path (str | None)

  • n_steps_input (int)

  • n_steps_output (int)

  • stride (int)

  • batch_size (int)

  • num_workers (int)

setup(stage=None)[source]#

Set up datasets for the given stage.

Parameters:

stage (str | None)

Return type:

None

train_dataloader()[source]#

Return training dataloader.

Return type:

DataLoader

val_dataloader()[source]#

Return validation dataloader.

Return type:

DataLoader

test_dataloader()[source]#

Return test dataloader.

Return type:

DataLoader