Fixed Mask Not Working as Expected

Unfortunately this isn’t the same issue. I’m using the simplest registration setup I could (the patch appears to be specific to mutual information which I’m not using).

The current issue is more generic/fundamental to the registration framework - it is not a problem with a specific metric it is a “feature” of the ImageRegistrationMethodv4 (see the link in my original post). The use of the mask in the larger image is expected to limit the sampling so that it isn’t across the whole image, only inside the mask. The use of rejection sampling to do this results in too few samples, so registration fails even when the user is providing enough information for it to succeed. In this example I expected the mask to be equivalent to cropping the fixed image to the ROI.

The SimpleITK code illustrating the issue is below so that you have all of the information (see which optimizer/transform/mask I use + any bugs I have :wink: ) .

import SimpleITK as sitk
import numpy as np


alpha = 0.5

####################################
# Register the two images, starting with the identity transform. The two
# images completely overlap.
####################################
def estimate_translation(fixed_image, moving_image, fixed_mask):
    transform = sitk.TranslationTransform(2)

    registration_method = sitk.ImageRegistrationMethod()

    registration_method.SetMetricAsCorrelation()
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)

    number_samples = 256

    percentage = number_samples/ np.prod(fixed_image.GetSize())
    
    registration_method.SetMetricSamplingPercentage(percentage)
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift()

    registration_method.SetInitialTransform(transform, inPlace=True)

    registration_method.SetMetricFixedMask(fixed_mask)

    final_transform = registration_method.Execute(sitk.Cast(fixed_image, sitk.sitkFloat32), 
                                                  sitk.Cast(moving_image, sitk.sitkFloat32))
    return transform.GetOffset()


####################################
# Create images, fixed and moving and the fixed image mask.
# Image content is a rectangle which is shifted by dx,dy in the moving image.
####################################
def create_images_and_fixed_mask(img_height, img_width, dx, dy):

    rect_width = 30
    rect_height = 20
    rect_start_x = 40
    rect_start_y = 20

    npa = np.zeros((img_height, img_width))
    npa[rect_start_y:rect_start_y+rect_height,rect_start_x:rect_start_x+rect_width] = 255
    fixed_image = sitk.GetImageFromArray(npa)

    mask_start_y = rect_start_y - 5
    mask_end_y = rect_start_y+rect_height + 5
    mask_start_x = rect_start_x - 5
    mask_end_x = rect_start_x+rect_width + 5

    npa[mask_start_y:mask_end_y, mask_start_x:mask_end_x] = 255
    fixed_mask = sitk.GetImageFromArray(npa) == 255

    npa = np.zeros((img_height, img_width))
    npa[rect_start_y+dy:rect_start_y+rect_height+dy,rect_start_x+dx:rect_start_x+rect_width+dx] = 255
    moving_image = sitk.GetImageFromArray(npa)

    return (fixed_image, moving_image, fixed_mask)

dx = 0
dy = 5

img_height = 64
img_width = 128
fixed_image, moving_image, fixed_mask = create_images_and_fixed_mask(img_height, img_width, dx, dy)
sitk.WriteImage(sitk.Cast(fixed_image,sitk.sitkUInt8), "small_fixed_image.png")
sitk.WriteImage(sitk.Cast(moving_image, sitk.sitkUInt8), "small_moving_image.png")
sitk.WriteImage(sitk.Cast(fixed_mask*255, sitk.sitkUInt8), "small_fixed_image_mask.png")                
#sitk.Show((1.0 - alpha)*fixed_image + alpha*moving_image, "small combined image before registration")

errors_x = []
errors_y = []
for i in range(1000):
    est_x, est_y = estimate_translation(fixed_image, moving_image, fixed_mask)
    errors_x.append(np.abs(est_x - dx))
    errors_y.append(np.abs(est_y - dy))
    
print('small image errors x mean(std), errors y mean(std): {0} ({1}) {2} ({3})'.format(np.mean(errors_x), np.std(errors_x), np.mean(errors_y), np.std(errors_y)))

img_height = 128
img_width = 256
fixed_image, moving_image, fixed_mask = create_images_and_fixed_mask(img_height, img_width, dx, dy)
sitk.WriteImage(sitk.Cast(fixed_image,sitk.sitkUInt8), "large_fixed_image.png")
sitk.WriteImage(sitk.Cast(moving_image,sitk.sitkUInt8), "large_moving_image.png")
sitk.WriteImage(sitk.Cast(fixed_mask*255, sitk.sitkUInt8), "large_fixed_image_mask.png")                
#sitk.Show((1.0 - alpha)*fixed_image + alpha*moving_image, "large combined image before registration")

errors_x = []
errors_y = []
for i in range(1000):
    est_x, est_y = estimate_translation(fixed_image, moving_image, fixed_mask)
    errors_x.append(np.abs(est_x - dx))
    errors_y.append(np.abs(est_y - dy))
    
print('large image errors x mean(std), errors y mean(std): {0} ({1}) {2} ({3})'.format(np.mean(errors_x), np.std(errors_x), np.mean(errors_y), np.std(errors_y)))