Multi-resolution registration anomaly

When utilizing the ImageRegistrationMethodv4 for multi-resolution registration, I encountered the following unusual behavior:

  1. In cases where the initial positions of the images to be registered were significantly misaligned, the multi-resolution registration process still performed normally.
  2. However, when the initial positions of the images were relatively close, performing multi-resolution registration resulted in failure to propagate the registration results from one resolution level to the subsequent level.

Below are my code and the running results.

bool RigidRegistration(RegParameter* rigidParameter,
    short * pRefData, unsigned uiRefSize[3], double dRefSpacing[3],
    short * pMovData, unsigned uiMovSize[3], double dMovSpacing[3],
    double dOutMat[9], double dOutTranslation[3],short * pOutData,double dPriorMat[9], double dPriorTranslation[3],
    double dRefPosition[3], double dRefOrientation[6], double dMovPosition[3], double dMovOrientation[6],
    bool *isCancel, double* regProgress)
{
    
    if (pRefData == nullptr || pMovData == nullptr) {
        return false;
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.0;
        }
    }
    

    RegParameter* RegPara = rigidParameter;
    if (RegPara == nullptr) {
        RegPara = new RegParameter();
    }
    double refPosition[3] = { 0 };
    if (dRefPosition != nullptr) {
        refPosition[0] = dRefPosition[0];
        refPosition[1] = dRefPosition[1];
        refPosition[2] = dRefPosition[2];
    }

    constexpr unsigned int Dimension = 3;
    using PixelType = short;
    using InterPixelType = float;

    using InterFixedImageType = itk::Image<InterPixelType, Dimension>;
    using InterMovingImageType = itk::Image<InterPixelType, Dimension>;

    using TransformType = itk::VersorRigid3DTransform<double>;   
    using MetricType =
        itk::MattesMutualInformationImageToImageMetricv4<InterFixedImageType, InterMovingImageType>;  

    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using RegistrationType = itk::
        ImageRegistrationMethodv4<InterFixedImageType, InterMovingImageType, TransformType>;

    auto refImg = InterFixedImageType::New();
    auto movImg = InterMovingImageType::New();
    unsigned refPixNum = uiRefSize[0] * uiRefSize[1] * uiRefSize[2];
    unsigned movPixNum = uiMovSize[0] * uiMovSize[1] * uiMovSize[2];
    std::unique_ptr<float[]> pfRefData(new float[refPixNum]());
    std::unique_ptr<float[]> pfMovData(new float[movPixNum]());
    for (unsigned i = 0; i < refPixNum; i++) {
        pfRefData[i] = (float)pRefData[i];
    }
    for (unsigned i = 0; i < movPixNum; i++) {
        pfMovData[i] = (float)pMovData[i];
    }
    TransformBuffToITKImage(pfRefData.get(), uiRefSize, dRefSpacing, dRefPosition, dRefOrientation, refImg);
    TransformBuffToITKImage(pfMovData.get(), uiMovSize, dMovSpacing, dMovPosition, dMovOrientation, movImg);

    auto metric = MetricType::New();
    metric->SetNumberOfHistogramBins(50); 
    auto optimizer = OptimizerType::New();
    auto registration = RegistrationType::New();

    //初始化
    auto initialTransform = TransformType::New();
    initialTransform->SetIdentity();
    using TransformInitializerType =
        itk::CenteredTransformInitializer<TransformType, InterFixedImageType, InterMovingImageType>;
    auto initializer = TransformInitializerType::New();

    itk::Point<double, 3> centerPoint;                     
    centerPoint[0] = refPosition[0] + (uiRefSize[0] - 1) * dRefSpacing[0] / 2; 
    centerPoint[1] = refPosition[1] + (uiRefSize[1] - 1) * dRefSpacing[1] / 2;
    centerPoint[2] = refPosition[2] + (uiRefSize[2] - 1) * dRefSpacing[2] / 2;
    initialTransform->SetCenter(centerPoint);

    if (dPriorMat != nullptr) {
        itk::Matrix<double, 3, 3> priorMat;
        for (unsigned i = 0; i < 3; i++) {
            for (unsigned j = 0; j < 3; j++) {
                priorMat(i, j) = dPriorMat[i * 3 + j];
            }
        }
        initialTransform->SetMatrix(priorMat,0.001);

    }
    if (dPriorTranslation != nullptr)
    {
        initialTransform->SetTranslation(dPriorTranslation);
    }
    if (dPriorTranslation == nullptr && dPriorMat == nullptr) {
        initializer->SetTransform(initialTransform);
        initializer->SetFixedImage(refImg);
        initializer->SetMovingImage(movImg);
        initializer->GeometryOn();           
        initializer->InitializeTransform();
    }

    using OptimizerScalesType = OptimizerType::ScalesType;
    OptimizerScalesType optimizerScales(
        initialTransform->GetNumberOfParameters());
    const double translationScale = RegPara->TranslationScale; 
    const double rotateScale = RegPara->RotateScale;         
    optimizerScales[0] = rotateScale;
    optimizerScales[1] = rotateScale;
    optimizerScales[2] = rotateScale;
    optimizerScales[3] = translationScale;
    optimizerScales[4] = translationScale; 
    optimizerScales[5] = translationScale; 
    optimizer->SetScales(optimizerScales);
    optimizer->SetNumberOfIterations(RegPara->NumberOfIterations);
    optimizer->SetLearningRate(RegPara->LearningRate);
    optimizer->SetRelaxationFactor(RegPara->RelaxationFactor);
    optimizer->SetMinimumStepLength(RegPara->MinimumStepLength);
    optimizer->SetReturnBestParametersAndValue(true);
    auto observer = CommandIterationUpdate::New();
    optimizer->AddObserver(itk::IterationEvent(), observer);


    unsigned FineShrinkFactor[3] = {1};
    if (uiRefSize[0] > 100 || uiRefSize[1] > 100 || uiRefSize[2] > 100) {
        FineShrinkFactor[0] = std::min(int(uiRefSize[0] / 100) + 1, int(4 / dRefSpacing[0]) + 1);
        FineShrinkFactor[1] = std::min(int(uiRefSize[1] / 100) + 1, int(4 / dRefSpacing[1]) + 1);
        FineShrinkFactor[2] = std::min(int(uiRefSize[2] / 100) + 1, int(4 / dRefSpacing[2]) + 1);
    }


    unsigned int pyramidLevel = RegPara->PyramidLevel;
    RegistrationType::MetricSamplingStrategyEnum samplingStrategy =
        RegistrationType::MetricSamplingStrategyEnum::REGULAR;
    registration->SetMetricSamplingStrategy(samplingStrategy);
    registration->SetNumberOfLevels(pyramidLevel);
    registration->SetMetric(metric);
    registration->SetOptimizer(optimizer);
    registration->SetFixedImage(refImg);
    registration->SetMovingImage(movImg);
    registration->SetInitialTransform(initialTransform);
    registration->InPlaceOn();                                

    RegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
    smoothingSigmasPerLevel.SetSize(pyramidLevel);                   
    for (unsigned i = 0; i < pyramidLevel; i++) {
        RegistrationType::ShrinkFactorsPerDimensionContainerType shrinkFactorsPerDimension;
        shrinkFactorsPerDimension[0] = FineShrinkFactor[0] * (pyramidLevel - i);
        shrinkFactorsPerDimension[1] = FineShrinkFactor[1] * (pyramidLevel - i);
        shrinkFactorsPerDimension[2] = FineShrinkFactor[2] * (pyramidLevel - i); 
        registration->SetShrinkFactorsPerDimension(i, shrinkFactorsPerDimension);

        smoothingSigmasPerLevel[i] = 0;               
    }
    registration->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);

    if (pyramidLevel > 1)                                  
    {
        using CommandType = RegistrationInterfaceCommand<RegistrationType>;
        auto command = CommandType::New();
        registration->AddObserver(itk::MultiResolutionIterationEvent(), command);
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.05;
            observer->SetCancelFlag(isCancel);
            observer->SetProgressPointer(regProgress);
            observer->SetTotalLevel(pyramidLevel);
            observer->SetProgressRange(0.05, 0.95);
        }
    }


    try
    { 
        registration->Update();
        std::cout << "Optimizer stop condition: "
            << registration->GetOptimizer()->GetStopConditionDescription()
            << std::endl;
    }
    catch (const itk::ExceptionObject& err)
    {
        std::cerr << "ExceptionObject caught !" << std::endl;
        std::cerr << err << std::endl;
        return false;
    }

    if (isCancel == nullptr || *isCancel != true)
    {
        if (regProgress != nullptr)
        {
            *regProgress = 0.95;
        }
    }
    
    if (isCancel != nullptr && *isCancel == true)
    {
        return false;
    }

    const TransformType::ParametersType finalParameters =
        registration->GetOutput()->Get()->GetParameters();
    auto finalTransform = TransformType::New();
    finalTransform->SetFixedParameters(
        registration->GetOutput()->Get()->GetFixedParameters());
    finalTransform->SetParameters(finalParameters);
    TransformType::MatrixType matrix = finalTransform->GetMatrix();
    TransformType::OffsetType translation = finalTransform->GetTranslation();
    auto& center = finalTransform->GetCenter();
    for (unsigned i = 0; i < 9; i++) {
        dOutMat[i] = matrix((unsigned)(i/3),(unsigned)(i%3));
    }
    for (unsigned i = 0; i < 3; i++) {
        dOutTranslation[i] = translation[i];
    }

    std::cout << "Center = " << std::endl << center << std::endl;
    std::cout << "Matrix = " << std::endl << matrix << std::endl;
    std::cout << "Translation = " << std::endl << translation << std::endl;

    if (pOutData != nullptr) {

        ResampleForRigid(pMovData, uiMovSize, dMovSpacing, uiRefSize, dRefSpacing,
            dOutMat, dOutTranslation, pOutData);
    }
    if (rigidParameter == nullptr && RegPara != nullptr) {
        delete RegPara;
    }
    return true;
}

class CommandIterationUpdate : public itk::Command
{
public:
    using Self = CommandIterationUpdate;
    using Superclass = itk::Command;
    using Pointer = itk::SmartPointer<Self>;
    itkNewMacro(Self);

    void SetProgressPointer(double* pProgress) { m_progress = pProgress; }

    void SetCancelFlag(bool* isCancel) { m_isCancel = isCancel; }

    void SetTotalLevel(unsigned totalLevel) { m_totalLevel = totalLevel; }

    void SetProgressRange(double startProgress, double endProgress)
    {
        m_startProgress = startProgress;
        m_endProgress = endProgress;
    }

    void SetCurLevel(unsigned curLevel) { m_curLevel = curLevel; }

protected:
    CommandIterationUpdate() = default;

public:

    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using OptimizerPointer = OptimizerType*;

    void Execute(itk::Object* caller, const itk::EventObject& event) override
    {
        auto optimizer = dynamic_cast<OptimizerPointer>(caller);
        if (!itk::IterationEvent().CheckEvent(&event))
        {
            return;
        }
        if (m_isCancel != nullptr && *m_isCancel)
        {
            optimizer->StopOptimization();
        }
        if (m_progress != nullptr) 
        {
            unsigned int currentIteration = optimizer->GetCurrentIteration();
            unsigned int maxIterations = optimizer->GetNumberOfIterations();
            double levelProgress = static_cast<double>(currentIteration) / maxIterations;
            double progressRange = m_endProgress - m_startProgress;
            *m_progress = m_startProgress + ((m_curLevel + levelProgress) / m_totalLevel)* progressRange;
        }

        std::cout << optimizer->GetCurrentIteration() << "   ";
        std::cout << optimizer->GetValue() << std::endl;
        std::cout << "  Current Position:" << optimizer->GetCurrentPosition() << std::endl;
        std::cout << "  Current LR Relaxation:" << optimizer->GetCurrentLearningRateRelaxation() << std::endl;
        std::cout << "  Current StepLength:" << optimizer->GetCurrentStepLength() << std::endl;
    }
    void Execute(const itk::Object* object, const itk::EventObject& event) override
    {

    }

private:
    double* m_progress = nullptr;
    bool* m_isCancel = nullptr;
    unsigned m_curLevel = 0;
    unsigned m_totalLevel = 1;
    double m_startProgress = 0.0;
    double m_endProgress = 1.0;

};

