autocast.data.datamodule#

class TheWellDataModule(well_dataset_name, n_steps_input=1, n_steps_output=1, batch_size=4, use_normalization=False, normalization_type=<class 'the_well.data.normalization.ZScoreNormalization'>, autoencoder_mode=False, num_workers=None, normalization_path='../stats.yaml', **well_kwargs)[source]#

Bases: LightningDataModule

DataModule for TheWell datasets.

Parameters:
  • well_dataset_name (str)

  • n_steps_input (int)

  • n_steps_output (int)

  • batch_size (int)

  • use_normalization (bool)

  • normalization_type (type[ZScoreNormalization] | None)

  • autoencoder_mode (bool)

  • num_workers (int | None)

  • normalization_path (str)

train_dataloader()[source]#

DataLoader for training.

Return type:

DataLoader

val_dataloader()[source]#

DataLoader for validation.

Return type:

DataLoader

test_dataloader()[source]#

DataLoader for testing.

Return type:

DataLoader

rollout_val_dataloader(batch_size=None)[source]#

DataLoader for full trajectory rollouts on validation data.

Parameters:

batch_size (int | None)

Return type:

DataLoader

rollout_test_dataloader(batch_size=None)[source]#

DataLoader for full trajectory rollouts on test data.

Parameters:

batch_size (int | None)

Return type:

DataLoader

class SpatioTemporalDataModule(data_path, data=None, dataset_cls=<class 'autocast.data.dataset.SpatioTemporalDataset'>, n_steps_input=1, n_steps_output=1, stride=1, channel_idxs=None, batch_size=4, dtype=torch.float32, ftype='torch', verbose=False, autoencoder_mode=False, full_trajectory_mode=False, use_normalization=False, normalization_type=<class 'the_well.data.normalization.ZScoreNormalization'>, normalization_path=None, normalization_stats=None, num_workers=None, pin_memory=False)[source]#

Bases: LightningDataModule

A class for spatio-temporal data modules.

Parameters:
train_dataloader()[source]#

DataLoader for training.

Return type:

DataLoader

val_dataloader()[source]#

DataLoader for standard validation (not full trajectory rollouts).

Return type:

DataLoader

rollout_val_dataloader(batch_size=None)[source]#

DataLoader for full trajectory rollouts on validation data.

Parameters:

batch_size (int | None)

Return type:

DataLoader

test_dataloader()[source]#

DataLoader for testing.

Return type:

DataLoader

rollout_test_dataloader(batch_size=None)[source]#

DataLoader for full trajectory rollouts on test data.

Parameters:

batch_size (int | None)

Return type:

DataLoader