autoencoder.datasets#

Module Contents#

Classes#

DiffusionMRIDataset

Diffusion MRI dataset. Loads voxel data from HDF5 file fast.

MRIDataModule

class autoencoder.datasets.DiffusionMRIDataset(parameters_file_path, data_file_path, subject_list, tissue, include_parameters=None, exclude_parameters=None, batch_size=1, return_target=False, use_spherical_data=False)#

Bases: torch.utils.data.Dataset

Diffusion MRI dataset. Loads voxel data from HDF5 file fast.

Parameters
  • parameters_file_path (pathlib.Path) – HDF5 file path that contains all parameters from the MRI data

  • data_file_path (pathlib.Path) – HDF5 file path containing voxel data

  • subject_list (numpy.ndarray) – list of subjects to create the dataset with.

  • tissue (Literal[wb, gm, wm, csf]) –

    The tissue to return. Can be the following values: wb, gm, wm, and csf. where:

    • wb = Whole Brain

    • gm = Grey Matter

    • wm = White Matter

    • csf = Cerebral Spinal Fluid

    These tissue types should be created beforehand with MRTrix3 5ttgen tool.

  • include_parameters (List[int]) – parameters to only include in the dataset. Defaults to None.

  • exclude_parameters (List[int]) – parameters to exclude from the dataset. Defaults to None.

  • batch_size (int) – batch size. Defaults to 0.

  • return_target (bool) – return target data with all parameters included. Useful for loss calculation when recreating the dataset from a subset of parameters. Defaults to False.

  • use_spherical_data (bool) – Used to transform the data for models that require spherical data. Defaults to False.

get_subject_id_by_batch_id(self, batch_id)#
Parameters

batch_id (int) –

Return type

int

get_metadata_by_subject_id(self, subject_id)#
Parameters

subject_id (int) –

__len__(self)#
__getitem__(self, index)#

Generates one sample of data

class autoencoder.datasets.MRIDataModule(parameters_file_path, data_file_path, train_subject_ids, validate_subject_ids, include_parameters=None, exclude_parameters=None, return_target=False, use_spherical_data=False, batch_size=0, num_workers=0)#

Bases: pytorch_lightning.LightningDataModule

Parameters
  • parameters_file_path (str) –

  • data_file_path (str) –

  • train_subject_ids (List[int]) –

  • validate_subject_ids (List[int]) –

  • include_parameters (str) –

  • exclude_parameters (str) –

  • return_target (bool) –

  • use_spherical_data (bool) –

  • batch_size (int) –

  • num_workers (int) –

setup(self, stage)#
Parameters

stage (Optional[str]) –

Return type

None

train_dataloader(self)#
Return type

torch.utils.data.DataLoader

val_dataloader(self)#
Return type

torch.utils.data.DataLoader

test_dataloader(self)#
Return type

torch.utils.data.DataLoader