autocast.data.datamodule#
- class TheWellDataModule(well_dataset_name, n_steps_input=1, n_steps_output=1, batch_size=4, use_normalization=False, normalization_type=<class 'the_well.data.normalization.ZScoreNormalization'>, autoencoder_mode=False, num_workers=None, normalization_path='../stats.yaml', **well_kwargs)[source]#
Bases:
LightningDataModuleDataModule for TheWell datasets.
- Parameters:
- rollout_val_dataloader(batch_size=None)[source]#
DataLoader for full trajectory rollouts on validation data.
- Parameters:
batch_size (int | None)
- Return type:
- class SpatioTemporalDataModule(data_path, data=None, dataset_cls=<class 'autocast.data.dataset.SpatioTemporalDataset'>, n_steps_input=1, n_steps_output=1, stride=1, channel_idxs=None, batch_size=4, dtype=torch.float32, ftype='torch', verbose=False, autoencoder_mode=False, full_trajectory_mode=False, use_normalization=False, normalization_type=<class 'the_well.data.normalization.ZScoreNormalization'>, normalization_path=None, normalization_stats=None, num_workers=None, pin_memory=False)[source]#
Bases:
LightningDataModuleA class for spatio-temporal data modules.
- Parameters:
data_path (str | None)
dataset_cls (type[SpatioTemporalDataset])
n_steps_input (int)
n_steps_output (int)
stride (int)
batch_size (int)
dtype (dtype)
ftype (str)
verbose (bool)
autoencoder_mode (bool)
full_trajectory_mode (bool)
use_normalization (bool)
normalization_type (type[ZScoreNormalization] | None)
normalization_path (None | str)
normalization_stats (dict | DictConfig | None)
num_workers (int | None)
pin_memory (bool)
- val_dataloader()[source]#
DataLoader for standard validation (not full trajectory rollouts).
- Return type:
- rollout_val_dataloader(batch_size=None)[source]#
DataLoader for full trajectory rollouts on validation data.
- Parameters:
batch_size (int | None)
- Return type: