Multi-resolution registration anomaly

When utilizing the ImageRegistrationMethodv4 for multi-resolution registration, I encountered the following unusual behavior:

  1. In cases where the initial positions of the images to be registered were significantly misaligned, the multi-resolution registration process still performed normally.
  2. However, when the initial positions of the images were relatively close, performing multi-resolution registration resulted in failure to propagate the registration results from one resolution level to the subsequent level.

Below are my code and the running results.

bool RigidRegistration(RegParameter* rigidParameter,
    short * pRefData, unsigned uiRefSize[3], double dRefSpacing[3],
    short * pMovData, unsigned uiMovSize[3], double dMovSpacing[3],
    double dOutMat[9], double dOutTranslation[3],short * pOutData,double dPriorMat[9], double dPriorTranslation[3],
    double dRefPosition[3], double dRefOrientation[6], double dMovPosition[3], double dMovOrientation[6],
    bool *isCancel, double* regProgress)
{
    
    if (pRefData == nullptr || pMovData == nullptr) {
        return false;
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.0;
        }
    }
    

    RegParameter* RegPara = rigidParameter;
    if (RegPara == nullptr) {
        RegPara = new RegParameter();
    }
    double refPosition[3] = { 0 };
    if (dRefPosition != nullptr) {
        refPosition[0] = dRefPosition[0];
        refPosition[1] = dRefPosition[1];
        refPosition[2] = dRefPosition[2];
    }

    constexpr unsigned int Dimension = 3;
    using PixelType = short;
    using InterPixelType = float;

    using InterFixedImageType = itk::Image<InterPixelType, Dimension>;
    using InterMovingImageType = itk::Image<InterPixelType, Dimension>;

    using TransformType = itk::VersorRigid3DTransform<double>;   
    using MetricType =
        itk::MattesMutualInformationImageToImageMetricv4<InterFixedImageType, InterMovingImageType>;  

    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using RegistrationType = itk::
        ImageRegistrationMethodv4<InterFixedImageType, InterMovingImageType, TransformType>;

    auto refImg = InterFixedImageType::New();
    auto movImg = InterMovingImageType::New();
    unsigned refPixNum = uiRefSize[0] * uiRefSize[1] * uiRefSize[2];
    unsigned movPixNum = uiMovSize[0] * uiMovSize[1] * uiMovSize[2];
    std::unique_ptr<float[]> pfRefData(new float[refPixNum]());
    std::unique_ptr<float[]> pfMovData(new float[movPixNum]());
    for (unsigned i = 0; i < refPixNum; i++) {
        pfRefData[i] = (float)pRefData[i];
    }
    for (unsigned i = 0; i < movPixNum; i++) {
        pfMovData[i] = (float)pMovData[i];
    }
    TransformBuffToITKImage(pfRefData.get(), uiRefSize, dRefSpacing, dRefPosition, dRefOrientation, refImg);
    TransformBuffToITKImage(pfMovData.get(), uiMovSize, dMovSpacing, dMovPosition, dMovOrientation, movImg);

    auto metric = MetricType::New();
    metric->SetNumberOfHistogramBins(50); 
    auto optimizer = OptimizerType::New();
    auto registration = RegistrationType::New();

    //初始化
    auto initialTransform = TransformType::New();
    initialTransform->SetIdentity();
    using TransformInitializerType =
        itk::CenteredTransformInitializer<TransformType, InterFixedImageType, InterMovingImageType>;
    auto initializer = TransformInitializerType::New();

    itk::Point<double, 3> centerPoint;                     
    centerPoint[0] = refPosition[0] + (uiRefSize[0] - 1) * dRefSpacing[0] / 2; 
    centerPoint[1] = refPosition[1] + (uiRefSize[1] - 1) * dRefSpacing[1] / 2;
    centerPoint[2] = refPosition[2] + (uiRefSize[2] - 1) * dRefSpacing[2] / 2;
    initialTransform->SetCenter(centerPoint);

    if (dPriorMat != nullptr) {
        itk::Matrix<double, 3, 3> priorMat;
        for (unsigned i = 0; i < 3; i++) {
            for (unsigned j = 0; j < 3; j++) {
                priorMat(i, j) = dPriorMat[i * 3 + j];
            }
        }
        initialTransform->SetMatrix(priorMat,0.001);

    }
    if (dPriorTranslation != nullptr)
    {
        initialTransform->SetTranslation(dPriorTranslation);
    }
    if (dPriorTranslation == nullptr && dPriorMat == nullptr) {
        initializer->SetTransform(initialTransform);
        initializer->SetFixedImage(refImg);
        initializer->SetMovingImage(movImg);
        initializer->GeometryOn();           
        initializer->InitializeTransform();
    }

    using OptimizerScalesType = OptimizerType::ScalesType;
    OptimizerScalesType optimizerScales(
        initialTransform->GetNumberOfParameters());
    const double translationScale = RegPara->TranslationScale; 
    const double rotateScale = RegPara->RotateScale;         
    optimizerScales[0] = rotateScale;
    optimizerScales[1] = rotateScale;
    optimizerScales[2] = rotateScale;
    optimizerScales[3] = translationScale;
    optimizerScales[4] = translationScale; 
    optimizerScales[5] = translationScale; 
    optimizer->SetScales(optimizerScales);
    optimizer->SetNumberOfIterations(RegPara->NumberOfIterations);
    optimizer->SetLearningRate(RegPara->LearningRate);
    optimizer->SetRelaxationFactor(RegPara->RelaxationFactor);
    optimizer->SetMinimumStepLength(RegPara->MinimumStepLength);
    optimizer->SetReturnBestParametersAndValue(true);
    auto observer = CommandIterationUpdate::New();
    optimizer->AddObserver(itk::IterationEvent(), observer);


    unsigned FineShrinkFactor[3] = {1};
    if (uiRefSize[0] > 100 || uiRefSize[1] > 100 || uiRefSize[2] > 100) {
        FineShrinkFactor[0] = std::min(int(uiRefSize[0] / 100) + 1, int(4 / dRefSpacing[0]) + 1);
        FineShrinkFactor[1] = std::min(int(uiRefSize[1] / 100) + 1, int(4 / dRefSpacing[1]) + 1);
        FineShrinkFactor[2] = std::min(int(uiRefSize[2] / 100) + 1, int(4 / dRefSpacing[2]) + 1);
    }


    unsigned int pyramidLevel = RegPara->PyramidLevel;
    RegistrationType::MetricSamplingStrategyEnum samplingStrategy =
        RegistrationType::MetricSamplingStrategyEnum::REGULAR;
    registration->SetMetricSamplingStrategy(samplingStrategy);
    registration->SetNumberOfLevels(pyramidLevel);
    registration->SetMetric(metric);
    registration->SetOptimizer(optimizer);
    registration->SetFixedImage(refImg);
    registration->SetMovingImage(movImg);
    registration->SetInitialTransform(initialTransform);
    registration->InPlaceOn();                                

    RegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
    smoothingSigmasPerLevel.SetSize(pyramidLevel);                   
    for (unsigned i = 0; i < pyramidLevel; i++) {
        RegistrationType::ShrinkFactorsPerDimensionContainerType shrinkFactorsPerDimension;
        shrinkFactorsPerDimension[0] = FineShrinkFactor[0] * (pyramidLevel - i);
        shrinkFactorsPerDimension[1] = FineShrinkFactor[1] * (pyramidLevel - i);
        shrinkFactorsPerDimension[2] = FineShrinkFactor[2] * (pyramidLevel - i); 
        registration->SetShrinkFactorsPerDimension(i, shrinkFactorsPerDimension);

        smoothingSigmasPerLevel[i] = 0;               
    }
    registration->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);

    if (pyramidLevel > 1)                                  
    {
        using CommandType = RegistrationInterfaceCommand<RegistrationType>;
        auto command = CommandType::New();
        registration->AddObserver(itk::MultiResolutionIterationEvent(), command);
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.05;
            observer->SetCancelFlag(isCancel);
            observer->SetProgressPointer(regProgress);
            observer->SetTotalLevel(pyramidLevel);
            observer->SetProgressRange(0.05, 0.95);
        }
    }


    try
    { 
        registration->Update();
        std::cout << "Optimizer stop condition: "
            << registration->GetOptimizer()->GetStopConditionDescription()
            << std::endl;
    }
    catch (const itk::ExceptionObject& err)
    {
        std::cerr << "ExceptionObject caught !" << std::endl;
        std::cerr << err << std::endl;
        return false;
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.95;
        }
    }
    
    if (isCancel != nullptr && *isCancel == true)
    {
        return false;
    }

    const TransformType::ParametersType finalParameters =
        registration->GetOutput()->Get()->GetParameters();
    auto finalTransform = TransformType::New();
    finalTransform->SetFixedParameters(
        registration->GetOutput()->Get()->GetFixedParameters());
    finalTransform->SetParameters(finalParameters);
    TransformType::MatrixType matrix = finalTransform->GetMatrix();
    TransformType::OffsetType translation = finalTransform->GetTranslation();
    auto& center = finalTransform->GetCenter();
    for (unsigned i = 0; i < 9; i++) {
        dOutMat[i] = matrix((unsigned)(i/3),(unsigned)(i%3));
    }
    for (unsigned i = 0; i < 3; i++) {
        dOutTranslation[i] = translation[i];
    }

    std::cout << "Center = " << std::endl << center << std::endl;
    std::cout << "Matrix = " << std::endl << matrix << std::endl;
    std::cout << "Translation = " << std::endl << translation << std::endl;

    if (pOutData != nullptr) {

        ResampleForRigid(pMovData, uiMovSize, dMovSpacing, uiRefSize, dRefSpacing,
            dOutMat, dOutTranslation, pOutData);
    }
    if (rigidParameter == nullptr && RegPara != nullptr) {
        delete RegPara;
    }
    return true;
}

