autoencoder.datasets
#
Module Contents#
Classes#
Diffusion MRI dataset. Loads voxel data from HDF5 file fast. |
|
- class autoencoder.datasets.DiffusionMRIDataset(parameters_file_path, data_file_path, subject_list, tissue, include_parameters=None, exclude_parameters=None, batch_size=1, return_target=False, use_spherical_data=False)#
Bases:
torch.utils.data.Dataset
Diffusion MRI dataset. Loads voxel data from HDF5 file fast.
- Parameters
parameters_file_path (pathlib.Path) – HDF5 file path that contains all parameters from the MRI data
data_file_path (pathlib.Path) – HDF5 file path containing voxel data
subject_list (numpy.ndarray) – list of subjects to create the dataset with.
tissue (Literal[wb, gm, wm, csf]) –
The tissue to return. Can be the following values:
wb
,gm
,wm
, andcsf
. where:wb
= Whole Braingm
= Grey Matterwm
= White Mattercsf
= Cerebral Spinal Fluid
These tissue types should be created beforehand with MRTrix3
5ttgen
tool.include_parameters (List[int]) – parameters to only include in the dataset. Defaults to None.
exclude_parameters (List[int]) – parameters to exclude from the dataset. Defaults to None.
batch_size (int) – batch size. Defaults to 0.
return_target (bool) – return target data with all parameters included. Useful for loss calculation when recreating the dataset from a subset of parameters. Defaults to False.
use_spherical_data (bool) – Used to transform the data for models that require spherical data. Defaults to False.
- get_subject_id_by_batch_id(self, batch_id)#
- Parameters
batch_id (int) –
- Return type
int
- get_metadata_by_subject_id(self, subject_id)#
- Parameters
subject_id (int) –
- __len__(self)#
- __getitem__(self, index)#
Generates one sample of data
- class autoencoder.datasets.MRIDataModule(parameters_file_path, data_file_path, train_subject_ids, validate_subject_ids, include_parameters=None, exclude_parameters=None, return_target=False, use_spherical_data=False, batch_size=0, num_workers=0)#
Bases:
pytorch_lightning.LightningDataModule
- Parameters
parameters_file_path (str) –
data_file_path (str) –
train_subject_ids (List[int]) –
validate_subject_ids (List[int]) –
include_parameters (str) –
exclude_parameters (str) –
return_target (bool) –
use_spherical_data (bool) –
batch_size (int) –
num_workers (int) –
- setup(self, stage)#
- Parameters
stage (Optional[str]) –
- Return type
None
- train_dataloader(self)#
- Return type
torch.utils.data.DataLoader
- val_dataloader(self)#
- Return type
torch.utils.data.DataLoader
- test_dataloader(self)#
- Return type
torch.utils.data.DataLoader