Point indices after 3D rotation

Hello there,

I am trying to rotate a CT-scan and then locate the centroid of a segmented lesion. I have a SITK image named image and a SITK image containing a binary segmentation mask named ann.

The idea is as follows:

  1. Convert both image and annotation to isotropic spacing
  2. Calculate the centroid indices of the annotation and convert this centroid to physical coordinates
  3. Apply a rotation transformation (euler transformation) to the image volume.
  4. Apply same rotation transformation to the physical centroid
  5. Calculate the new indices of the centroid, by converting the physical coordinates back to indices

For convenience, assume that step 1 has already been completed. I created the following function for step 2 3 4 and 5.

def rotation3d(image, ann, theta_x, theta_y, theta_z, background_value=0.0):
    """
    This function rotates an image across each of the x, y, z axes by theta_x, theta_y, and theta_z degrees
    respectively (euler ZXY orientation) and resamples it to be isotropic.
    :param image: An sitk 3D image
    :param theta_x: The amount of degrees the user wants the image rotated around the x axis
    :param theta_y: The amount of degrees the user wants the image rotated around the y axis
    :param theta_z: The amount of degrees the user wants the image rotated around the z axis
    :param background_value: The value that should be used to pad the rotated volume
    :return: The rotated image
    """
    k = sitk.GetArrayFromImage(ann)
    spot_indices = np.argwhere(k)
    centroid = [int(q) for q in np.round(np.mean(spot_indices.astype(float), axis=0))]  # Z,X,Y #
    centroid = [centroid[1],centroid[2],centroid[0]]#ZXY to XYZ
    physical_centroid  = image.TransformIndexToPhysicalPoint(centroid)

    euler_transform = sitk.Euler3DTransform(
        image.TransformContinuousIndexToPhysicalPoint([(sz) / 2.0 for sz in image.GetSize()]),
        np.deg2rad(theta_x),
        np.deg2rad(theta_y),
        np.deg2rad(theta_z))

    # middle_of_scan = [sz/2 for sz in image.GetSize()]
    # compute the resampling grid for the transformed image
    max_indexes = [sz for sz in image.GetSize()]
    extreme_indexes = list(itertools.product(*(list(zip([0] * image.GetDimension(), max_indexes)))))
    extreme_points_transformed = [euler_transform.TransformPoint(image.TransformContinuousIndexToPhysicalPoint(p)) for p
                                  in extreme_indexes]

    output_min_coordinates = np.min(extreme_points_transformed, axis=0)
    output_max_coordinates = np.max(extreme_points_transformed, axis=0)
    rotated_physical_centroid = euler_transform.TransformPoint(physical_centroid)
    # isotropic ouput spacing
    output_spacing = min(image.GetSpacing())
    output_spacing = [output_spacing] * image.GetDimension()

    output_origin = output_min_coordinates
    output_size = [int(((omx - omn) / ospc) + 0.5) for ospc, omn, omx in
                   zip(output_spacing, output_min_coordinates, output_max_coordinates)]

    output_direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]
    output_pixeltype = image.GetPixelIDValue()

    resampled_image = sitk.Resample(image,
                         output_size,
                         euler_transform.GetInverse(),
                         sitk.sitkLinear,
                         output_origin,
                         output_spacing,
                         output_direction,
                         background_value,
                         output_pixeltype)
    new_centroid = resampled_image.TransformPhysicalPointToIndex(rotated_physical_centroid)

    return resampled_image, new_centroid

The centroid indices returned by the function seem to be incorrect. Does anyone have an idea on why?

For convenience, I use the following code that plots the centroid as a red cross on top of the axial slice:

def plot_axial_slice_with_centroid(image, centroid):
    # Convert the SimpleITK image to a NumPy array for easier manipulation
    image_array = sitk.GetArrayFromImage(image)

    # Get the axial slice containing the centroid
    z, x, y = centroid
    axial_slice = image_array[z, :, :]

    # Create a plot
    plt.imshow(axial_slice, cmap='gray')

    # Add a red cross to indicate the centroid
    plt.plot(y, x, 'rx', markersize=10)

    # Set appropriate axis labels and title
    plt.xlabel('Y')
    plt.ylabel('X')
    plt.title('Axial Slice with Centroid')

    # Display the plot
    plt.show()
rotated_image, rotated_centroid = rotation3d(image, ann, 0, 0, 0, -1024)
plot_axial_slice_with_centroid(rotated_image, rotated_centroid)

Any help is appreciated, thanks in advance!

Why not use LabelGeometryImageFilter’s centroid functionality? GetCentroid() should return point in physical space, transform those points using inverse of your image transform, and then use TransformPhysicalPointToIndex.

With linear transforms, it is both better and faster to use TransformGeometryImageFilter than resampling.

