Differentiable Bspline transformation

I’ve been trying to use B-spline transformation using Neural Networks in PyTorch. I’ve been using the model from here. However, when I use the predicted parameters in SITK, the transformation results do not match. I wanted to know if there is any differentiable BSpline transformation that works in the same way as the BSpline transformation in SITK.

Thanks in advance

Hello @prms,

Not sure I understand your question. Possibly provide additional details or figures to clarify the issue. As far as I could tell you have a deep learning “black box” which predicts the displacement of b-spline control points and when you apply these using the SimpleITK BSplineTransform, they do not create the expected deformation?

Do you also have the grid structure used by the “black box”, the initial grid locations are required for working with the SimpleITK transformation. Bottom line, you need to translate between the Bspline grid structure and displacements used by the deep learning model and those used by SimpleITK.

Hi @zivy,

I have been using the BSplineSTN3D model from here which is based in BSplineTransformation from here. The NN outputs the parameters and those parameters are used to deform the image. I applied those parameters in BSplineTransform of SITK which results in a different image. The mentioned code has methods to compute the grid using the displacement field. They are using some convolutional operations and ultimately reduce to the shape of the volume that is being used.

I also wanted to visualize the deformation field thinking that might help. Is there some tool in SITK to visualize grid?

import SimpleITK as sitk
import torch
import numpy as np
import pickle

with open("bspline_params.pkl", "rb") as f:
    bspline_params = pickle.load(f).detach().cpu().numpy().squeeze()

img = sitk.ReadImage("mri.nii")
image_size = img.GetSize()
spacing = (10, 10, 10)
mesh_size = np.ceil(np.array(image_size) / spacing).astype(np.int).tolist()

bspline_transform = sitk.BSplineTransform(3, 2)

bspline_params_list = bspline_params.flatten().tolist()

resampled_image = sitk.Resample(img, img, bspline_transform, sitk.sitkNearestNeighbor, 0.0)
sitk.WriteImage(resampled_image, "transformed_using_sitk.nii")

The output is attached below(left is sitk transformation and is close to original image, right is Pytorch transformation which seems heavily deformed)

this is the parameters dump used and this is the image

I had posted another question in the project repo. Turns out, we cannot relate these transformations. Now I am wondering whether I am using the BSplineTransform correctly or not since it does not seem to warp the image.

Slicer can visualize transforms nicely. See some examples in the docs.

1 Like

@dzenanz Thanks for the response. My problem is not being able to figure out why two deformations are so different. Is there any way to visualize the deformations in SimpleITK? My other need is to be able to find the new points after deformation using the deformation field. Is this possible?

Hello @prms,

I would go with visualization using Slicer as recommended by @dzenanz.

If you are willing to use a crude way of visualizing, then you can use the GridImageSource and apply the deformation fields to that image to understand how they vary in space. This approach is used in this Jupyter notebook, section titled Radial Distortion.

The deformation field, DisplacementFieldTransform, has the same interface as the rest of the transformations, so just use the TransformPoint method to map your points. Remember to use physical points and not the image indexes (move between the two representations using the image’s TransformIndexToPhysicalPoint and TransformPhysicalPointToIndex methods).

Yes. In Slicer, got to the transforms module, select your transform and then add your points to the list of transformed nodes. ITK and SimpleITK do not have visualization capabilities of their own. There is a VTKGlue module with only basic capabilities.

Thank you @zivy and @dzenanz. I’ll go through your suggestions.