You came very close.
registration->SetFixedImage(fixedImage);
registration->SetMovingImage(movingImage);
needed to be changed to
registration->SetFixedImage(fixedImage);
registration->SetMovingImage(movingImage);
registration->SetFixedImage(1, fixedImage);
registration->SetMovingImage(1, movingImage);
There needs to be a separate pair of inputs for each metric in the multi-metric, as the different metrics don’t necessarily take the same inputs (e.g. first metric takes images, the second metric takes point sets). This should probably be noted in the documentation. Do you mind contributing that?
For posterity, here is a complete example which runs on my computer:
#include "itkImageRegistrationMethodv4.h"
#include "itkTranslationTransform.h"
#include "itkMeanSquaresImageToImageMetricv4.h"
#include "itkRegularStepGradientDescentOptimizerv4.h"
#include "itkImageFileReader.h"
#include "itkImageFileWriter.h"
#include "itkResampleImageFilter.h"
#include "itkCastImageFilter.h"
#include "itkCorrelationImageToImageMetricv4.h"
template <typename TFilter>
class itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate : public itk::Command
{
public:
using Self = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate;
using Superclass = itk::Command;
using Pointer = itk::SmartPointer<Self>;
itkNewMacro(Self);
protected:
itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate() = default;
public:
void
Execute(itk::Object * caller, const itk::EventObject & event) override
{
Execute((const itk::Object *)caller, event);
}
void
Execute(const itk::Object * object, const itk::EventObject & event) override
{
if (typeid(event) != typeid(itk::IterationEvent))
{
return;
}
const auto * optimizer = dynamic_cast<const TFilter *>(object);
if (!optimizer)
{
itkGenericExceptionMacro("Error dynamic_cast failed");
}
std::cout << "It- " << optimizer->GetCurrentIteration() << " gradient: " << optimizer->GetGradient()
<< " metric value: " << optimizer->GetCurrentMetricValue()
<< " Params: " << const_cast<TFilter *>(optimizer)->GetCurrentPosition() << std::endl;
}
};
//////////////////////////////////////////////////////////////////////////////////////////////////////////////
int
main()
{
constexpr int Dimension = 2;
using ImageType = itk::Image<float, Dimension>;
using PixelType = float;
using FixedImageType = itk::Image<PixelType, Dimension>;
using MovingImageType = itk::Image<PixelType, Dimension>;
using FixedImageReaderType = itk::ImageFileReader<FixedImageType>;
using MovingImageReaderType = itk::ImageFileReader<MovingImageType>;
FixedImageReaderType::Pointer fixedImageReader = FixedImageReaderType::New();
MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();
fixedImageReader->SetFileName("C:/Dev/ITK-git/Examples/Data/BrainProtonDensitySliceShifted13x17y.png");
movingImageReader->SetFileName("C:/Dev/ITK-git/Examples/Data/BrainT1SliceBorder20.png");
auto fixedImage = fixedImageReader->GetOutput();
auto movingImage = movingImageReader->GetOutput();
int numberOfIterations = 50;
/* if (argc > 1)
{
numberOfIterations = std::stoi(argv[1]);
}*/
// create an affine transform
using TranslationTransformType = itk::TranslationTransform<double, Dimension>;
auto translationTransform = TranslationTransformType::New();
translationTransform->SetIdentity();
using OptimizerType = itk::GradientDescentOptimizerv4;
auto optimizer = OptimizerType::New();
using RegistrationType = itk::ImageRegistrationMethodv4<ImageType, ImageType>;
auto registration = RegistrationType::New();
registration->SetFixedImage(fixedImage);
registration->SetMovingImage(movingImage);
registration->SetFixedImage(1, fixedImage);
registration->SetMovingImage(1, movingImage);
using MultiMetricType = itk::ObjectToObjectMultiMetricv4<Dimension, Dimension, ImageType>;
using MetricType = itk::CorrelationImageToImageMetricv4<ImageType, ImageType>;
auto correlationMetric = MetricType::New();
correlationMetric->SetFixedImage(fixedImage);
correlationMetric->SetMovingImage(movingImage);
correlationMetric->SetMovingTransform(translationTransform);
correlationMetric->Initialize();
/*translationTransform->SetIdentity();*/
// Test with two different metric types
/* using MeanSquaresMetricType = itk::MeanSquaresImageToImageMetricv4<ImageType, ImageType>;
auto meanSquaresMetric = MeanSquaresMetricType::New();
meanSquaresMetric->SetFixedImage(fixedImage);
meanSquaresMetric->SetMovingImage(movingImage);
meanSquaresMetric->SetMovingTransform(translationTransform);*/
using MattesMutualInformationMetricType = itk::MattesMutualInformationImageToImageMetricv4<ImageType, ImageType>;
auto MattesMutualInformationMetric = MattesMutualInformationMetricType::New();
MattesMutualInformationMetric->SetFixedImage(fixedImage);
MattesMutualInformationMetric->SetMovingImage(movingImage);
MattesMutualInformationMetric->SetMovingTransform(translationTransform);
MattesMutualInformationMetric->Initialize();
auto multiMetric2 = MultiMetricType::New();
multiMetric2->AddMetric(correlationMetric);
multiMetric2->AddMetric(MattesMutualInformationMetric);
multiMetric2->Initialize();
translationTransform->SetIdentity();
MultiMetricType::WeightsArrayType oriMetricWeights(2);
oriMetricWeights[0] = 0.5;
oriMetricWeights[1] = 0.5;
multiMetric2->SetMetricWeights(oriMetricWeights);
std::cout << "*** Multi-metric with different metric types: " << std::endl;
using ParametersType = TranslationTransformType::ParametersType;
ParametersType initialParameters(translationTransform->GetNumberOfParameters());
initialParameters[0] = 1.0; // Initial offset in mm along X
initialParameters[1] = 1.0; // Initial offset in mm along Y
translationTransform->SetParameters(initialParameters);
registration->SetInitialTransform(translationTransform);
registration->SetOptimizer(optimizer);
registration->SetMetric(multiMetric2);
optimizer->SetMetric(multiMetric2);
optimizer->SetNumberOfIterations(numberOfIterations);
optimizer->SetLearningRate(0.01);
optimizer->SetMaximumStepSizeInPhysicalUnits(1.0);
using CommandType = itkObjectToObjectMultiMetricv4RegistrationTestCommandIterationUpdate<OptimizerType>;
auto observer = CommandType::New();
optimizer->AddObserver(itk::IterationEvent(), observer);
//optimizer->StartOptimization();
try
{
registration->Update();
std::cout << "Optimizer stop condition: " << registration->GetOptimizer()->GetStopConditionDescription()
<< std::endl;
}
catch (itk::ExceptionObject & err)
{
std::cout << "ExceptionObject caught !" << std::endl;
std::cout << err << std::endl;
system("pause");
return EXIT_FAILURE;
}
using ResampleFilterType = itk::ResampleImageFilter<MovingImageType, FixedImageType>;
OptimizerType::ParametersType finalParameters = translationTransform->GetParameters();
double TranslationAlongX = finalParameters[0];
double TranslationAlongY = finalParameters[1];
unsigned int numberOfIteration = optimizer->GetCurrentIteration();
double bestValue = optimizer->GetValue();
std::cout << std::endl;
std::cout << "Result = " << std::endl;
std::cout << " Translation X = " << TranslationAlongX << std::endl;
std::cout << " Translation Y = " << TranslationAlongY << std::endl;
std::cout << " Iterations = " << numberOfIteration << std::endl;
std::cout << " Metric value = " << bestValue << std::endl;
TranslationTransformType::Pointer finalTransform = TranslationTransformType::New();
finalTransform->SetParameters(finalParameters);
finalTransform->SetFixedParameters(translationTransform->GetFixedParameters());
ResampleFilterType::Pointer resample = ResampleFilterType::New();
resample->SetTransform(finalTransform);
resample->SetInput(movingImageReader->GetOutput());
resample->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
resample->SetOutputOrigin(fixedImage->GetOrigin());
resample->SetOutputSpacing(fixedImage->GetSpacing());
resample->SetOutputDirection(fixedImage->GetDirection());
resample->SetDefaultPixelValue(100);
using OutputPixelType = unsigned char;
using OutputImageType = itk::Image<OutputPixelType, Dimension>;
using CastFilterType = itk::CastImageFilter<FixedImageType, OutputImageType>;
using WriterType = itk::ImageFileWriter<OutputImageType>;
WriterType::Pointer writer = WriterType::New();
CastFilterType::Pointer caster = CastFilterType::New();
writer->SetFileName("out.png");
caster->SetInput(resample->GetOutput());
writer->SetInput(caster->GetOutput());
writer->Update();
system("pause");
return EXIT_SUCCESS;
}