two label map registration. Bad results.

I have a mask that has 2 labels. I want to register it on a segmentation map that has 3 labels(both cases background 0).
image
image

I have run the following method but getting bad result outputs:

mask_itk = sitk.GetImageFromArray(mask)
label_itk = sitk.GetImageFromArray(labelMap) 

fixed_image = sitk.Cast(label_itk, sitk.sitkFloat32)
moving_image = sitk.Cast(mask_itk, sitk.sitkFloat32)

initialTx = sitk.CenteredTransformInitializer(fixed_image, moving_image, sitk.AffineTransform(fixed_image.GetDimension()))
registrationMethod = sitk.ImageRegistrationMethod()

# registrationMethod.SetShrinkFactorsPerLevel([1, 1, 1])
# registrationMethod.SetSmoothingSigmasPerLevel([1, 1, 1])

# registrationMethod.SetMetricAsJointHistogramMutualInformation(20)
registrationMethod.MetricUseFixedImageGradientFilterOff()
  
registrationMethod.SetOptimizerAsGradientDescent(
     learningRate=10.0,
     numberOfIterations=1000,
     estimateLearningRate=registrationMethod.EachIteration,
 )
registrationMethod.SetOptimizerScalesFromPhysicalShift()
registrationMethod.SetInitialTransform(initialTx)
registrationMethod.SetInterpolator(sitk.sitkLinear)

registrationMethod.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registrationMethod))
registrationMethod.AddCommand(sitk.sitkMultiResolutionIterationEvent,
                              lambda: command_multiresolution_iteration(registrationMethod))

outTx1 = registrationMethod.Execute(fixed_image, moving_image)

print("-------")
print(outTx1)
print(f"Optimizer stop condition: {registrationMethod.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {registrationMethod.GetOptimizerIteration()}")
print(f" Metric value: {registrationMethod.GetMetricValue()}")

displacementField = sitk.Image(fixed_image.GetSize(), sitk.sitkVectorFloat64)
displacementField.CopyInformation(fixed_image)
displacementTx = sitk.DisplacementFieldTransform(displacementField)
# del displacementField
displacementTx.SetSmoothingGaussianOnUpdate(
 varianceForUpdateField=0.0, varianceForTotalField=1.5
)

registrationMethod.SetMovingInitialTransform(outTx1)
registrationMethod.SetInitialTransform(displacementTx, inPlace=True)

registrationMethod.SetMetricAsANTSNeighborhoodCorrelation(4)
registrationMethod.MetricUseFixedImageGradientFilterOff()

# registrationMethod.SetShrinkFactorsPerLevel([3, 2, 1])
# registrationMethod.SetSmoothingSigmasPerLevel([2, 1, 1])

registrationMethod.SetOptimizerScalesFromPhysicalShift()
registrationMethod.SetOptimizerAsGradientDescent(
 learningRate=1,
 numberOfIterations=300,
 estimateLearningRate=registrationMethod.EachIteration,
)

registrationMethod.Execute(fixed_image, moving_image)
print("-------")
print(displacementTx)
print(f"Optimizer stop condition: {registrationMethod.GetOptimizerStopConditionDescription()}")
print(f" Iteration: {registrationMethod.GetOptimizerIteration()}")
print(f" Metric value: {registrationMethod.GetMetricValue()}")

compositeTx = sitk.CompositeTransform([outTx1, displacementTx])

resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(compositeTx)

out = resampler.Execute(moving_image)

The result I am getting is as follows:
image

Though I believe it is somewhat working but I guess not the kind of registration results I need, where the butterfly regions and the red parts of the segmentation should match/overlap.

Any suggestions, or ideas to improve would be very appreciated.

Thanks.
seg_label_map.h5 (84.1 KB)

The sample data both label and segmentation map can be found above. Here is the code.

!pip install h5py
import os
pathToTheFile= os.path.join(downloadDir, 'seg_label_map.h5')
with h5py.File(reconPath, 'r') as pfile: 
    segmentation = np.array(pfile['segmentation'])
    label = np.array(pfile['label'])

For label map registration, you should use nearest neighbor interpolation, not linear. And you should use label-appropriate metric, such as match cardinality.

1 Like

The result with NearestNeighbor interpolator.
image
Looks like the results are getting cropped and in this case jittery.

On another note, does itkMatchCardinalityImageToImageMetric have any python implementation?

You could consider registering signed distance field of the inner label, instead of labels directly.

1 Like

Does the top image need to be rotated counter clockwise by 45-90 degrees? Are you first performing affine/rigid registration before the deformation? How do the results of this look before the deformation?

1 Like

Hello I tried to implement an AffinceTx from an example with displacement(composite) tx.
Please see the code below:

