itk::ImageRegistrationMethod does not return the best value from all iterations


(donelron) #1

I am using the following lines of code (basically copied from an the ITK example https://itk.org/Wiki/ITK/Examples/Registration/MutualInformation) to register two synthetic images with Mutual information:

#include "itkImageRegistrationMethod.h"
#include "itkTranslationTransform.h"
#include "itkMutualInformationImageToImageMetric.h"
#include "itkGradientDescentOptimizer.h"
#include "itkNormalizeImageFilter.h"
#include "itkDiscreteGaussianImageFilter.h"
#include "itkResampleImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkCheckerBoardImageFilter.h"
#include "itkEllipseSpatialObject.h"
#include "itkSpatialObjectToImageFilter.h"
#include "itkImageFileWriter.h"

#include "itkRescaleIntensityImageFilter.h"
#include "itkCommandIterationUpdate.h"
#include "itkMacro.h"


const    unsigned int    Dimension = 2;
typedef  unsigned char           PixelType;

typedef itk::Image< PixelType, Dimension >  ImageType;

static void CreateEllipseImage(ImageType::Pointer image);
static void CreateCircleImage(ImageType::Pointer image);


typedef   float InternalPixelType;
typedef itk::Image< float, 2> InternalImageType;

typedef itk::ImageRegistrationMethod<
    InternalImageType,
    InternalImageType >  RegistrationType;



class CommandIterationUpdate : public itk::Command
{
public:
    using Self = CommandIterationUpdate;
    using Superclass = itk::Command;
    using Pointer = itk::SmartPointer<CommandIterationUpdate>;
    itkNewMacro(CommandIterationUpdate);
protected:
    CommandIterationUpdate() = default;

    using InternalImageType = itk::Image< float, 2 >;
    using VectorPixelType = itk::Vector< float, 2 >;
    using DisplacementFieldType = itk::Image<  VectorPixelType, 2 >;

    using RegistrationFilterType = itk::ImageRegistrationMethod<
        InternalImageType,
        InternalImageType >;

public:

    void Execute(itk::Object *caller, const itk::EventObject & event) override
    {
        Execute((const itk::Object *)caller, event);
    }

    void Execute(const itk::Object * object, const itk::EventObject & event) override
    {        
        const auto * optimizerMethod = static_cast<const itk::GradientDescentOptimizer *>(object);

        if (optimizerMethod == nullptr)
        {
            return;
        }
        if (!(itk::IterationEvent().CheckEvent(&event)))
        {
            return;
        }

        std::cout << "it " << optimizerMethod->GetCurrentIteration() << ", transl = " << optimizerMethod->GetCurrentPosition() 
          << ", value=" << optimizerMethod->GetValue()
          << "\n";
    }
};


int main( int argc, char *argv[] )
{

  // Generate synthetic fixed and moving images
  ImageType::Pointer  fixedImage = ImageType::New();
  CreateCircleImage(fixedImage);
  ImageType::Pointer movingImage = ImageType::New();
  CreateEllipseImage(movingImage);

  // Normalize the images
  typedef itk::NormalizeImageFilter<ImageType, InternalImageType> NormalizeFilterType;
  NormalizeFilterType::Pointer fixedNormalizer = NormalizeFilterType::New();
  NormalizeFilterType::Pointer movingNormalizer = NormalizeFilterType::New();

  fixedNormalizer->SetInput(  fixedImage);
  movingNormalizer->SetInput( movingImage);

  typedef itk::TranslationTransform< double, Dimension > TransformType;
  typedef itk::GradientDescentOptimizer                  OptimizerType;
  typedef itk::LinearInterpolateImageFunction< InternalImageType, double > InterpolatorType;
  
  typedef itk::MutualInformationImageToImageMetric<
                                          InternalImageType,
                                          InternalImageType >    MetricType;

  TransformType::Pointer      transform     = TransformType::New();
  OptimizerType::Pointer      optimizer     = OptimizerType::New();
  InterpolatorType::Pointer   interpolator  = InterpolatorType::New();
  RegistrationType::Pointer   registration  = RegistrationType::New();

  registration->SetOptimizer(     optimizer     );
  registration->SetTransform(     transform     );
  registration->SetInterpolator(  interpolator  );

  MetricType::Pointer         metric        = MetricType::New();
  registration->SetMetric( metric  );

  metric->SetFixedImageStandardDeviation(  0.4 );
  metric->SetMovingImageStandardDeviation( 0.4 );

  registration->SetFixedImage( fixedNormalizer->GetOutput()    );
  registration->SetMovingImage(movingNormalizer->GetOutput()   );

  fixedNormalizer->Update();
  ImageType::RegionType fixedImageRegion = fixedNormalizer->GetOutput()->GetBufferedRegion();
  registration->SetFixedImageRegion( fixedImageRegion );

  typedef RegistrationType::ParametersType ParametersType;
  ParametersType initialParameters( transform->GetNumberOfParameters() );

  initialParameters[0] = 0.0;  // Initial offset along X
  initialParameters[1] = 0.0;  // Initial offset along Y

  registration->SetInitialTransformParameters( initialParameters );  
  const unsigned int numberOfPixels = fixedImageRegion.GetNumberOfPixels();
  const unsigned int numberOfSamples = static_cast<unsigned int>(numberOfPixels * 0.01);
  metric->SetNumberOfSpatialSamples( numberOfSamples );

  optimizer->SetLearningRate(50.0);
  optimizer->SetNumberOfIterations(20);
  optimizer->MaximizeOn(); // We want to maximize mutual information (the default of the optimizer is to minimize)

  // Connect an observer
  CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
  optimizer->AddObserver( itk::IterationEvent(), observer );

  try
    {
    registration->Update();
    std::cout << "Optimizer stop condition: "
              << registration->GetOptimizer()->GetStopConditionDescription()
              << std::endl;
    }
  catch( itk::ExceptionObject & err )
    {
    std::cout << "ExceptionObject caught !" << std::endl;
    std::cout << err << std::endl;
    return EXIT_FAILURE;
    }

  ParametersType finalParameters = registration->GetLastTransformParameters();

  double TranslationAlongX = finalParameters[0];
  double TranslationAlongY = finalParameters[1];

  unsigned int numberOfIterations = optimizer->GetCurrentIteration();

  double bestValue = optimizer->GetValue();


  // Print out results
  std::cout << std::endl;
  std::cout << "Result = " << std::endl;
  std::cout << " Translation X = " << TranslationAlongX  << std::endl;
  std::cout << " Translation Y = " << TranslationAlongY  << std::endl;
  std::cout << " Iterations    = " << numberOfIterations << std::endl;
  std::cout << " Metric value  = " << bestValue          << std::endl;
  std::cout << " Numb. Samples = " << numberOfSamples    << std::endl;

  return EXIT_SUCCESS;
}

void CreateEllipseImage(ImageType::Pointer image)
{
  typedef itk::EllipseSpatialObject< Dimension >   EllipseType;

  typedef itk::SpatialObjectToImageFilter<
    EllipseType, ImageType >   SpatialObjectToImageFilterType;

  SpatialObjectToImageFilterType::Pointer imageFilter =
    SpatialObjectToImageFilterType::New();

  ImageType::SizeType size;
  size[ 0 ] =  100;
  size[ 1 ] =  100;

  imageFilter->SetSize( size );

  ImageType::SpacingType spacing;
  spacing.Fill(1);
  imageFilter->SetSpacing(spacing);

  EllipseType::Pointer ellipse    = EllipseType::New();
  EllipseType::ArrayType radiusArray;
  radiusArray[0] = 10;
  radiusArray[1] = 20;
  ellipse->SetRadius(radiusArray);

  typedef EllipseType::TransformType                 TransformType;
  TransformType::Pointer transform = TransformType::New();
  transform->SetIdentity();

  TransformType::OutputVectorType  translation;
  TransformType::CenterType        center;

  translation[ 0 ] =  65;
  translation[ 1 ] =  45;
  transform->Translate( translation, false );

  ellipse->SetObjectToParentTransform( transform );

  imageFilter->SetInput(ellipse);

  ellipse->SetDefaultInsideValue(255);
  ellipse->SetDefaultOutsideValue(0);
  imageFilter->SetUseObjectValue( true );
  imageFilter->SetOutsideValue( 0 );

  imageFilter->Update();

  image->Graft(imageFilter->GetOutput());

}

void CreateCircleImage(ImageType::Pointer image)
{
 typedef itk::EllipseSpatialObject< Dimension >   EllipseType;

  typedef itk::SpatialObjectToImageFilter<
    EllipseType, ImageType >   SpatialObjectToImageFilterType;

  SpatialObjectToImageFilterType::Pointer imageFilter =
    SpatialObjectToImageFilterType::New();

  ImageType::SizeType size;
  size[ 0 ] =  100;
  size[ 1 ] =  100;

  imageFilter->SetSize( size );

  ImageType::SpacingType spacing;
  spacing.Fill(1);
  imageFilter->SetSpacing(spacing);

  EllipseType::Pointer ellipse    = EllipseType::New();
  EllipseType::ArrayType radiusArray;
  radiusArray[0] = 10;
  radiusArray[1] = 10;  
  ellipse->SetRadius(radiusArray);

  typedef EllipseType::TransformType                 TransformType;
  TransformType::Pointer transform = TransformType::New();
  transform->SetIdentity();

  TransformType::OutputVectorType  translation;
  TransformType::CenterType        center;

  translation[ 0 ] =  50;
  translation[ 1 ] =  50;
  transform->Translate( translation, false );

  ellipse->SetObjectToParentTransform( transform );

  imageFilter->SetInput(ellipse);

  ellipse->SetDefaultInsideValue(255);
  ellipse->SetDefaultOutsideValue(0);
  imageFilter->SetUseObjectValue( true );
  imageFilter->SetOutsideValue( 0 );

  imageFilter->Update();

  image->Graft(imageFilter->GetOutput());
}

Now when I look at the output it seems that the result of the last iteration (in this example 0.000934234) is always used as “best value”, even though there are better options: e.g. lines 3 and 4 (ie 0.00124208 and 0.00240695). Of course I can put up some mechanism to keep track of all the results of all iterations and then at the very end return the best one. But this is probably not how it was intended to be used?!? There is a function
virtual const ParametersType& itk::ImageRegistrationMethod< TFixedImage, TMovingImage >::GetLastTransformParameters ( ) const
. However, I would have expected a function GetBestTransformParameters() to actually return the translation that resulted in the max. mutual information. What am I missing here?!?

it 0, transl = [3.080745490862392e-27, -2.716527225092382e-25], value=0.000612325
it 1, transl = [5.1655689727723054e-23, 3.0754345820462986e-23], value=-0.00153213
it 2, transl = [4.7576327557073874e-23, 3.0754345820462986e-23], value=0.00124208
it 3, transl = [2.2864627443559803e-23, 3.0728931531586856e-23], value=0.00240695
it 4, transl = [-4.666378034618216e-19, 2.1827462357148532e-23], value=-0.0626463
it 5, transl = [-4.66734633157285e-19, -6.388334110244383e-23], value=-0.000920195
it 6, transl = [-4.654328593482208e-19, -6.587315526432373e-23], value=0.0228237
it 7, transl = [-4.659618853553521e-19, 1.434332129935076e-21], value=-0.0218188
it 8, transl = [-4.659649774797798e-19, 1.4587940542757805e-21], value=-0.14773
it 9, transl = [-4.660465615278749e-19, 1.4587889196999623e-21], value=0.00145463
it 10, transl = [-4.661200248067382e-19, 1.4587889196999623e-21], value=-0.00058087
it 11, transl = [-4.658168507396112e-19, 1.4685938068876238e-21], value=0.00610458
it 12, transl = [-4.658170963459659e-19, 1.4661027965906913e-21], value=6.32885e-05
it 13, transl = [-4.657769902544275e-19, 1.4630956262500453e-21], value=0.0076189
it 14, transl = [-4.657769902544275e-19, 1.4630956262500453e-21], value=-0.076171
it 15, transl = [-4.657769587250847e-19, 1.4630956262500453e-21], value=0.00135417
it 16, transl = [-4.657769587250847e-19, 1.4629990013936299e-21], value=0.000306046
it 17, transl = [-4.658379842485591e-19, 1.4629990013936299e-21], value=-0.00570804
it 18, transl = [-9.334279237000942e-19, 1.6765176189990196e-21], value=-0.196626
it 19, transl = [-9.334320099934345e-19, 1.6765106878293262e-21], value=0.000934234
Optimizer stop condition: GradientDescentOptimizer: Maximum number of iterations (20) exceeded.

Result =
Translation X = -9.33432e-19
Translation Y = 1.67651e-21
Iterations = 20
Metric value = 0.000934234
Numb. Samples = 100


(Dženan Zukić) #2

Try calling optimizer->SetReturnBestParametersAndValue(true);