Hello ITK community! I am in desperate need for help as I am finding no resources on this.
I am trying to implement a custom ImageToImage similarity metric for the purpose of registration. For a project, I want to register two 3D binary segmentation masks as a first attempt. For this, I want a simple DICE metric to be used as the similarity metric, however, this does not exist in the C++ ITK library. How do I implement this? So far I simply tried to inherit from the ImageToImageMetricv4 class and implement my own value and derivative calculation (see implementation below) but I just keep getting errors that I honestly do not quite understand. Can anybody help me out here? Or at least point me in the right direction on how to do this?
template <typename TFixedImage, typename TMovingImage>
class DiceMetric : public itk::ImageToImageMetricv4<TFixedImage, TMovingImage>
{
public:
using Self = DiceMetric;
using Superclass = itk::ImageToImageMetricv4<TFixedImage, TMovingImage>;
using Pointer = itk::SmartPointer<Self>;
itkTypeMacro(DiceMetric, ImageToImageMetricv4);
// Define some type aliases for convenience
using FixedImageType = TFixedImage;
using MovingImageType = TMovingImage;
using FixedImagePixelType = typename FixedImageType::PixelType;
using MovingImagePixelType = typename MovingImageType::PixelType;
using MeasureType = typename Superclass::MeasureType;
using DerivativeType = typename Superclass::DerivativeType;
using FixedImageRegionType = typename FixedImageType::RegionType;
// Override GetValue function
MeasureType GetValue() const override
{
// Get the fixed and moving images
const FixedImageType *fixedImage = this->GetFixedImage();
const MovingImageType *movingImage = this->GetMovingImage();
// Get the region of overlap between the two images
FixedImageRegionType region = fixedImage->GetBufferedRegion();
region.Crop(movingImage->GetBufferedRegion());
// Compute the Dice coefficient
MeasureType dice = 0.0;
MeasureType intersection = 0.0;
MeasureType sum = 0.0;
itk::ImageRegionConstIterator<FixedImageType> fixedIt(fixedImage, region);
itk::ImageRegionConstIterator<MovingImageType> movingIt(movingImage, region);
while (!fixedIt.IsAtEnd())
{
FixedImagePixelType fixedValue = fixedIt.Get();
MovingImagePixelType movingValue = movingIt.Get();
intersection += static_cast<MeasureType>(fixedValue && movingValue);
sum += static_cast<MeasureType>(fixedValue || movingValue);
++fixedIt;
++movingIt;
}
if (sum > 0.0)
{
dice = (2.0 * intersection) / sum;
}
return dice;
}
// Override GetDerivative function
void GetDerivative(DerivativeType &derivative) const override
{
// Set the derivative to zero manually
derivative.Fill(0.0);
}
protected:
DiceMetric() = default;
~DiceMetric() override = default;
};
// Command Observer to print optimizer paramters at each iteration
// to monitor the evolution of the registration process
class CommandIterationUpdate : public itk::Command {
public:
using Self = CommandIterationUpdate;
using Superclass = itk::Command;
using Pointer = itk::SmartPointer<Self>;
itkNewMacro(Self);
protected:
CommandIterationUpdate() = default;
public:
using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
using OptimizerPointer = const OptimizerType *;
void Execute(itk::Object * caller, const itk::EventObject & event) override {
Execute((const itk::Object *)caller, event);
}
void Execute(const itk::Object * object, const itk::EventObject & event) override {
auto optimizer = static_cast<OptimizerPointer>(object);
if (!itk::IterationEvent().CheckEvent(&event)) {
return;
}
std::cout << optimizer->GetCurrentIteration() << " ";
std::cout << optimizer->GetValue() << " ";
std::cout << optimizer->GetCurrentPosition() << std::endl;
}
};
When I then try to set this metric in the registration method with
registration->SetMetric(metric)
I get the error:
no suitable conversion function from “itk::Object::Pointer” to "itk::ObjectToObjectMetricBaseTemplate
Any help is appreciated!