Transforming Point not working when flipping

I am working on a data augmentation pipeline with a translations, rotation, scaling, flipping, and elastic deformation. I need to track how landmarks are modified during the transformation and resampling process, but there seems to be an issue when using the flipping with the affine transformation.

flipping_trans = sitk.AffineTransform(3)

# Flipping in the z axis
flipping_trans.Scale([1, 1, -1])

# Add other transforms
...
composite_trans = sitk.CompositeTransform(transformation_list)


# composite_trans is a combination of -> 
# [translate image center to origin (so rotation and scaling are about the image center),
# random translation, random rotation, random scaling, flipping in z axis, 
# translate origin to output image center,
# elastic deformation]


# Resample image
...
output_image = sitk_resample.Execute(input_image)


displacement_field = sitk.TransformToDisplacementField(composite_trans, sitk.sitkVectorFloat64,
                                                       size=output_image.GetSize(),
                                                       outputSpacing=output_image.GetSpacing(),
                                                       outputDirection=output_image.GetDirection(),
                                                       outputOrigin=output_image.GetOrigin())

invert_displacement_field = sitk.InvertDisplacementField(displacement_field )
inverted_transform = sitk.DisplacementFieldTransform(invert_displacement_field)

coord_physical = input_image.TransformContinuousIndexToPhysicalPoint(coord)
transformed_physical = inverted_transform.TransformPoint(coord_physical)
transformed_coord = output_image.TransformPhysicalPointToContinuousIndex(transformed_physical)

The output landmarks are correct when flipping is not used, but fails when it is used. I am not entirely sure what the issue is or what I could be doing wrong.

If you do flipping, that will likely take a large part of the displacement field outside of your image. That might be the problem.

How is elastic deformation represented? If using BSpline, maybe you should “flatten” your transforms to a BSpline instead of displacement field.

If my theory is correct, using a different origin and larger size along z axis for displacement field should “resolve” the problem.

Hello @kleingeo,

What you are encountering is an issue with computing the inverse of the displacement field. Explanations after the code (image used here is part of the data distributed with the SimpleITK notebooks).

TL;DR, code that works around the issues:

import SimpleITK as sitk

def invert_bspline_transform(tx, output_size, output_origin, output_spacing, output_direction):
    displacement_field_image = sitk.TransformToDisplacementField(tx,
                                                                 sitk.sitkVectorFloat64,
                                                                 output_size, 
                                                                 output_origin,
                                                                 output_spacing,
                                                                 output_direction)
    return invert_displacement_field_image(displacement_field_image)

def invert_displacement_field_transform(tx):
    return invert_displacement_field_image(sitk.DisplacementFieldTransform(tx).GetDisplacementField())

def invert_displacement_field_image(displacement_field_image):
    # SimpleITK supports three different filters for inverting a displacement field
    # arbitrary selection used with default values
    return sitk.DisplacementFieldTransform(sitk.InvertDisplacementField(displacement_field_image))


input_image = sitk.ReadImage('training_001_ct.mha')
point_indexes = [[247, 120, 8], [161,171,4]]

# translation
translation = sitk.TranslationTransform(3,[30,0,0])

# deformation, just a translation in y for easy validation
dx = sitk.Image(input_image.GetSize(), sitk.sitkFloat64)
dy = sitk.Image(input_image.GetSize(), sitk.sitkFloat64) + 20
dz = sitk.Image(input_image.GetSize(), sitk.sitkFloat64)
displacement_image = sitk.Compose(dx,dy,dz)
displacement_image.CopyInformation(input_image)
deformation = sitk.DisplacementFieldTransform(displacement_image)

# reflection in z
flipping_trans = sitk.AffineTransform(3)
flipping_trans.Scale([1, 1, -1])
flipping_trans.SetCenter(input_image.TransformContinuousIndexToPhysicalPoint([sz/2 for sz in input_image.GetSize()]))

composite_trans = sitk.CompositeTransform([flipping_trans, translation, deformation])

output_image = sitk.Resample(input_image, composite_trans)

sitk.Show(input_image, 'input')
sitk.Show(output_image, 'output')

original_transform = composite_trans

inverted_transform_list = []
composite_trans.FlattenTransform()
for i in range(composite_trans.GetNumberOfTransforms()-1, -1, -1):
    tx = composite_trans.GetNthTransform(i)
    ttype = tx.GetTransformEnum()
    if ttype is sitk.sitkDisplacementField:
        inverted_transform_list.append(invert_displacement_field_transform(tx))
    elif ttype is sitk.sitkBSplineTransform:
        inverted_transform_list.append(invert_bspline_transform(tx, output_image.GetSize(), output_image.GetOrigin(),
                                                                output_image.GetSpacing(), output_image.GetDirection()))
    else:
        inverted_transform_list.append(tx.GetInverse())
inverted_transform = sitk.CompositeTransform(inverted_transform_list)
        
for pnt_indexes in point_indexes:
    original_point = input_image.TransformContinuousIndexToPhysicalPoint(pnt_indexes)
    print(f'Original point: {original_point}')
    print(f'Transformed point: {inverted_transform.TransformPoint(original_point)}')
    print(f'Inverted transformed point: {original_transform.TransformPoint(inverted_transform.TransformPoint(original_point))}')

  1. If we have a CompositeTransform which only contains transformations with a global domain, reflection included, call tx.GetInverse() and it works fine. If one of the transformations inside the composite transform is a bounded domain transform (i.e. BSpline or DeformationField) this will not work (exception thrown).

  2. In ITK we don’t have a direct way to invert BSpline Transformations, so we convert it to a displacement field and use one of the displacement field inversion classes.

  3. The displacement field inversion algorithms most often assume that the field represents a smooth and continuous function. With reflection neighboring points are transformed in opposite directions, so trying to invert a displacement field that represents a reflection will not yield a good result (which is what you are experiencing).

2 Likes

Ok, thank you for that. So the easiest solution is to take the inverse of each component of the composite transform and stich them together (including transforming the spline to a deformation field and inverting that)?

Correct. Two things to notice:(1) inverted transforms are composed in reverse order to get the inverse of the original composite. (2) before starting the process the composite transform is flattened, if there were internal composite transforms they are replaced by the basic transforms that they contain.

2 Likes