bspline registration from numpy array doesn't transform point

Hi,
i’m registering two DICOM images using affine and Bspline registrations. after registration, I’m using the transformation to transform points (in physical space). I’m creating an interactive interface that enable the user to register the images multiple times. that leads to a situation where sometime the input images are the path to the images (in the first registration) and sometimes they are numpy arrays. interestingly, I found different behavior between the cases. in some cases, when using numpy arrays, the images registered fine but the points are not transformed. the only reason i could thought of was that somehow the points are out of the boundaries of the transform when using numpy arrays as input.
these is a minimal example of comparing the two cases (with the same image):

def register(fixed, moving, fixed_meta, moving_meta, out_path, type="bspline", fixed_mask=None, moving_mask=None):
    fixed_reader = None
    try:
        fixed, fixed_reader = read_image(fixed)
        print("fixed reading finished ", fixed.GetSize())
        moving, moving_reader = read_image(moving)
        print("moving reading finished ", moving.GetSize())
        fixed_f = sitk.Cast(fixed, sitk.sitkFloat32)
        moving_f = sitk.Cast(moving, sitk.sitkFloat32)
    except Exception as e:
        fixed_f = sitk.Cast(sitk.GetImageFromArray(np.transpose(fixed, (2,0,1))), sitk.sitkFloat32)
        set_meta_data(fixed_f, fixed_meta)
        moving_f = sitk.Cast(sitk.GetImageFromArray((np.transpose(moving, (2,0,1)))),sitk.sitkFloat32)
        set_meta_data(moving_f, moving_meta)
    
    if type == "bspline":
        outTx = bspline_nonrigid_registration(fixed_f, moving_f, out_path, fixed_mask,moving_mask)
    elif type == "euler":
        outTx = euler_rigid_registration(fixed_f, moving_f, out_path, fixed_mask,moving_mask)
    else:
        raise ValueError("transformation of type: %s does not exist" %type)
    warped = warp_image(fixed_f, moving_f, outTx)
    if moving_mask is not None and fixed_mask is not None:
        warped_mask = warp_image(fixed_mask, moving_mask, outTx)
        warped_mask =np.transpose(sitk.GetArrayFromImage(warped_mask),(1,2,0))
    else:
        warped_mask = None
    file_name = type + ".tfm"
    sitk.WriteTransform(outTx, os.path.join(out_path,file_name))
   
    return np.transpose(sitk.GetArrayFromImage(fixed_f),(1,2,0)),\
           np.transpose(sitk.GetArrayViewFromImage(moving_f),(1,2,0)),\
           np.transpose(sitk.GetArrayFromImage(warped),(1,2,0)),\
           warped_mask, outTx

def set_meta_data(img, meta):
    try:
        spacing = meta["pixelSpacing"][0], meta["pixelSpacing"][1], meta["sliceSpacing"]
    except Exception as e:
        spacing = meta["pixelSpacing"][0], meta["pixelSpacing"][1], meta["sliceThickness"]
    img.SetSpacing(spacing)
    ipp = meta['IPP'].value
    img.SetOrigin(ipp)

def warp_image(fixed, moving, outTx):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(fixed)
    resampler.SetInterpolator(sitk.sitkLinear)
    resampler.SetDefaultPixelValue(-1024)
    resampler.SetTransform(outTx)
    warped_img = resampler.Execute(moving)
    return warped_img

def bspline_nonrigid_registration(fixed_image, moving_image, out_path, fixed_mask=None, moving_mask=None):
    registration_method = sitk.ImageRegistrationMethod()

    grid_physical_spacing = [50.0, 50.0, 50.0] # A control point every 5mm
    image_physical_size = [size*spacing for size,spacing in zip(fixed_image.GetSize(), fixed_image.GetSpacing())]
    mesh_size = [int(image_size/grid_spacing + 0.5) \
                 for image_size,grid_spacing in zip(image_physical_size,grid_physical_spacing)]
   
    mesh_size = [int(sz/4 + 0.5) for sz in mesh_size]
    print(mesh_size)

    initial_transform = sitk.BSplineTransformInitializer(image1=fixed_image,
                                                         transformDomainMeshSize=mesh_size, order=3)
  
    registration_method.SetInitialTransformAsBSpline(initial_transform,
                                                     inPlace=False,
                                                     scaleFactors=[1,2,4])

    registration_method.SetMetricAsMeanSquares()
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01, seed=42)

    if fixed_mask is not None:
        registration_method.SetMetricFixedMask(fixed_mask)
    if moving_mask is not None:
        registration_method.SetMetricMovingMask(moving_mask)


    registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
    registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
    registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    registration_method.SetInterpolator(sitk.sitkLinear)
    registration_method.SetOptimizerAsLBFGS2(solutionAccuracy=1e-6, numberOfIterations=2,
                                             deltaConvergenceTolerance=1e-6)

    registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))
    plot_metric(registration_method)

    final_transformation = registration_method.Execute(fixed_image, moving_image)
    exceute_metric_plot(registration_method, out_path, "bspline")
    print('\nOptimizer\'s stopping condition, {0}'.format(registration_method.GetOptimizerStopConditionDescription()))
    return final_transformation


