ElastixImageFilter() vs ImageRegistrationMethod()

I would like to replace deprecated ElastixImageFilter method with the ImageRegistrationMethod method for registration. I am expecting both of these method can derive a similar resampled images, but in fact, ElastixImageFilter can derive a much better resampled images compared to ImageRegistrationMethod, why is this so? and what can i do?

Registration using ElastixImageFilter:

def registration_ElastixImageFilter(fixed_image, moving_image, fixed_mask, moving_mask, out_path):
    registration_method = sitk.ElastixImageFilter()
    registration_method.SetFixedImage(fixed_image)
    registration_method.SetMovingImage(moving_image)
    registration_method.SetFixedMask(fixed_mask)
    registration_method.SetMovingMask(moving_mask)

    bspline_map = sitk.ReadParameterFile("Parameters_BSpline.txt")
    registration_method.SetParameterMap(bspline_map)
    
    registration_method.SetLogToConsole(True)
    registration_method.SetOutputDirectory(out_path)
    registration_method.Execute()
    resampled_image = registration_method.GetResultImage()
    return registration_method, resampled_image

Parameters_BSpline.txt:

(FixedInternalImagePixelType "short")
(MovingInternalImagePixelType "short")
(UseDirectionCosines "true")
(Registration "MultiResolutionRegistration")
(Interpolator "BSplineInterpolator")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")
(FixedImagePyramid "FixedRecursiveImagePyramid")
(MovingImagePyramid "MovingRecursiveImagePyramid")
(Optimizer "AdaptiveStochasticGradientDescent")
(Transform "BSplineTransform")
(Metric "AdvancedMattesMutualInformation")
(Metric "AdvancedMeanSquares")
(FinalGridSpacingInPhysicalUnits 4)
(HowToCombineTransforms "Compose")
(NumberOfHistogramBins 32)
(ErodeMask "false")
(NumberOfResolutions 4)
(MaximumNumberOfIterations 500)
(NumberOfSpatialSamples 2048)
(NewSamplesEveryIteration "true")
(ImageSampler "Random")
(BSplineInterpolationOrder 1)
(FinalBSplineInterpolationOrder 3)
(DefaultPixelValue 2048)
(WriteResultImage "true")
(ResultImagePixelType "short")
(ResultImageFormat "mhd")
(WriteTransformParametersEachResolution "true")

Registration algorithm using ImageRegistrationMethod:

def registration_ImageRegistrationMethod(fixed_image, moving_image, fixed_mask=None, moving_mask=None, out_path=""):
    registration_method = sitk.ImageRegistrationMethod()
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=32)

    size = fixed_image.GetSize()
    total_voxels = size[0] * size[1] * size[2]
    # since I want 2048 samples (based on Parameters_BSpline.txt)
    R0 = 2048/(total_voxels/8)
    R1 = 2048/(total_voxels/4)
    R2 = 2048/(total_voxels/2)
    R3 = 2048/(total_voxels/1)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentagePerLevel([R0,R1,R2,R3])

    registration_method.SetInterpolator(sitk.sitkBSpline)
    registration_method.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=500,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10,
    )

    # trying to mimic a typical 4-level pyramid settings in SITK:
    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[8, 4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[8, 4, 2, 1])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    dimension = fixed_image.GetDimension()
    physical_size = [
        (fixed_image.GetSize()[i] - 1.0) * fixed_image.GetSpacing()[i]
        for i in range(dimension)
    ]

    # trying to mimic 4 mm final grid spacing
    mesh_size = [max(1, int(math.floor(p / 4.0))) for p in physical_size]

    initial_transform = sitk.BSplineTransformInitializer(fixed_image, mesh_size)
    registration_method.SetInitialTransform(initial_transform, inPlace=True)
    if fixed_mask is not None:
        registration_method.SetMetricFixedMask(fixed_mask)
    if moving_mask is not None:
        registration_method.SetMetricMovingMask(moving_mask)
    def command_iteration(method):
        print(
            f"Iteration: {method.GetOptimizerIteration()}, "
            f"Metric value: {method.GetMetricValue()}"
        )

    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
    final_transform = registration_method.Execute(fixed_image, moving_image)
    
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_image)
    resampler.SetInterpolator(sitk.sitkBSpline)  

    # (DefaultPixelValue 2048)
    resampler.SetDefaultPixelValue(2048)
    resampler.SetTransform(final_transform)
    resampled_image = resampler.Execute(moving_image)

    return final_transform, resampled_image

Thanks in advance :slight_smile:

@Niels_Dekker might want to comment.

1 Like

The reason I want to replace ElastixImageFilter() with ImageRegistrationMethod() is because I am experimenting with different registration configurations, and MetricEvaluate() would be a convenient way to evaluate the overall performance, but it is only applicable for ImageRegistrationMethod(). So my another question: Is there a similar function like MetricEvaluate() that I can apply on ElastixImageFilter() output?

Hi @Niels_Dekker, it would be greatly appreciated if you could give me any sort of advices on this matter. Thank you so much!