When utilizing the ImageRegistrationMethodv4 for multi-resolution registration, I encountered the following unusual behavior:
- In cases where the initial positions of the images to be registered were significantly misaligned, the multi-resolution registration process still performed normally.
- However, when the initial positions of the images were relatively close, performing multi-resolution registration resulted in failure to propagate the registration results from one resolution level to the subsequent level.
Below are my code and the running results.
bool RigidRegistration(RegParameter* rigidParameter,
short * pRefData, unsigned uiRefSize[3], double dRefSpacing[3],
short * pMovData, unsigned uiMovSize[3], double dMovSpacing[3],
double dOutMat[9], double dOutTranslation[3],short * pOutData,double dPriorMat[9], double dPriorTranslation[3],
double dRefPosition[3], double dRefOrientation[6], double dMovPosition[3], double dMovOrientation[6],
bool *isCancel, double* regProgress)
{
if (pRefData == nullptr || pMovData == nullptr) {
return false;
}
if (isCancel == nullptr || *isCancel != true)
{
if (regProgress != nullptr)
{
*regProgress = 0.0;
}
}
RegParameter* RegPara = rigidParameter;
if (RegPara == nullptr) {
RegPara = new RegParameter();
}
double refPosition[3] = { 0 };
if (dRefPosition != nullptr) {
refPosition[0] = dRefPosition[0];
refPosition[1] = dRefPosition[1];
refPosition[2] = dRefPosition[2];
}
constexpr unsigned int Dimension = 3;
using PixelType = short;
using InterPixelType = float;
using InterFixedImageType = itk::Image<InterPixelType, Dimension>;
using InterMovingImageType = itk::Image<InterPixelType, Dimension>;
using TransformType = itk::VersorRigid3DTransform<double>;
using MetricType =
itk::MattesMutualInformationImageToImageMetricv4<InterFixedImageType, InterMovingImageType>;
using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
using RegistrationType = itk::
ImageRegistrationMethodv4<InterFixedImageType, InterMovingImageType, TransformType>;
auto refImg = InterFixedImageType::New();
auto movImg = InterMovingImageType::New();
unsigned refPixNum = uiRefSize[0] * uiRefSize[1] * uiRefSize[2];
unsigned movPixNum = uiMovSize[0] * uiMovSize[1] * uiMovSize[2];
std::unique_ptr<float[]> pfRefData(new float[refPixNum]());
std::unique_ptr<float[]> pfMovData(new float[movPixNum]());
for (unsigned i = 0; i < refPixNum; i++) {
pfRefData[i] = (float)pRefData[i];
}
for (unsigned i = 0; i < movPixNum; i++) {
pfMovData[i] = (float)pMovData[i];
}
TransformBuffToITKImage(pfRefData.get(), uiRefSize, dRefSpacing, dRefPosition, dRefOrientation, refImg);
TransformBuffToITKImage(pfMovData.get(), uiMovSize, dMovSpacing, dMovPosition, dMovOrientation, movImg);
auto metric = MetricType::New();
metric->SetNumberOfHistogramBins(50);
auto optimizer = OptimizerType::New();
auto registration = RegistrationType::New();
//初始化
auto initialTransform = TransformType::New();
initialTransform->SetIdentity();
using TransformInitializerType =
itk::CenteredTransformInitializer<TransformType, InterFixedImageType, InterMovingImageType>;
auto initializer = TransformInitializerType::New();
itk::Point<double, 3> centerPoint;
centerPoint[0] = refPosition[0] + (uiRefSize[0] - 1) * dRefSpacing[0] / 2;
centerPoint[1] = refPosition[1] + (uiRefSize[1] - 1) * dRefSpacing[1] / 2;
centerPoint[2] = refPosition[2] + (uiRefSize[2] - 1) * dRefSpacing[2] / 2;
initialTransform->SetCenter(centerPoint);
if (dPriorMat != nullptr) {
itk::Matrix<double, 3, 3> priorMat;
for (unsigned i = 0; i < 3; i++) {
for (unsigned j = 0; j < 3; j++) {
priorMat(i, j) = dPriorMat[i * 3 + j];
}
}
initialTransform->SetMatrix(priorMat,0.001);
}
if (dPriorTranslation != nullptr)
{
initialTransform->SetTranslation(dPriorTranslation);
}
if (dPriorTranslation == nullptr && dPriorMat == nullptr) {
initializer->SetTransform(initialTransform);
initializer->SetFixedImage(refImg);
initializer->SetMovingImage(movImg);
initializer->GeometryOn();
initializer->InitializeTransform();
}
using OptimizerScalesType = OptimizerType::ScalesType;
OptimizerScalesType optimizerScales(
initialTransform->GetNumberOfParameters());
const double translationScale = RegPara->TranslationScale;
const double rotateScale = RegPara->RotateScale;
optimizerScales[0] = rotateScale;
optimizerScales[1] = rotateScale;
optimizerScales[2] = rotateScale;
optimizerScales[3] = translationScale;
optimizerScales[4] = translationScale;
optimizerScales[5] = translationScale;
optimizer->SetScales(optimizerScales);
optimizer->SetNumberOfIterations(RegPara->NumberOfIterations);
optimizer->SetLearningRate(RegPara->LearningRate);
optimizer->SetRelaxationFactor(RegPara->RelaxationFactor);
optimizer->SetMinimumStepLength(RegPara->MinimumStepLength);
optimizer->SetReturnBestParametersAndValue(true);
auto observer = CommandIterationUpdate::New();
optimizer->AddObserver(itk::IterationEvent(), observer);
unsigned FineShrinkFactor[3] = {1};
if (uiRefSize[0] > 100 || uiRefSize[1] > 100 || uiRefSize[2] > 100) {
FineShrinkFactor[0] = std::min(int(uiRefSize[0] / 100) + 1, int(4 / dRefSpacing[0]) + 1);
FineShrinkFactor[1] = std::min(int(uiRefSize[1] / 100) + 1, int(4 / dRefSpacing[1]) + 1);
FineShrinkFactor[2] = std::min(int(uiRefSize[2] / 100) + 1, int(4 / dRefSpacing[2]) + 1);
}
unsigned int pyramidLevel = RegPara->PyramidLevel;
RegistrationType::MetricSamplingStrategyEnum samplingStrategy =
RegistrationType::MetricSamplingStrategyEnum::REGULAR;
registration->SetMetricSamplingStrategy(samplingStrategy);
registration->SetNumberOfLevels(pyramidLevel);
registration->SetMetric(metric);
registration->SetOptimizer(optimizer);
registration->SetFixedImage(refImg);
registration->SetMovingImage(movImg);
registration->SetInitialTransform(initialTransform);
registration->InPlaceOn();
RegistrationType::SmoothingSigmasArrayType smoothingSigmasPerLevel;
smoothingSigmasPerLevel.SetSize(pyramidLevel);
for (unsigned i = 0; i < pyramidLevel; i++) {
RegistrationType::ShrinkFactorsPerDimensionContainerType shrinkFactorsPerDimension;
shrinkFactorsPerDimension[0] = FineShrinkFactor[0] * (pyramidLevel - i);
shrinkFactorsPerDimension[1] = FineShrinkFactor[1] * (pyramidLevel - i);
shrinkFactorsPerDimension[2] = FineShrinkFactor[2] * (pyramidLevel - i);
registration->SetShrinkFactorsPerDimension(i, shrinkFactorsPerDimension);
smoothingSigmasPerLevel[i] = 0;
}
registration->SetSmoothingSigmasPerLevel(smoothingSigmasPerLevel);
if (pyramidLevel > 1)
{
using CommandType = RegistrationInterfaceCommand<RegistrationType>;
auto command = CommandType::New();
registration->AddObserver(itk::MultiResolutionIterationEvent(), command);
}
if (isCancel == nullptr || *isCancel != true)
{
if (regProgress != nullptr)
{
*regProgress = 0.05;
observer->SetCancelFlag(isCancel);
observer->SetProgressPointer(regProgress);
observer->SetTotalLevel(pyramidLevel);
observer->SetProgressRange(0.05, 0.95);
}
}
try
{
registration->Update();
std::cout << "Optimizer stop condition: "
<< registration->GetOptimizer()->GetStopConditionDescription()
<< std::endl;
}
catch (const itk::ExceptionObject& err)
{
std::cerr << "ExceptionObject caught !" << std::endl;
std::cerr << err << std::endl;
return false;
}
if (isCancel == nullptr || *isCancel != true)
{
if (regProgress != nullptr)
{
*regProgress = 0.95;
}
}
if (isCancel != nullptr && *isCancel == true)
{
return false;
}
const TransformType::ParametersType finalParameters =
registration->GetOutput()->Get()->GetParameters();
auto finalTransform = TransformType::New();
finalTransform->SetFixedParameters(
registration->GetOutput()->Get()->GetFixedParameters());
finalTransform->SetParameters(finalParameters);
TransformType::MatrixType matrix = finalTransform->GetMatrix();
TransformType::OffsetType translation = finalTransform->GetTranslation();
auto& center = finalTransform->GetCenter();
for (unsigned i = 0; i < 9; i++) {
dOutMat[i] = matrix((unsigned)(i/3),(unsigned)(i%3));
}
for (unsigned i = 0; i < 3; i++) {
dOutTranslation[i] = translation[i];
}
std::cout << "Center = " << std::endl << center << std::endl;
std::cout << "Matrix = " << std::endl << matrix << std::endl;
std::cout << "Translation = " << std::endl << translation << std::endl;
if (pOutData != nullptr) {
ResampleForRigid(pMovData, uiMovSize, dMovSpacing, uiRefSize, dRefSpacing,
dOutMat, dOutTranslation, pOutData);
}
if (rigidParameter == nullptr && RegPara != nullptr) {
delete RegPara;
}
return true;
}
class CommandIterationUpdate : public itk::Command
{
public:
using Self = CommandIterationUpdate;
using Superclass = itk::Command;
using Pointer = itk::SmartPointer<Self>;
itkNewMacro(Self);
void SetProgressPointer(double* pProgress) { m_progress = pProgress; }
void SetCancelFlag(bool* isCancel) { m_isCancel = isCancel; }
void SetTotalLevel(unsigned totalLevel) { m_totalLevel = totalLevel; }
void SetProgressRange(double startProgress, double endProgress)
{
m_startProgress = startProgress;
m_endProgress = endProgress;
}
void SetCurLevel(unsigned curLevel) { m_curLevel = curLevel; }
protected:
CommandIterationUpdate() = default;
public:
using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
using OptimizerPointer = OptimizerType*;
void Execute(itk::Object* caller, const itk::EventObject& event) override
{
auto optimizer = dynamic_cast<OptimizerPointer>(caller);
if (!itk::IterationEvent().CheckEvent(&event))
{
return;
}
if (m_isCancel != nullptr && *m_isCancel)
{
optimizer->StopOptimization();
}
if (m_progress != nullptr)
{
unsigned int currentIteration = optimizer->GetCurrentIteration();
unsigned int maxIterations = optimizer->GetNumberOfIterations();
double levelProgress = static_cast<double>(currentIteration) / maxIterations;
double progressRange = m_endProgress - m_startProgress;
*m_progress = m_startProgress + ((m_curLevel + levelProgress) / m_totalLevel)* progressRange;
}
std::cout << optimizer->GetCurrentIteration() << " ";
std::cout << optimizer->GetValue() << std::endl;
std::cout << " Current Position:" << optimizer->GetCurrentPosition() << std::endl;
std::cout << " Current LR Relaxation:" << optimizer->GetCurrentLearningRateRelaxation() << std::endl;
std::cout << " Current StepLength:" << optimizer->GetCurrentStepLength() << std::endl;
}
void Execute(const itk::Object* object, const itk::EventObject& event) override
{
}
private:
double* m_progress = nullptr;
bool* m_isCancel = nullptr;
unsigned m_curLevel = 0;
unsigned m_totalLevel = 1;
double m_startProgress = 0.0;
double m_endProgress = 1.0;
};
template <typename TRegistration>
class RegistrationInterfaceCommand : public itk::Command
{
public:
using Self = RegistrationInterfaceCommand;
using Superclass = itk::Command;
using Pointer = itk::SmartPointer<Self>;
itkNewMacro(Self);
bool SetIterationCommand(CommandIterationUpdate* commandIterationUpdate)
{
m_commandIterationUpdate = commandIterationUpdate;
}
protected:
RegistrationInterfaceCommand() = default;
public:
using RegistrationType = TRegistration;
using RegistrationPointer = RegistrationType*;
using OptimizerType = itk::RegularStepGradientDescentOptimizerv4<double>;
using OptimizerPointer = OptimizerType*;
void Execute(itk::Object* object, const itk::EventObject& event) override
{
if (!(itk::MultiResolutionIterationEvent().CheckEvent(&event)))
{
return;
}
auto registration = static_cast<RegistrationPointer>(object);
auto optimizer =
static_cast<OptimizerPointer>(registration->GetModifiableOptimizer());
unsigned int currentLevel = registration->GetCurrentLevel();
typename RegistrationType::ShrinkFactorsPerDimensionContainerType
shrinkFactors =
registration->GetShrinkFactorsPerDimension(currentLevel);
typename RegistrationType::SmoothingSigmasArrayType smoothingSigmas =
registration->GetSmoothingSigmasPerLevel();
//
if (m_commandIterationUpdate != nullptr)
{
m_commandIterationUpdate->SetCurLevel(currentLevel);
}
std::cout << "-------------------------------------" << std::endl;
std::cout << " Current level = " << currentLevel << std::endl;
std::cout << " shrink factor = " << shrinkFactors << std::endl;
std::cout << " smoothing sigma = ";
std::cout << smoothingSigmas[currentLevel] << std::endl;
std::cout << std::endl;
if (registration->GetCurrentLevel() == 0)
{
optimizer->SetLearningRate(optimizer->GetLearningRate()*4);
optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() * 10);
}
else
{
optimizer->SetLearningRate(optimizer->GetCurrentStepLength());
optimizer->SetMinimumStepLength(optimizer->GetMinimumStepLength() / 10);
}
}
// pure virtual method, must be realized.
void Execute(const itk::Object*, const itk::EventObject&) override
{
return;
}
private:
CommandIterationUpdate* m_commandIterationUpdate = nullptr;
};
Here, I used two exactly identical images for registration, and the results are as follows:
Could everyone help me analyze the reasons for this situation and provide solutions to this problem? I would be extremely grateful.@dzenanz