fixed_path = get_all_dicom_files(r"C:\Users\ilaym\Desktop\Dart\seedsMovement\cases\panc\0", {})
moving_path = get_all_dicom_files(r"C:\Users\ilaym\Desktop\Dart\seedsMovement\cases\panc\30", {})

fixed_np = read_dicom(fixed_path['CT'], fixed_path['meta'])
moving_np = read_dicom(moving_path['CT'], moving_path['meta'])

fixed_np,moving_np, wraped_np, wraped_bbox_np, outx_np = register(fixed_np, moving_np, fixed_path['meta'], moving_path['meta'], "./")
fixed,moving, wraped, wraped_bbox, outx = register(fixed_path['CT'], moving_path['CT'], fixed_path['meta'], moving_path['meta'], "./")

overlay_images(fixed_np, wraped_np)
overlay_images(fixed, wraped)
print("*****from array*****\n", outx_np)
print("*****from path*****\n", outx)
p = [-2,135,52]
print("path transform ", outx.TransformPoint(p))
print("np transform ", outx_np.TransformPoint(p))

the result is:
from array
itk::simple::Transform
CompositeTransform (000001EC73DBCCC0)
RTTI typeinfo: class itk::CompositeTransform<double,3>
Reference Count: 1
Modified Time: 1886173
Debug: Off
Object Name:
Observers:
none
Transforms in queue, from begin to end:

BSplineTransform (000001EC75A835B0)
RTTI typeinfo: class itk::BSplineTransform<double,3,3>
Reference Count: 1
Modified Time: 1886165
Debug: Off
Object Name:
Observers:
none
CoefficientImage: [ 000001EC18756720, 000001EC18755640, 000001EC18751590 ]
TransformDomainOrigin: [-202.686, -112.3, 269.92]
TransformDomainPhysicalDimensions: [406.205, 406.205, 314]
TransformDomainDirection: 1 0 0
0 1 0
0 0 1

 TransformDomainMeshSize: [8, 8, 8]
 GridSize: [11, 11, 11]
 GridOrigin: [-253.462, -163.076, 230.67]
 GridSpacing: [50.7756, 50.7756, 39.25]
 GridDirection: 1 0 0

0 1 0
0 0 1

End of MultiTransform.
<<<<<<<<<<
TransformsToOptimizeFlags, begin() to end():
1
TransformsToOptimize in queue, from begin to end:
End of TransformsToOptimizeQueue.
<<<<<<<<<<
End of CompositeTransform.
<<<<<<<<<<

from path
itk::simple::Transform
CompositeTransform (000001EC73DBC680)
RTTI typeinfo: class itk::CompositeTransform<double,3>
Reference Count: 1
Modified Time: 3793676
Debug: Off
Object Name:
Observers:
none
Transforms in queue, from begin to end:

BSplineTransform (000001EC75A84180)
RTTI typeinfo: class itk::BSplineTransform<double,3,3>
Reference Count: 1
Modified Time: 3793668
Debug: Off
Object Name:
Observers:
none
CoefficientImage: [ 000001EC18755370, 000001EC18755BE0, 000001EC18754290 ]
TransformDomainOrigin: [-202.686, -112.3, -44.08]
TransformDomainPhysicalDimensions: [406.205, 406.205, 314]
TransformDomainDirection: 1 0 0
0 1 0
0 0 1

 TransformDomainMeshSize: [8, 8, 8]
 GridSize: [11, 11, 11]
 GridOrigin: [-253.462, -163.076, -83.33]
 GridSpacing: [50.7756, 50.7756, 39.25]
 GridDirection: 1 0 0

0 1 0
0 0 1

End of MultiTransform.
<<<<<<<<<<
TransformsToOptimizeFlags, begin() to end():
1
TransformsToOptimize in queue, from begin to end:
End of TransformsToOptimizeQueue.
<<<<<<<<<<
End of CompositeTransform.
<<<<<<<<<<

path transform (3.835009400038694, 137.46439367370127, 32.945338954151126)
np transform (-2.0, 135.0, 52.0)

Process finished with exit code 0

as you can see, there is a difference in the transformatin in the TransformDomainOrigin field. i dont understand why, since i’m passing meta data when usign arrays. in addition, you can see that the outx_np didnt transform the point p
thanks!
Ilay

1 Like