class CommandIterationUpdate : public itk::Command
{
public:
    using Self = CommandIterationUpdate;
    using Superclass = itk::Command;
    using Pointer = itk::SmartPointer<Self>;
    itkNewMacro(Self);

    void SetProgressPointer(double* pProgress) { m_progress = pProgress; }

    void SetCancelFlag(bool* isCancel) { m_isCancel = isCancel; }

    void SetTotalLevel(unsigned totalLevel) { m_totalLevel = totalLevel; }

    void SetProgressRange(double startProgress, double endProgress)
    {
        m_startProgress = startProgress;
        m_endProgress = endProgress;
    }

    void SetCurLevel(unsigned curLevel) { m_curLevel = curLevel; }

protected:
    CommandIterationUpdate() = default;

public:

    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using OptimizerPointer = OptimizerType*;

    void Execute(itk::Object* caller, const itk::EventObject& event) override
    {
        auto optimizer = dynamic_cast<OptimizerPointer>(caller);
        if (!itk::IterationEvent().CheckEvent(&event))
        {
            return;
        }
        if (m_isCancel != nullptr && *m_isCancel)
        {
            optimizer->StopOptimization();
        }
        if (m_progress != nullptr) 
        {
            unsigned int currentIteration = optimizer->GetCurrentIteration();
            unsigned int maxIterations = optimizer->GetNumberOfIterations();
            double levelProgress = static_cast<double>(currentIteration) / maxIterations;
            double progressRange = m_endProgress - m_startProgress;
            *m_progress = m_startProgress + ((m_curLevel + levelProgress) / m_totalLevel)* progressRange;
        }

        std::cout << optimizer->GetCurrentIteration() << "   ";
        std::cout << optimizer->GetValue() << std::endl;
        std::cout << "  Current Position:" << optimizer->GetCurrentPosition() << std::endl;
        std::cout << "  Current LR Relaxation:" << optimizer->GetCurrentLearningRateRelaxation() << std::endl;
        std::cout << "  Current StepLength:" << optimizer->GetCurrentStepLength() << std::endl;
    }
    void Execute(const itk::Object* object, const itk::EventObject& event) override
    {

    }

private:
    double* m_progress = nullptr;
    bool* m_isCancel = nullptr;
    unsigned m_curLevel = 0;
    unsigned m_totalLevel = 1;
    double m_startProgress = 0.0;
    double m_endProgress = 1.0;

};

