I am trying to register 2 3D volumes. I know that there is significant (rigid) shift between the two images, but registration keep stopping after first iteration. The gradient value is zero at the first iteration! The metric value is however non-zero.
The core of the code is as shown below. After a whole day to debugging my guess is that I am missing setting something correctly, but can’t figure out what it is!
Is there any obvious mistake that could lead to this problem?
using FixedImageType = ImageType;
using MovingImageType = ImageType;
using FixedImageReaderType = itk::ImageFileReader< FixedImageType >;
using MovingImageReaderType = itk::ImageFileReader< MovingImageType >;
using InternalPixelType = float;
using InternalImageType = itk::Image< InternalPixelType, ImageType::ImageDimension >;
// read input images
typename FixedImageReaderType::Pointer fixedImageReader = FixedImageReaderType::New();
typename MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();
fixedImageReader->SetFileName( filename1 );
movingImageReader->SetFileName( filename2 );
std::cout << "Setup input file readers."<<std::endl;
try
{
fixedImageReader->Update();
movingImageReader->Update();
}
catch (itk::ExceptionObject & error)
{
std::cerr << "Error while reading images: " << error << std::endl;
return EXIT_FAILURE;
}
FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();
MovingImageType::Pointer movingImage = movingImageReader->GetOutput();
// Setup registration framework...
using TransformType = itk::Euler3DTransform< double >;//, ImageType::ImageDimension >;
using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
using InterpolatorType = itk::LinearInterpolateImageFunction< InternalImageType, double>;
using RegistrationType = itk::ImageRegistrationMethodv4< InternalImageType, InternalImageType >;
//using MetricType = itk::MattesMutualInformationImageToImageMetricv4< InternalImageType, InternalImageType >;
using MetricType = itk::MeanSquaresImageToImageMetricv4< InternalImageType, InternalImageType >;
typename TransformType::Pointer transform = TransformType::New();
typename OptimizerType::Pointer optimizer = OptimizerType::New();
typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
typename RegistrationType::Pointer registration = RegistrationType::New();
typename MetricType::Pointer metric = MetricType::New();
registration->SetOptimizer(optimizer);
registration->SetMetric( metric );
registration->SetFixedImage( fixedImageReader->GetOutput() );
registration->SetMovingImage( movingImageReader->GetOutput() );
registration->SetObjectName("Registration Method");
using ParametersType = OptimizerType::ParametersType;
TransformType::NumberOfParametersType numberOfParameters = transform->GetNumberOfParameters();
using TransformInitializerType = itk::CenteredTransformInitializer<TransformType, FixedImageType, MovingImageType>;
TransformInitializerType::Pointer initializer = TransformInitializerType::New();
initializer->SetTransform(transform);
initializer->SetFixedImage(fixedImageReader->GetOutput());
initializer->SetMovingImage(movingImageReader->GetOutput());
initializer->GeometryOn();
initializer->InitializeTransform();
using OptimizerScalesType = OptimizerType::ScalesType;
OptimizerScalesType optimizerScales(transform->GetNumberOfParameters());
const double translationScale = 1.0 / 1000.0;
optimizerScales[0] = 1.0;
optimizerScales[1] = 1.0;
optimizerScales[2] = 1.0;
optimizerScales[3] = translationScale;
optimizerScales[4] = translationScale;
optimizerScales[4] = translationScale;
optimizer->SetScales(optimizerScales);
registration->SetMovingInitialTransform( transform );
typename TransformType::Pointer identityTransform = TransformType::New();
identityTransform->SetIdentity();
registration->SetFixedInitialTransform( identityTransform );
std::cout << "Setup registration method"<<std::endl;
constexpr unsigned int numberOfLevels1 = 1;
RegistrationType::ShrinkFactorsArrayType shrinkFactorsPerLevel1;
shrinkFactorsPerLevel1.SetSize(numberOfLevels1);
shrinkFactorsPerLevel1[0] = 3;
RegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel1;
smoothingSigmasPerLevel1.SetSize(numberOfLevels1);
smoothingSigmasPerLevel1[0] = 2;
registration->SetNumberOfLevels(numberOfLevels1);
registration->SetShrinkFactorsPerLevel(shrinkFactorsPerLevel1);
registration->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel1);
registration->SetMetricSamplingPercentage(0.5);
registration->MetricSamplingReinitializeSeed(121212);
using MaskObjectType = itk::ImageMaskSpatialObject< FixedImageType::ImageDimension >;
MaskObjectType::Pointer fixedMaskObject = MaskObjectType::New();
MaskObjectType::Pointer movingMaskObject = MaskObjectType::New();
using FixedThresholdFilterType = itk::BinaryThresholdImageFilter<FixedImageType, MaskObjectType::ImageType >;
FixedThresholdFilterType::Pointer fixedImageThresholdFilter = FixedThresholdFilterType::New();
fixedImageThresholdFilter->SetInput(fixedImageReader->GetOutput());
fixedImageThresholdFilter->SetOutsideValue(0);
fixedImageThresholdFilter->SetInsideValue( itk::NumericTraits< MaskObjectType::ImageType::PixelType>::max() );
fixedImageThresholdFilter->SetLowerThreshold(10);
fixedImageThresholdFilter->SetUpperThreshold(4098);
using MovingThresholdFilterType = itk::BinaryThresholdImageFilter<MovingImageType, MaskObjectType::ImageType >;
MovingThresholdFilterType::Pointer movingImageThresholdFilter = MovingThresholdFilterType::New();
movingImageThresholdFilter->SetInput(movingImageReader->GetOutput());
movingImageThresholdFilter->SetOutsideValue(0);
movingImageThresholdFilter->SetInsideValue( itk::NumericTraits< MaskObjectType::ImageType::PixelType >::max() );
movingImageThresholdFilter->SetLowerThreshold(10);
movingImageThresholdFilter->SetUpperThreshold(4098);
fixedImageThresholdFilter->Update();
fixedMaskObject->SetImage( fixedImageThresholdFilter->GetOutput() );
movingImageThresholdFilter->Update();
movingMaskObject->SetImage( movingImageThresholdFilter->GetOutput() );
//output of threshold filter looks OK
fixedMaskObject->Update();
movingMaskObject->Update();
// not sure how to verify if this object is generated correctly
metric->SetFixedImageMask(fixedMaskObject);
metric->SetMovingImageMask(movingMaskObject);
optimizer->SetNumberOfIterations( 200 );
optimizer->SetRelaxationFactor(0.5);
optimizer->SetLearningRate( 0.15 );
optimizer->SetMinimumStepLength(0.00001);
optimizer->SetGradientMagnitudeTolerance(0.0000000001);
std::cout << "Setup optimizer settings" << std::endl;
CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
optimizer->AddObserver( itk::IterationEvent(), observer );
using TranslationCommandType = RegistrationInterfaceCommand<RegistrationType>;
TranslationCommandType::Pointer command1 = TranslationCommandType::New();
registration->AddObserver(itk::MultiResolutionIterationEvent(), command1);
try
{
registration->Update();
std::cout << "Optimizer stop condition: "
<< registration->GetOptimizer()->GetStopConditionDescription()
<< std::endl;
}
catch( itk::ExceptionObject & err )
{
std::cout << "ExceptionObject caught !" << std::endl;
std::cout << err << std::endl;
return EXIT_FAILURE;
}