1 Like

Thanks for your reply!

Why not use LabelGeometryImageFilter’s centroid functionality?

LabelGeometryImageFilter is not included when installing itk using pip (AttributeError: module 'itk' has no attribute 'LabelGeometryImageFilter'). I read somewhere that I have to compile ITK myself if I want that functionality, but I rather use the default ITK and SimpleITK packages.

With linear transforms, it is both better and faster to use TransformGeometryImageFilter than resampling.

Perhaps TransformGeometryImageFilter is better and faster, but so far I have not been able to achieve the desired rotation and translation using that class (perhaps you could give some example code that performs the same transform using TransformGeometryImageFilter as the resampling operation that I have provided?)

Regardless, I did find a mistake in my previously posted function, namely that the centroid coordinates were switched around. I have moved the centroid calculation code to the bottom of the function for clarity. The function is now working for all angles of theta_z!

However, the location of the transformed centroid is wrong when theta_x and theta_y are non-zero, and I do not understand why. Any insight would be appreciated!

def rotation3d(image, ann, theta_x, theta_y, theta_z, background_value=0.0):
    """
    This function rotates an image across each of the x, y, z axes by theta_x, theta_y, and theta_z degrees
    respectively (euler ZXY orientation) and resamples it to be isotropic.
    :param image: An sitk 3D image
    :param theta_x: The amount of degrees the user wants the image rotated around the x axis
    :param theta_y: The amount of degrees the user wants the image rotated around the y axis
    :param theta_z: The amount of degrees the user wants the image rotated around the z axis
    :param background_value: The value that should be used to pad the rotated volume
    :return: The rotated image
    """
    euler_transform = sitk.Euler3DTransform(
        image.TransformContinuousIndexToPhysicalPoint([(sz) / 2.0 for sz in image.GetSize()]),
        np.deg2rad(theta_x),
        np.deg2rad(theta_y),
        np.deg2rad(theta_z))
    max_indexes = [sz for sz in image.GetSize()]
    extreme_indexes = list(itertools.product(*(list(zip([0] * image.GetDimension(), max_indexes)))))
    extreme_points_transformed = [euler_transform.TransformPoint(image.TransformContinuousIndexToPhysicalPoint(p)) for p
                                  in extreme_indexes]

    output_min_coordinates = np.min(extreme_points_transformed, axis=0)
    output_max_coordinates = np.max(extreme_points_transformed, axis=0)
    # isotropic ouput spacing
    output_spacing = min(image.GetSpacing())
    output_spacing = [output_spacing] * image.GetDimension()

    output_origin = output_min_coordinates
    output_size = [int(((omx - omn) / ospc) + 0.5) for ospc, omn, omx in
                   zip(output_spacing, output_min_coordinates, output_max_coordinates)]

    output_direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]
    output_pixeltype = image.GetPixelIDValue()

    resampled_image = sitk.Resample(image,
                         output_size,
                         euler_transform.GetInverse(),
                         sitk.sitkLinear,
                         output_origin,
                         output_spacing,
                         output_direction,
                         background_value,
                         output_pixeltype)
    #Calculate new centroid position
    k = sitk.GetArrayFromImage(ann)
    spot_indices = np.argwhere(k)
    centroid = [int(q) for q in np.round(np.mean(spot_indices.astype(float), axis=0))]  # Z,X,Y #
    centroid = [centroid[1], centroid[2], centroid[0]]  # ZXY to XYZ

    physical_centroid = image.TransformIndexToPhysicalPoint(centroid)
    rotated_physical_centroid = euler_transform.GetInverse().TransformPoint(physical_centroid)

    new_centroid = resampled_image.TransformPhysicalPointToIndex(rotated_physical_centroid)
    new_centroid = [new_centroid[2], new_centroid[0], new_centroid[1]]  # XYZ to ZXY
    return resampled_image, new_centroid

Thank you in advance!

In SimpleITK the LabelShapeStatisticsImageFilter has the GetCentroid method which may be useful to your work.

Thanks for your quick response! That one indeed is useful. Using that method for finding the centroid, I found out that I was switching the coordinates around because I was applying the eulerTransform.GetInverse() to both the image and the centroid! By removing the GetInverse() from the centroid calculation, and removing all coordinate switching, I now have a function that works for all angles!

