I would like to replace deprecated ElastixImageFilter method with the ImageRegistrationMethod method for registration. I am expecting both of these method can derive a similar resampled images, but in fact, ElastixImageFilter can derive a much better resampled images compared to ImageRegistrationMethod, why is this so? and what can i do?
Registration using ElastixImageFilter:
def registration_ElastixImageFilter(fixed_image, moving_image, fixed_mask, moving_mask, out_path):
registration_method = sitk.ElastixImageFilter()
registration_method.SetFixedImage(fixed_image)
registration_method.SetMovingImage(moving_image)
registration_method.SetFixedMask(fixed_mask)
registration_method.SetMovingMask(moving_mask)
bspline_map = sitk.ReadParameterFile("Parameters_BSpline.txt")
registration_method.SetParameterMap(bspline_map)
registration_method.SetLogToConsole(True)
registration_method.SetOutputDirectory(out_path)
registration_method.Execute()
resampled_image = registration_method.GetResultImage()
return registration_method, resampled_image
Parameters_BSpline.txt:
(FixedInternalImagePixelType "short")
(MovingInternalImagePixelType "short")
(UseDirectionCosines "true")
(Registration "MultiResolutionRegistration")
(Interpolator "BSplineInterpolator")
(ResampleInterpolator "FinalBSplineInterpolator")
(Resampler "DefaultResampler")
(FixedImagePyramid "FixedRecursiveImagePyramid")
(MovingImagePyramid "MovingRecursiveImagePyramid")
(Optimizer "AdaptiveStochasticGradientDescent")
(Transform "BSplineTransform")
(Metric "AdvancedMattesMutualInformation")
(Metric "AdvancedMeanSquares")
(FinalGridSpacingInPhysicalUnits 4)
(HowToCombineTransforms "Compose")
(NumberOfHistogramBins 32)
(ErodeMask "false")
(NumberOfResolutions 4)
(MaximumNumberOfIterations 500)
(NumberOfSpatialSamples 2048)
(NewSamplesEveryIteration "true")
(ImageSampler "Random")
(BSplineInterpolationOrder 1)
(FinalBSplineInterpolationOrder 3)
(DefaultPixelValue 2048)
(WriteResultImage "true")
(ResultImagePixelType "short")
(ResultImageFormat "mhd")
(WriteTransformParametersEachResolution "true")
Registration algorithm using ImageRegistrationMethod:
def registration_ImageRegistrationMethod(fixed_image, moving_image, fixed_mask=None, moving_mask=None, out_path=""):
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=32)
size = fixed_image.GetSize()
total_voxels = size[0] * size[1] * size[2]
# since I want 2048 samples (based on Parameters_BSpline.txt)
R0 = 2048/(total_voxels/8)
R1 = 2048/(total_voxels/4)
R2 = 2048/(total_voxels/2)
R3 = 2048/(total_voxels/1)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentagePerLevel([R0,R1,R2,R3])
registration_method.SetInterpolator(sitk.sitkBSpline)
registration_method.SetOptimizerAsGradientDescent(
learningRate=1.0,
numberOfIterations=500,
convergenceMinimumValue=1e-6,
convergenceWindowSize=10,
)
# trying to mimic a typical 4-level pyramid settings in SITK:
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[8, 4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[8, 4, 2, 1])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
dimension = fixed_image.GetDimension()
physical_size = [
(fixed_image.GetSize()[i] - 1.0) * fixed_image.GetSpacing()[i]
for i in range(dimension)
]
# trying to mimic 4 mm final grid spacing
mesh_size = [max(1, int(math.floor(p / 4.0))) for p in physical_size]
initial_transform = sitk.BSplineTransformInitializer(fixed_image, mesh_size)
registration_method.SetInitialTransform(initial_transform, inPlace=True)
if fixed_mask is not None:
registration_method.SetMetricFixedMask(fixed_mask)
if moving_mask is not None:
registration_method.SetMetricMovingMask(moving_mask)
def command_iteration(method):
print(
f"Iteration: {method.GetOptimizerIteration()}, "
f"Metric value: {method.GetMetricValue()}"
)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
final_transform = registration_method.Execute(fixed_image, moving_image)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(sitk.sitkBSpline)
# (DefaultPixelValue 2048)
resampler.SetDefaultPixelValue(2048)
resampler.SetTransform(final_transform)
resampled_image = resampler.Execute(moving_image)
return final_transform, resampled_image
Thanks in advance