#!/usr/bin/env python

# This code is based on
# https://itk.org/ITKExamples/src/Registration/Common/Perform2DTranslationRegistrationWithMeanSquares/Documentation.html
# which holds the Apache 2.0 license, which can be found here:
# https://www.apache.org/licenses/LICENSE-2.0.html
#
# What differs from that work is licensed under the GPL 2.0

import sys
import itk
import numpy as np
from distutils.version import StrictVersion as VS
# from benchmark3d.primitives import displ_series_obj3d
from skimage.morphology import disk
from skimage.transform import SimilarityTransform
from skimage.transform import warp
from scipy.ndimage import shift


# Draw a happy face
side_size = 100
obj3d = np.zeros((side_size, side_size, int(side_size/10)),
                 dtype=np.uint8)
transf_mat_1 = SimilarityTransform(
    scale=0.8, translation=(5, (100 - 80) / 2)).inverse
disk1 = np.where(warp(disk(50, dtype=np.uint8),
                      transf_mat_1))
transf_mat_1 = SimilarityTransform(
    scale=0.8, translation=(10, (100 - 80) / 2)).inverse
disk2 = np.where(warp(disk(50, dtype=np.uint8),
                      transf_mat_1))
obj3d[disk1] = 1
obj3d[disk2] = 0
obj3d[19:30, 49:60, :] = 1
obj3d[69:80, 49:60, :] = 1

# Happily create 4 frames
series_length = 4
holder = np.zeros((series_length, *obj3d.shape), dtype=np.uint8)
# Move the last frame
holder[-1] = shift(holder[-1],
                   ((series_length - 1, 0, 0)))

# Load first and last frames as ITK data (images)
fixed_image = itk.image_from_array(holder[0])
moving_image = itk.image_from_array(holder[-1])
# Hack!!! ##################################################
# itk.imwrite(itk.image_from_array(holder[0, :, :, 0]),
#             "itk_hello_registration.in1.tiff")
# itk.imwrite(itk.image_from_array(holder[-1, :, :, 1]),
#             "itk_hello_registration.in4.tiff")
itk.imwrite(itk.image_from_array(holder[0]),
            "itk_hello_registration.in1.vtk")
itk.imwrite(itk.image_from_array(holder[-1]),
            "itk_hello_registration.in4.vtk")
del fixed_image, moving_image
PixelType = itk.ctype('float')

# fixed_image = itk.imread("itk_hello_registration.in1.tiff",
#                          PixelType)
# moving_image = itk.imread("itk_hello_registration.in4.tiff",
#                           PixelType)
# fixed_image = itk.ImageFileReader.New(
#     FileName="itk_hello_registration.in1.vtk").GetOutput()
fixedImageFile = "itk_hello_registration.in1.vtk"
# moving_image = itk.ImageFileReader.New(
#     FileName="itk_hello_registration.in4.vtk").GetOutput()
movingImageFile = "itk_hello_registration.in4.vtk"
# ##########################################################
# Set outputs
output_image = "itk_hello_registration.out.vtk"
diff_image_after = "itk_hello_registration.after.vtk"
diff_image_before = "itk_hello_registration.before.vtk"

if VS(itk.Version.GetITKVersion()) < VS("4.9.0"):
    print("ITK 4.9.0 is required.")
    sys.exit(1)

# if len(sys.argv) != 6:
#     print("Usage: " + sys.argv[0] + " <fixedImageFile> <movingImageFile> "
#           "<outputImagefile> <differenceImageAfter> <differenceImageBefore>")
#     sys.exit(1)
#
# fixedImageFile = sys.argv[1]
# movingImageFile = sys.argv[2]
# outputImageFile = sys.argv[3]
# diffImageAfter = sys.argv[4]
# diffImageBefore = sys.argv[5]
#
PixelType = itk.ctype('float')
#
fixed_image = itk.imread(fixedImageFile, PixelType)
moving_image = itk.imread(movingImageFile, PixelType)

Dimension = fixed_image.GetImageDimension()
FixedImageType = itk.Image[PixelType, Dimension]
MovingImageType = itk.Image[PixelType, Dimension]

TransformType = itk.TranslationTransform[itk.D, Dimension]
initialTransform = TransformType.New()

optimizer = itk.RegularStepGradientDescentOptimizerv4.New(
    LearningRate=4,
    MinimumStepLength=0.001,
    RelaxationFactor=0.5,
    NumberOfIterations=200)

metric = itk.MeanSquaresImageToImageMetricv4[
    FixedImageType, MovingImageType].New()

registration = itk.ImageRegistrationMethodv4[FixedImageType,MovingImageType].New(
    FixedImage=fixed_image,
    MovingImage=moving_image,
    Metric=metric,
    Optimizer=optimizer,
    InitialTransform=initialTransform)

movingInitialTransform = TransformType.New()
initialParameters = movingInitialTransform.GetParameters()
initialParameters[0] = 0
initialParameters[1] = 0
movingInitialTransform.SetParameters(initialParameters)
registration.SetMovingInitialTransform(movingInitialTransform)

identityTransform = TransformType.New()
identityTransform.SetIdentity()
registration.SetFixedInitialTransform(identityTransform)

registration.SetNumberOfLevels(1)
registration.SetSmoothingSigmasPerLevel([0])
registration.SetShrinkFactorsPerLevel([1])

registration.Update()

transform = registration.GetTransform()
finalParameters = transform.GetParameters()
translationAlongX = finalParameters.GetElement(0)
translationAlongY = finalParameters.GetElement(1)

numberOfIterations = optimizer.GetCurrentIteration()

bestValue = optimizer.GetValue()

print("Result = ")
print(" Translation X = " + str(translationAlongX))
print(" Translation Y = " + str(translationAlongY))
print(" Iterations    = " + str(numberOfIterations))
print(" Metric value  = " + str(bestValue))

CompositeTransformType = itk.CompositeTransform[itk.D, Dimension]
outputCompositeTransform = CompositeTransformType.New()
outputCompositeTransform.AddTransform(movingInitialTransform)
outputCompositeTransform.AddTransform(registration.GetModifiableTransform())

resampler = itk.ResampleImageFilter.New(Input=moving_image,
                                        Transform=outputCompositeTransform,
                                        UseReferenceImage=True,
                                        ReferenceImage=fixed_image)
resampler.SetDefaultPixelValue(100)

OutputPixelType = itk.ctype('unsigned char')
OutputImageType = itk.Image[OutputPixelType, Dimension]

caster = itk.CastImageFilter[FixedImageType,
                             OutputImageType].New(Input=resampler)

writer = itk.ImageFileWriter.New(Input=caster, FileName=output_image)
writer.SetFileName(output_image)
writer.Update()

difference = itk.SubtractImageFilter.New(Input1=fixed_image,
                                         Input2=resampler)

intensityRescaler = itk.RescaleIntensityImageFilter[FixedImageType,
                                                    OutputImageType].New(
                                                        Input=difference,
                                                        OutputMinimum=itk.NumericTraits[OutputPixelType].min(),
                                                        OutputMaximum=itk.NumericTraits[OutputPixelType].max())

resampler.SetDefaultPixelValue(1)
writer.SetInput(intensityRescaler.GetOutput())
writer.SetFileName(diff_image_after)
writer.Update()

resampler.SetTransform(identityTransform)
writer.SetFileName(diff_image_before)
writer.Update()