template <typename TRegistration>
class RegistrationInterfaceCommand : public itk::Command
{

public:
    using Self = RegistrationInterfaceCommand;
    using Superclass = itk::Command;
    using Pointer = itk::SmartPointer<Self>;
    itkNewMacro(Self);

    bool SetIterationCommand(CommandIterationUpdate* commandIterationUpdate)
    {
        m_commandIterationUpdate = commandIterationUpdate;
    }

protected:
    RegistrationInterfaceCommand() = default;

public:
    using RegistrationType = TRegistration;
    using RegistrationPointer = RegistrationType*;
    using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
    using OptimizerPointer = OptimizerType*;

    void Execute(itk::Object* object, const itk::EventObject& event) override
    {
        if (!(itk::MultiResolutionIterationEvent().CheckEvent(&event)))
        {
            return;
        }
        auto registration = static_cast<RegistrationPointer>(object);
        auto optimizer =
            static_cast<OptimizerPointer>(registration->GetModifiableOptimizer());

        unsigned int currentLevel = registration->GetCurrentLevel();
        typename RegistrationType::ShrinkFactorsPerDimensionContainerType
            shrinkFactors =
            registration->GetShrinkFactorsPerDimension(currentLevel);
        typename RegistrationType::SmoothingSigmasArrayType smoothingSigmas =
            registration->GetSmoothingSigmasPerLevel();

        // 
        if (m_commandIterationUpdate != nullptr)
        {
            m_commandIterationUpdate->SetCurLevel(currentLevel);
        }

        std::cout << "-------------------------------------" << std::endl;
        std::cout << " Current level = " << currentLevel << std::endl;
        std::cout << "    shrink factor = " << shrinkFactors << std::endl;
        std::cout << "    smoothing sigma = ";
        std::cout << smoothingSigmas[currentLevel] << std::endl;
        std::cout << std::endl;
        if (registration->GetCurrentLevel() == 0)
        {
            optimizer->SetLearningRate(optimizer->GetLearningRate()*4);
            optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() * 10);
        }
        else
        {
            optimizer->SetLearningRate(optimizer->GetCurrentStepLength());
            optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() / 10);
        }

    }

    // pure virtual method, must be realized.
    void Execute(const itk::Object*, const itk::EventObject&) override
    {
        return;
    }

private:
    CommandIterationUpdate* m_commandIterationUpdate = nullptr;
};

Here, I used two exactly identical images for registration, and the results are as follows:



Could everyone help me analyze the reasons for this situation and provide solutions to this problem? I would be extremely grateful.@dzenanz