template <typename TRegistration>
class RegistrationInterfaceCommand : public itk::Command
{

public:
    using Self = RegistrationInterfaceCommand;
    using Superclass = itk::Command;
    using Pointer = itk::SmartPointer<Self>;
    itkNewMacro(Self);

    bool SetIterationCommand(CommandIterationUpdate* commandIterationUpdate)
    {
        m_commandIterationUpdate = commandIterationUpdate;
    }

protected:
    RegistrationInterfaceCommand() = default;

public:
    using RegistrationType = TRegistration;
    using RegistrationPointer = RegistrationType*;
    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using OptimizerPointer = OptimizerType*;

    void Execute(itk::Object* object, const itk::EventObject& event) override
    {
        if (!(itk::MultiResolutionIterationEvent().CheckEvent(&event)))
        {
            return;
        }
        auto registration = static_cast<RegistrationPointer>(object);
        auto optimizer =
            static_cast<OptimizerPointer>(registration->GetModifiableOptimizer());

        unsigned int currentLevel = registration->GetCurrentLevel();
        typename RegistrationType::ShrinkFactorsPerDimensionContainerType
            shrinkFactors =
            registration->GetShrinkFactorsPerDimension(currentLevel);
        typename RegistrationType::SmoothingSigmasArrayType smoothingSigmas =
            registration->GetSmoothingSigmasPerLevel();

        // 
        if (m_commandIterationUpdate != nullptr)
        {
            m_commandIterationUpdate->SetCurLevel(currentLevel);
        }

        std::cout << "-------------------------------------" << std::endl;
        std::cout << " Current level = " << currentLevel << std::endl;
        std::cout << "    shrink factor = " << shrinkFactors << std::endl;
        std::cout << "    smoothing sigma = ";
        std::cout << smoothingSigmas[currentLevel] << std::endl;
        std::cout << std::endl;
        if (registration->GetCurrentLevel() == 0)
        {
            optimizer->SetLearningRate(optimizer->GetLearningRate()*4);
            optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() * 10);
        }
        else
        {
            optimizer->SetLearningRate(optimizer->GetCurrentStepLength());
            optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() / 10);
        }

    }

    // pure virtual method, must be realized.
    void Execute(const itk::Object*, const itk::EventObject&) override
    {
        return;
    }

private:
    CommandIterationUpdate* m_commandIterationUpdate = nullptr;
};

Here, I used two exactly identical images for registration, and the results are as follows:



Could everyone help me analyze the reasons for this situation and provide solutions to this problem? I would be extremely grateful.@dzenanz

Is there anyone who can offer help? I would be extremely grateful.

I have identified the reason. This command causes the optimizer to return the parameter combination with the optimal metric value, rather than the one from the last iteration. Sometimes the parameter combination that yields the optimal metric value is not the one that performs the best, which leads to the current problem.

1 Like