Sitk AffineTransform vs Pytorch grid sample

I’m comparing the result of SITK AffineTransform and Pytorch grid_sample. The difference between them is that sitk treats origin as the centre of rotation while Pytorch treats the centre of the image as the centre of rotation.

import SimpleITK as sitk
import numpy as np
import torch
import os
import pickle
import matplotlib.pyplot as plt
import copy
import imageio
import cv2
affine_matrix = np.array([[[ 1.0170,  0.0398,  0.0435, -0.0110],
         [-0.0713,  1.0165,  0.0446,  0.0221],
         [ 0.0387,  0.0189,  0.9905, -0.0033]]], dtype=np.float64)

torch_affine_matrix = torch.from_numpy(affine_matrix).unsqueeze(0)

img_3d = sitk.ReadImage("img.nii")
img_3d_tensor = torch.from_numpy(sitk.GetArrayFromImage(img_3d)).unsqueeze(0).unsqueeze(0)

# sitk transform
affine_3d_transform = sitk.AffineTransform(3)
affine_3d_transform.SetMatrix(theta_3d.squeeze().cpu().numpy()[:3, :3].flatten())
affine_3d_transform.SetTranslation(theta_3d.squeeze().cpu().numpy()[:, 3])

# affine_3d_transform.SetCenter((32, 32, 32))  # center of volume


# pytorch transform
resampled_3d_sitk = sitk.Resample(img_3d, img_3d, affine_3d_transform, sitk.sitkNearestNeighbor, 0.0)

sitk.WriteImage(resampled_3d_sitk, "sitk_3d_resampled.nii")

grid_3d = torch.nn.functional.affine_grid(theta_3d, img_3d_tensor.shape)
resampled_3d_pytorch = torch.nn.functional.grid_sample(img_3d_tensor, grid_3d)

resampled_3d_pytorch = sitk.GetImageFromArray(resampled_3d_pytorch.squeeze())
resampled_3d_pytorch.CopyInformation(img_3d)
sitk.WriteImage(resampled_3d_pytorch, "pytorch_3d_resampled.nii")

I’m not able to figure out why they are behaving differently. I might have missed something here. My assumption is that transforming the same volume by same matrix should give same result. Any help would be highly appreciated. Thanks in advance!

Result difference(left is SITK’s output and right is Pytorch’s output):

imge url: Dropbox - img.nii - Simplify your life

Hello @prms,

“The difference between them is that sitk treats origin as the centre of rotation while Pytorch treats the centre of the image as the centre of rotation.”

SimpleITK does not use the image origin as center of rotation for the global transformations (rigid, affine…). There are two options, either it uses the center of rotation specified by the user or the default which is [0,0,0]. For most medical images the origin is not at [0,0,0], check with img_3d.GetOrigin(). So you should explicitly set the center of rotation either to the image center or the image origin whichever you want (image center makes more sense in most cases). Please take a look at this Jupyter notebook section titled Resampling.

Hi @Zivy,
Thank you for pointing the resource. I was missing out on this. By origin I meant (0,0,0) in context of voxel space. I tried the same in 2D image by setting the center of rotation to (90, 140) of a (180, 280) image and the results are identical(it had unit direction cosine). As SITK treats volumes as physical objects, is it the case that center also expects the coordinates in physical coordinates instead of voxel coordinates? And would setting center to below values
affine_3d_transform.SetCenter(np.array(img_3d.GetOrigin()) + np.array(img_3d.TransformContinuousIndexToPhysicalPoint(img_3d.GetSize()))/2)

rotate the volume about its origin? I tried with this setting but still, the results are different. Also, I am using AffineTransform instead of Euler3DTransform. Is there anything else that needs to be considered?

Hello @prms,

Yes, you need to use the physical center of the image. You don’t need to add the center to the origin:

affine_3d_transform.SetCenter(np.array(img_3d.TransformContinuousIndexToPhysicalPoint(img_3d.GetSize()))/2)

Any solution to this? I have a similar problem, i.e. I have a pytorch affine transform that I want to “itk-ify”.

In other words I have a transform that correctly transforms an image in “tensor space” (an idealized space with no direction cosines, isotropic voxel size, etc) and I’d like to apply the same transform in physical space where the image may have non-standard direction cosines etc.

Here’s an example where I’ve simplified the transform to a simple set of scales:

fixed = 'C:/Dev//ixi//IXI002-Guys-0828-T1.nii'
fixed = sitk.ReadImage(fixed)

theta = np.array([1.25, 0, 0, 0, 0, 1.25, 0, 0, 0, 0, 1.25, 0])

# PYTORCH
fixed_tensor = torch.tensor(sitk.GetArrayFromImage(fixed)).float().unsqueeze(0).unsqueeze(0)
grid = torch.nn.functional.affine_grid(torch.tensor(theta), fixed_tensor.size())
fixed_tensor_rs = torch.nn.functional.grid_sample(fixed_tensor, grid.float(), mode='bilinear')

fixed_rs_torch = sitk.GetImageFromArray(fixed_tensor_rs.squeeze(0).squeeze(0).cpu().detach().numpy())
fixed_rs_torch.CopyInformation(fixed)
sitk.WriteImage(fixed_rs_torch, 'fixed_rs_torch.mha')

# SITK

theta_sitk = theta.transpose(0,2,1).flatten()
affine_transform = sitk.AffineTransform(3)
affine_transform.SetParameters(theta_sitk)
print(affine_transform)

fixed_rs_sitk = sitk.Resample(fixed, fixed, affine_transform)
sitk.WriteImage(fixed_rs_sitk, 'fixed_rs_sitk.mha')

Even in this case the images (though both are scaled down) are not identical.
How can I modify the theta transform to give an output consistenct with the output of pytorch?

Did you try setting the Center as @zivy explained above?

I did, it introduced an even greater discrepancy between torch/sitk.
I’m confused why changing the center has an impact on a scaling operation?

Hello @dapper416,

This has to do with how the ITK global transforms work. Please see this Jupyter notebook, specifically the introduction and the sections titled Similarity [2D] and Scale Transform.

Thanks @zivy. I will take a look but I am kinda new to this so a lot is going straight over my head.
Do you have any tips on things to try?

Hello @dapper416 ,

The code snippet you provided is missing something. The line
theta_sitk = theta.transpose(0,2,1).flatten() throws an exception given
theta = np.array([1.25, 0, 0, 0, 0, 1.25, 0, 0, 0, 0, 1.25, 0]).

Did you expect it to yield:
theta_sitk =[1.25, 0, 0, 0, 1.25, 0, 0, 0,1.25, 0, 0, 0]

The order of SimpleITK paramters is
[a00, a01, a02, a10, a11, a12, a30, a31, a32, t0, t1, t2] as you can see this doesn’t match the pytorch order which is
[a00, a01, a02, t0, a10, a11, a12, t1, a30, a31, a32, t2].

I suspect the transpose.flatten code is intended to do this rearranging, but it doesn’t seem to be working. Possibly a bug there?

Hi @zivy!

Sorry I made an error when copying the code.

This should reproduce the problem (also printing out the SITK affine transform to confirm that it’s just a scale matrix). You are right the transpose.flatten is intended to rearrange.

fixed = 'C:/Dev//ixi//IXI002-Guys-0828-T1.nii'
moving = 'C:/Dev//ixi//IXI012-HH-1211-T1.nii'

fixed = sitk.ReadImage(fixed)
moving = sitk.ReadImage(moving)

theta = np.array([1.25, 0, 0, 0, 0, 1.25, 0, 0, 0, 0, 1.25, 0]).reshape(1,3,4)

# PYTORCH
fixed_tensor = torch.tensor(sitk.GetArrayFromImage(fixed)).float().unsqueeze(0).unsqueeze(0)
grid = torch.nn.functional.affine_grid(torch.tensor(theta), fixed_tensor.size())
fixed_tensor_rs = torch.nn.functional.grid_sample(fixed_tensor, grid.float(), mode='bilinear')

fixed_rs_torch = sitk.GetImageFromArray(fixed_tensor_rs.squeeze(0).squeeze(0).cpu().detach().numpy())
fixed_rs_torch.CopyInformation(fixed)
sitk.WriteImage(fixed_rs_torch, 'fixed_rs_torch.mha')

# SITK

theta_sitk = theta.transpose(0,2,1).flatten()
affine_transform = sitk.AffineTransform(3)
affine_transform.SetParameters(theta_sitk)
print(affine_transform)

fixed_rs_sitk = sitk.Resample(fixed, fixed_rs_torch, affine_transform)
sitk.WriteImage(fixed_rs_sitk, 'fixed_rs_sitk.mha')