fixed = sitk.Cast(label_itk, sitk.sitkFloat32)
moving = sitk.Cast(mask_itk, sitk.sitkFloat32)
    
initialTx = sitk.CenteredTransformInitializer(
                 fixed, moving, sitk.AffineTransform(fixed.GetDimension()))
  
R = sitk.ImageRegistrationMethod()
R.SetMetricAsMattesMutualInformation(numberOfHistogramBins=500)  
R.SetOptimizerAsGradientDescent(
     learningRate=10,
     numberOfIterations=300,
     estimateLearningRate=R.EachIteration,
 )
R.SetOptimizerScalesFromPhysicalShift()
R.SetInitialTransform(initialTx)
R.SetInterpolator(sitk.sitkNearestNeighbor)
outTx1 = R.Execute(fixed, moving)
displacementField = sitk.TransformToDisplacementFieldFilter()
displacementField.SetReferenceImage(fixed_image)
displacementTx = sitk.DisplacementFieldTransform(displacementField.Execute(sitk.Transform(2, sitk.sitkIdentity)))
displacementTx.SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0, varianceForTotalField=1.5) 
R.SetMovingInitialTransform(outTx1)
R.SetInitialTransform(displacementTx, inPlace=True)

R.SetMetricAsANTSNeighborhoodCorrelation(4)
R.MetricUseFixedImageGradientFilterOff()
R.Execute(fixed, moving)
compositeTx = sitk.CompositeTransform([outTx1, displacementTx])
  
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(compositeTx)
  
out = resampler.Execute(moving)

image
image

Do you have any example for 2D Affine sitk implementation?

would it be possible to give an example or implementation of SDF in sitk?
I searched about it but didn’t get any registration-specific material.
Any help?

Here is an example of SDF registration in C++ ITK. @blowekamp or @zivy might point you to SDF in SITK.

2 Likes

We do not have an example in SimpleITK for registration of label images using a signed-distance-transform. This is relatively straightforward, pseudo-code below:

# Step 1: convert relevant label into distance map:
b_value = x # label of the "butterfly" shaped region
fixed_image_distance_map = sitk.SignedMaurerDistanceMap(fixed_label_image==b_value, squaredDistance=True, useImageSpacing=True)
moving_image_distance_map = sitk.SignedMaurerDistanceMap(moving_label_image==b_value, squaredDistance=True, useImageSpacing=True)

# Step 2: register the fixed_image_distance_map and moving_image_distance_map as usual, 
# including linear interpolation etc., same way you would register intensity images.

# Step 3: apply resulting transformation to resample moving_label_image, just use the nearest neighbor interpolator so that you do not introduce labels that weren't in the original label image.
1 Like

image
This is the result I get after SDF registration, which is pretty similar to
image just translated I suppose.
The codes are as follow:

    b_value = 1
    fixed_image_distance_map = sitk.SignedMaurerDistanceMap(fixed_image == b_value, squaredDistance=True, useImageSpacing=True)
    b_value = 2
    moving_image_distance_map = sitk.SignedMaurerDistanceMap(moving_image == b_value, squaredDistance=True, useImageSpacing=True)

    initialTx = sitk.CenteredTransformInitializer(fixed_image_distance_map,
                                                  moving_image_distance_map,
                                                  sitk.AffineTransform(fixed_image.GetDimension()))
    registrationMethod = sitk.ImageRegistrationMethod()
   
    registrationMethod.MetricUseFixedImageGradientFilterOff()
    registrationMethod.SetOptimizerAsGradientDescent(
        learningRate=10.0,
        numberOfIterations=1000,
        estimateLearningRate=registrationMethod.EachIteration,
    )
    registrationMethod.SetOptimizerScalesFromPhysicalShift()
    registrationMethod.SetInitialTransform(initialTx)
    registrationMethod.SetInterpolator(sitk.sitkNearestNeighbor)

    outTx1 = registrationMethod.Execute(fixed_image_distance_map, moving_image_distance_map)

    displacementField = sitk.Image(fixed_image_distance_map.GetSize(), sitk.sitkVectorFloat64)
    displacementField.CopyInformation(fixed_image_distance_map)
    displacementTx = sitk.DisplacementFieldTransform(displacementField)

    displacementTx.SetSmoothingGaussianOnUpdate(
        varianceForUpdateField=0.0, varianceForTotalField=1.5
    )

    registrationMethod.SetMovingInitialTransform(outTx1)
    registrationMethod.SetInitialTransform(displacementTx, inPlace=True)

    registrationMethod.SetMetricAsANTSNeighborhoodCorrelation(4)
    registrationMethod.MetricUseFixedImageGradientFilterOff()

    registrationMethod.SetOptimizerScalesFromPhysicalShift()
    registrationMethod.SetOptimizerAsGradientDescent(
        learningRate=1,
        numberOfIterations=300,
        estimateLearningRate=registrationMethod.EachIteration,
    )

    registrationMethod.Execute(fixed_image_distance_map, moving_image_distance_map)

    compositeTx = sitk.CompositeTransform([outTx1, displacementTx])

    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed_image_distance_map)
    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    resampler.SetDefaultPixelValue(0)
    resampler.SetTransform(outTx1)

    out = resampler.Execute(moving_image)

