Source code for segmentation.data

import nibabel as nib
import numpy as np
from tensorflow.keras.utils import Sequence

[docs]class DataGenerator(Sequence): """ Class used for data generators. """ def __init__( self, id_list, batch_size=10, dim=( 128, 128, 64), shuffle=True, n_classes=3): ''' Function called when initializing the class. ''' self.id_list = id_list self.batch_size = batch_size self.shuffle = shuffle self.dim = dim self.on_epoch_end() self.n_classes = n_classes
[docs] def on_epoch_end(self): ''' Updates indexes after each epoch. If shuffle is set to True, the indexes are shuffled. Shuffling the order in which examples are fed to the classifier is helpful so that batches between epochs do not look alike. Doing so will eventually make our model more robust. ''' self.indexes = np.arange(len(self.id_list)) if self.shuffle: np.random.shuffle(self.indexes)
def __data_generation(self, list_IDs_temp): ''' Generates data containing batch_size samples X : (n_samples, *dim, n_channels) ''' # Initialization X = np.empty((self.batch_size, *self.dim)) Y = np.empty((self.batch_size, *self.dim)) # Generate data for index, ID in enumerate(list_IDs_temp): # Store volume temp_volume = nib.load(ID[0]) temp_volume = temp_volume.get_fdata() temp_volume = np.asarray(temp_volume) X[index, :, :, :] = temp_volume # Store label temp_label = nib.load(ID[1]) temp_label = temp_label.get_fdata() temp_label = np.asarray(temp_label) Y[index, :, :, :] = temp_label X = X.reshape(X.shape + (1,)) # necessary to give it as input to model Y = self.remapLabels(Y) return X, Y def __len__(self): 'Denotes the number of batches per epoch' return int(np.floor(len(self.id_list) / self.batch_size)) def __getitem__(self, index): 'Generate one batch of data' # Generate indexes of the batch indexes = self.indexes[index * self.batch_size:(index + 1) * self.batch_size] # Find list of IDs id_list_temp = [self.id_list[k] for k in indexes] # Generate data X, y = self.__data_generation(id_list_temp) return X, y
[docs] def remapLabels(self, labels_4D): labels_5D = np.zeros(labels_4D.shape + (self.n_classes, )) # Scan the classes for c in range(self.n_classes): temp_indexes = np.where(labels_4D == c) labels_5D[temp_indexes + (np.ones(temp_indexes[0].shape, dtype='int') * c, )] = 1 return labels_5D