I am trying to register thorax “bones” (ribs, sternum, vertebrae, costal cartilages).
As expected image registration gets stuck in local minima and confuses ribs/vertebrae.
Therefore I am trying to use the pointset registration in ITK. So far I am quite happy with the results. Affine is quite fast, but BSpline registration is really slow.
What I don’t understand is what is actually slow. Initializing the BSpline is fast. Once it has started it runs nearly as fast as the affine part. But before it really starts it takes forever.
This is my code (yes I mix SimpleITK and ITK, but thats not the issue here):
def pointset_registration(
fixed_labels: sitk.Image,
moving_labels: sitk.Image,
transform_file: Path,
narrow_band: float = 1.0,
transform_type: Transform = Transform.affine.value,
initial_transform: Path | None = None,
num_iterations: int = 50,
):
"""Register two label maps via pointset registration"""
# Extract point sets
fixed_labels = crop_foreground(fixed_labels, fixed_labels != 0)
moving_labels = crop_foreground(moving_labels, moving_labels != 0)
D = fixed_labels.GetDimension()
fixed_set, _ = extract_pointset(fixed_labels, narrow_band=narrow_band)
moving_set, _ = extract_pointset(moving_labels, narrow_band=narrow_band)
if transform_type == Transform.affine:
AffineTransformType = itk.AffineTransform[itk.D, D]
transform = AffineTransformType.New()
transform.SetIdentity()
elif transform_type == Transform.euler3d:
EulerTransformType = itk.Euler3DTransform[itk.D]
transform = EulerTransformType.New()
transform.SetIdentity()
elif transform_type == Transform.translation:
TranslationTransformType = itk.TranslationTransform[itk.D, D]
transform = TranslationTransformType.New()
transform.SetIdentity()
elif transform_type == Transform.bspline:
transform = init_bspline_transform(fixed_labels, 75.0)
if initial_transform:
init_transform = itk.transformread(str(initial_transform))[0]
composite_transform = itk.CompositeTransform[itk.D, D].New()
composite_transform.AddTransform(init_transform)
composite_transform.AddTransform(transform)
composite_transform.SetOnlyMostRecentTransformToOptimizeOn()
transform = composite_transform
# Define types
PointSetType = type(moving_set)
PointSetMetricType = itk.LabeledPointSetToPointSetMetricv4[PointSetType]
ShiftScalesType = itk.RegistrationParameterScalesFromPhysicalShift[
PointSetMetricType
]
OptimizerType = itk.RegularStepGradientDescentOptimizerv4[itk.D]
metric = PointSetMetricType.New(
FixedPointSet=fixed_set,
MovingPointSet=moving_set,
MovingTransform=transform,
)
metric.Initialize()
shift_scale_estimator = ShiftScalesType.New(
Metric=metric, VirtualDomainPointSet=metric.GetVirtualTransformedPointSet()
)
optimizer = OptimizerType.New(
Metric=metric,
NumberOfIterations=num_iterations,
ScalesEstimator=shift_scale_estimator,
MaximumStepSizeInPhysicalUnits=fixed_labels.GetSpacing()[0],
MinimumConvergenceValue=0.0,
ConvergenceWindowSize=20, # default: 50
DoEstimateLearningRateOnce=True,
)
def print_iteration():
print(
f"It: {optimizer.GetCurrentIteration()}"
f" metric value: {optimizer.GetCurrentMetricValue():.6f} "
)
optimizer.AddObserver(itk.IterationEvent(), print_iteration)
# Run optimization to align the point sets
print("Start optimization")
optimizer.StartOptimization()
I get to “Start optimization” quite fast, but it takes minutes until the first iteration is printed. 250 iterations run in 20 second.
Should I be aware of something e.g. related the VirtualTransformedPointSet
or the shift_scale_estimator
?