Also the selected regions were:
image
image

With the butterfly portion of the region taken the results are also somewhat similar.

When using SDF approach, you should not be masking your images.

1 Like

Also are you monitoring the registration metric at each iterations? Perhaps create a plot to determine if the registration optimization was successful.

1 Like

Hello @blowekamp, @zivy and @dzenanz
Thanks for your kind suggestions.
I made some improvements!!
image
image
image
image

all_orientations = np.linspace(-2*np.pi, 2*np.pi, 100)
# print(all_orientations)

# Evaluate the similarity metric using the rotation parameter space sampling, translation remains the same for all.
initial_transform = sitk.Euler2DTransform(sitk.CenteredTransformInitializer(fixed_image,
                                                                      moving_image,
                                                                      sitk.Euler2DTransform(),
                                                                      sitk.CenteredTransformInitializerFilter.GEOMETRY))
# Registration framework setup.
registration_method = sitk.ImageRegistrationMethod()
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=500)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.001)
registration_method.SetInitialTransform(initial_transform, inPlace=False)
registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=2.0,
                                                       minStep=1e-4,
                                                       numberOfIterations=500,
                                                       gradientMagnitudeTolerance=1e-8)
registration_method.SetInterpolator(sitk.sitkNearestNeighbor)
registration_method.SetOptimizerScalesFromIndexShift()
# best_orientation = (0.0, 0.0)
best_similarity_value = registration_method.MetricEvaluate(fixed_image, moving_image)
similarity_value = []
# Iterate over all other rotation parameter settings.
for key, orientation in enumerate(all_orientations):    #   .items():
    initial_transform.SetAngle(orientation)
    registration_method.SetInitialTransform(initial_transform)
    current_similarity_value = registration_method.MetricEvaluate(fixed_image, moving_image)
    similarity_value.append(current_similarity_value)
    # print("current similarity value: ", current_similarity_value)
    if current_similarity_value < best_similarity_value:
      best_similarity_value = current_similarity_value
      best_orientation = orientation
    print('best orientation is: ' + str(best_orientation))
    print(current_similarity_value)
    
    plt.plot(all_orientations, similarity_value, 'b')
    plt.plot(best_orientation, best_similarity_value, 'rv')
    plt.show()

initial_transform.SetAngle(best_orientation)
registration_method.SetInitialTransform(initial_transform, inPlace=False)
eulerTx = registration_method.Execute(fixed_image, moving_image)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(eulerTx)

out = resampler.Execute(moving_image)
del registration_method, initial_transform, eulerTx
moving_image = sitk.Cast(out, sitk.sitkFloat32)

# +--------------------------+
# |   displacement method    |
# +--------------------------+
initial_transform = sitk.CenteredTransformInitializer(fixed_image, moving_image, sitk.AffineTransform(fixed_image.GetDimension()))
registration_method = sitk.ImageRegistrationMethod()
# registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=500)
registration_method.SetOptimizerAsGradientDescent(learningRate=10,
                                                  numberOfIterations=300,
                                                  estimateLearningRate=registration_method.EachIteration,
                                                 )
registration_method.SetOptimizerScalesFromPhysicalShift()
registration_method.SetInitialTransform(initial_transform)
registration_method.SetInterpolator(sitk.sitkNearestNeighbor)
outTx1 = registration_method.Execute(fixed_image, moving_image)
displacementField = sitk.TransformToDisplacementFieldFilter()
displacementField.SetReferenceImage(fixed_image)
displacementTx = sitk.DisplacementFieldTransform(displacementField.Execute(sitk.Transform(2, sitk.sitkIdentity)))
displacementTx.SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0, varianceForTotalField=1.5)

registration_method.SetMovingInitialTransform(outTx1)
registration_method.SetInitialTransform(displacementTx, inPlace=True)
registration_method.SetMetricAsANTSNeighborhoodCorrelation(40)
registration_method.MetricUseFixedImageGradientFilterOff()

registration_method.Execute(fixed_image, moving_image)
compositeTx = sitk.CompositeTransform([outTx1, displacementTx])
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(compositeTx)

out = resampler.Execute(moving_image)

The question I have:

How do I track the metric change and optimization parameters?
I have tried to print them using the following helper function:

def command_iteration(method):
    global metric_values
    print("is this running? ")
    if method.GetOptimizerIteration() == 0:
        print(f"\tLevel: {method.GetCurrentLevel()}")
        print(f"\tScales: {method.GetOptimizerScales()}")
    print(f"#{method.GetOptimizerIteration()}")
    print(f"\tMetric Value: {method.GetMetricValue():10.5f}")
    print(f"\tLearningRate: {method.GetOptimizerLearningRate():10.5f}")
    if method.GetOptimizerConvergenceValue() != sys.float_info.max:
        print(
            "\tConvergence Value: "
            + f"{method.GetOptimizerConvergenceValue():.5e}"
        )
    metric_values.append(method.GetMetricValue())
    # print(f"{method.():3} = {method.GetMetricValue():10.5f}")
    plt.plot(metric_values, 'r')
    # plt.plot(multires_iterations, [metric_values[index] for index in multires_iterations], 'b*')
    plt.xlabel('Iteration Number', fontsize=12)
    plt.ylabel('Metric Value', fontsize=12)
    plt.show()

buy when I add this function in the above codes like:

registration_method.Execute(fixed_image, moving_image)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
compositeTx = sitk.CompositeTransform([outTx1, displacementTx])

It doesn’t do anything or run.

If I use it as below:

registration_method.MetricUseFixedImageGradientFilterOff()
# if registration_method.GetOptimizerIteration() == 31:
    # metric_v = []
# metric_v.append(registration_method.GetMetricValue())
registration_method.AddCommand(sitk.sitkIterationEvent, lambda : command_iteration(registration_method))
registration_method.Execute(fixed_image, moving_image)

Gives the following error:

registration_method.Execute(fixed_image, moving_image)
  File "C:\Users\banikr\Miniconda3\envs\reg37\lib\site-packages\SimpleITK\SimpleITK.py", line 10859, in Execute
    val = _SimpleITK.ImageRegistrationMethod_Execute(self, fixed, moving)
RuntimeError: Exception thrown in SimpleITK ImageRegistrationMethod_Execute: D:\a\1\sitk\Wrapping\Python\sitkPyCommand.cxx:135:
sitk::ERROR: There was an error executing the Python Callable.

Thanks all again.

1 Like

Hello @banikr,

The problem is that you are not creating the metric_values list variable anywhere. You need to create it before registration and then the command_iteration declares it global so it can refer to the same list across multiple function calls.

The approach illustrated in this notebook with the metric_start_plot, metric_end_plot and metric_plot_values creates the list in the beginning, sitkStartEvent, adds values to it during, sitkIterationEvent, execution and then deletes it when done, sitkEndEvent .

2 Likes

Hi,
I have plotted the registration metric at iterations.
image
which pretty much plateaued around 700 iterations.
Does it mean, Demons displacement field method can’t improve the registration performance further?


I am actually started getting good results.

What other methods would you guys suggest that I shall experiment with?
Thanks to all of you again for helping me out.

2 Likes

Hello @banikr,

Not sure about your reference to the Demons registration method, also 700 iterations. The code samples provided in this discussion do not use that algorithm and they have numberOfIterations=300.

It is hard to help when you are sharing contradictory information. There is no need to share the code you actually used, but you do need to describe what you are doing. Are you using the Demons algorithm? Is it applied to the signed distance maps?

Assuming you are using the Demons algorithm and that the code previously shared is defunct, possibly take a look at this jupyter notebook which illustrates the usage of the Demons family of algorithms in a multi-resolution setting.

1 Like

@zivy thanks for the reminder.
The demons code are as below:

def command_iteration(filter):
    global metric_values
    print(f"{filter.GetElapsedIterations():3} = {filter.GetMetric():10.5f}")
    metric_values.append(filter.GetMetric())

metric_values = []
demons = sitk.FastSymmetricForcesDemonsRegistrationFilter()
demons.SetNumberOfIterations(5000)
demons.SetStandardDeviations(1.2)
demons.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(demons))

transform_to_displacment_field_filter = sitk.TransformToDisplacementFieldFilter()
transform_to_displacment_field_filter.SetReferenceImage(fixed_image)
displacementTx = sitk.DisplacementFieldTransform(transform_to_displacment_field_filter.Execute
                                                    (sitk.Transform(2, sitk.sitkIdentity))
                                                    )
displacementTx.SetSmoothingGaussianOnUpdate(varianceForUpdateField=0.0, varianceForTotalField=1.5)

displacementField = transform_to_displacment_field_filter.Execute(displacementTx)
displacementField = demons.Execute(fixed_image, moving_image, displacementField)
outTx = sitk.DisplacementFieldTransform(displacementField)
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_image)
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
resampler.SetDefaultPixelValue(0)
resampler.SetTransform(outTx)
out = resampler.Execute(moving_image)

Also, I do the same Euler2D alignment before demons, which I shared in previous displacement field method codes.