Extracting transformed patches from 3D image

In training AI algorithms for e.g. CT scan segmentation, patches are often extracted due to the GPU memory constraints. Additionally, patches are augmented with transforms such as rotation and scaling.

Now I have implemented such procedure using sitk.Resample, but it’s quite involved as you have to first calculate how big your patch should be such that the original patch size can be extracted from the transformed patch without having any padding values. It would be much simpler to use the patch origin and transformed spacing, and direction to sample from the original image. This can be done using for example scipy’s map_coordinates,

def random_rotate_patch(self, patch, order=3, mode='nearest', cval=0.0):                    
    from scipy.spatial.transform import Rotation as R                                        
    from scipy.ndimage import map_coordinates                                                
                                                                                             
    si, sj, sk = patch.shape                                                                 
    patch_center = 0.5 * np.array([si,sj,sk])                                                
    coords = np.mgrid[0:si, 0:sj, 0:sk].T.reshape(-1,3) - patch_center                       
                                                                                             
    axis = self._random_axis()                                                               
    angle = self.rng.random() * 0.08 * np.pi                                                 
    rot = R.from_rotvec(angle * axis).as_matrix()                                            
                                                                                             
    coords_rotated = coords @ rot.T                                                          
    coords_rotated += patch_center                                                           
    coords_rotated = coords_rotated.reshape((si,sj,sk,3))                                    
    coords_rotated = np.transpose(coords_rotated, (3, 0, 1, 2))                              
                                                                                             
    patch_rotated = map_coordinates(patch, coords_rotated, order=order, cval=cval, mode=mode)
                                                                                             
    return patch_rotated                                                                     

but this is x10 slower than the SimpleITK implementation I mentioned above. Can I somehow achieve something similar to the above code using SimpleITK or ITK?

Yes. It should be as simple as composing your transform for patch extraction (which you already said you have) with a transform for random rotation. Than pass that transform to resample. If you had identity, then just use the random rotation matrix instead (no need to compose).

Hi @dzenanz. I have already implemented what you mention. What I want to achieve, instead, is to extract a patch from an image by sampling at the transformed coordinates of the patch. Let me demonstrate what I mean by modifying the code above to do what I intend it to do:

# work in voxel space
# patch_origin at the corner of the patch.
def transform_and_extract_patch(self, image, patch_origin, patch_shape, order=3, mode='nearest', cval=0.0):                    
    from scipy.spatial.transform import Rotation as R                                        
    from scipy.ndimage import map_coordinates                                                
                                                                                             
    si, sj, sk = patch_shape                                                                 
    patch_center = 0.5 * np.array([si,sj,sk])                                                
    coords = np.mgrid[0:si, 0:sj, 0:sk].T.reshape(-1,3) - patch_center                       
                                                                                             
    axis = self._random_axis()                                                               
    angle = self.rng.random() * 0.08 * np.pi                                                 
    rot = R.from_rotvec(angle * axis).as_matrix()                                            
                                                                                             
    coords_rotated = coords @ rot.T                                                          
    coords_rotated += patch_center + patch_origin                                                         
    coords_rotated = coords_rotated.reshape((si,sj,sk,3))                                    
    coords_rotated = np.transpose(coords_rotated, (3, 0, 1, 2))                          
                                                                                             
    patch_rotated = map_coordinates(image, coords_rotated, order=order, cval=cval, mode=mode)
                                                                                             
    return patch_rotated                   

Note that image, here, is the whole CT scan.

Invoke sitk.resample instead of map_coordinates? Of course, you will need to construct its parameters, which are a bit different.

Ok, so it’s not very clear from the documentation, but after experimenting a bit, I think I understand now that you can use the reference_image argument to achieve what I’m trying to do.

dummy_patch = sitk.Image(*patch_shape, sitk.sitkUInt16)
dummy_patch.SetSpacing(image.GetSpacing())
patch_transformed = sitk.Resample(image, dummy_patch, transform, sitk.sitkBSpline, 0.0)
1 Like