def rotation3d(image, ann, theta_x, theta_y, theta_z, background_value=0.0):
    """
    This function rotates an image across each of the x, y, z axes by theta_x, theta_y, and theta_z degrees
    respectively (euler ZXY orientation) and resamples it to be isotropic.
    :param image: An sitk 3D image
    :param theta_x: The amount of degrees the user wants the image rotated around the x axis
    :param theta_y: The amount of degrees the user wants the image rotated around the y axis
    :param theta_z: The amount of degrees the user wants the image rotated around the z axis
    :param background_value: The value that should be used to pad the rotated volume
    :return: The rotated image
    """
    # image=ann
    euler_transform = sitk.Euler3DTransform(
        image.TransformContinuousIndexToPhysicalPoint([(sz) / 2.0 for sz in image.GetSize()]),
        np.deg2rad(theta_x),
        np.deg2rad(theta_y),
        np.deg2rad(theta_z))
    max_indexes = [sz for sz in image.GetSize()]
    extreme_indexes = list(itertools.product(*(list(zip([0] * image.GetDimension(), max_indexes)))))
    extreme_points_transformed = [euler_transform.TransformPoint(image.TransformContinuousIndexToPhysicalPoint(p)) for p
                                  in extreme_indexes]

    output_min_coordinates = np.min(extreme_points_transformed, axis=0)
    output_max_coordinates = np.max(extreme_points_transformed, axis=0)
    # isotropic ouput spacing
    output_spacing = min(image.GetSpacing())
    output_spacing = [output_spacing] * image.GetDimension()

    output_origin = output_min_coordinates
    output_size = [int(((omx - omn) / ospc) + 0.5) for ospc, omn, omx in
                   zip(output_spacing, output_min_coordinates, output_max_coordinates)]

    output_direction = [1, 0, 0, 0, 1, 0, 0, 0, 1]
    output_pixeltype = image.GetPixelIDValue()

    resampled_image = sitk.Resample(image,
                         output_size,
                         euler_transform.GetInverse(),
                         sitk.sitkLinear,
                         output_origin,
                         output_spacing,
                         output_direction,
                         background_value,
                         output_pixeltype)
    #Calculate new centroid position
    filter = sitk.LabelShapeStatisticsImageFilter()
    filter.Execute(ann)
    
    physical_centroid = filter.GetCentroid(1)#YXZ #Assuming that there is always one label...
    #physical_centroid = [physical_centroid[1],physical_centroid[0],physical_centroid[2]] #YXZ to XYZ (found through trial and error)
    if filter.GetNumberOfLabels()>1:
        raise ValueError("Image has more than one labels, so there are multiple centroids...")
    rotated_physical_centroid = euler_transform.TransformPoint(physical_centroid)
    new_centroid = resampled_image.TransformPhysicalPointToIndex(rotated_physical_centroid)
    #new_centroid = [new_centroid[2], new_centroid[0], new_centroid[1]]  # XYZ to ZXY
    return resampled_image, new_centroid

Thank you once again! Now the only question that remains is how I can use the TransformGeometryImageFilter instead of resampling to decrease the execution time of the function

1 Like

TransformGeometryImageFilter might be available in ITK 5.4RC1. To get it, do pip install --pre --upgrade itk.

Yes I have TransformGeometryImageFilter, but I wonder how I produce the same result as with the resampling operation. When I do

transformer = sitk.TransformGeometryImageFilter()
resampled_image = transformer.Execute(image, euler_transform.GetInverse())

then the image does not appear to rotate. My guess is that this is because the rotation happens in physical space and the function automatically adjusts the Direction and Origin of the image space. I tried to apply the parameters that I calculated previously on the new image space,

#After executing the TransformGeometryImageFilter
resampled_image.SetDirection(output_direction)
resampled_image.SetOrigin(output_origin)
resampled_image.SetSpacing(output_spacing)

but unfortunately that does not seem to help. The resampled_image does not appear to be rotated when converting to Numpy and plotting

Yes, it should appear the same when plotting using numpy. If you are working with physical space, why are you using numpy/matplotlib? It is better to visualize your images using Slicer or ITK-SNAP.

I am using numpy/matplotlib as, in the end, the goal is to feed non-orthogonal slices to a 2.5D detection model that takes a numpy arrays as input. The idea is that by first rotating and then slicing, the content of each slice changes which can act as a form of data augmentation.

When employing the rotation3d function that I proposed previously, which uses a resampling operation, the rotation is visible when visualizing the data using numpy/matplotlib. This is great, as we can feed those slices into our model! The question is now if we can achieve the same result using TransformGeometryImageFilter as it has a potential execution time benefit over resampling.

In that case resampling is unavoidable.

But why are you rolling your own? Both MONAI and TorchIO have random rotation augmentation transforms.

Sorry for not replying for a bit! I have been investigating using the torchIO.transforms.Affine() function to apply the 3D rotation. However, the main downside for me is that the dimensions of the “viewing cube” do not change, hence when rotation of the volume in physical space occurs, part of the volume falls outside of the “viewing cube”.

In the function that I have proposed previously, the viewing cube is resized by calculating the coordinates of the extremes of the physical volume after rotation, and then converting these to indices to calculate a new origin and size, which are applied in the resampling step.