# slight modifications from
# https://github.com/InsightSoftwareConsortium/ITKColorNormalization/blob/master/examples/ITKColorNormalization.ipynb
# a couple of functions for my environment, ie JRS = True

JRS = True

def view(itk_image, title='', x=20, y=10):
    """ arg itk_image from itk.imread(input_image_filename, PixelType);
        convert to np array, use matplotlib to plot as usual
    """

    # assumes 'img_path' a .png, .jpg, .tiff, ...
    import numpy as np
    import matplotlib.pyplot as plt
    img = np.array(itk_image)
    print(f' view:: img.shape={img.shape}, img.dtype={img.dtype}')
    imgplot = plt.imshow(img)
    plt.title(title)
    fig = plt.gcf()
    move_fig(fig,x,y)
    plt.show()

import matplotlib as mp
def move_fig(f, x, y):
    """ Move figure's upper left corner to (x, y), pixel units """
    backend = mp.get_backend()
    assert backend == "TkAgg", "unable get get  TkAgg' as matplotlib backend"
    f.canvas.manager.window.wm_geometry(f"+{x}+{y}")

# below ITK example of color normalization
## [1]:
# Install itk-spcn and itkwidgets, if necessary
import sys
necessary = False
if necessary:
    import os
    #?? !{sys.executable} -m pip install itk-spcn
    os.system("python3 -m pip install itk-spcn")
    #?? !{sys.executable} -m pip install itkwidgets
    os.system("python3 -m pip install itkwidgets")

## [2]:
# Import needed packages
from urllib.request import urlretrieve
import os

# Import itk, which includes itk-spcn.
import itk
if not JRS:
    from itkwidgets import view

## [3]:
# Fetch input images, if we don't have them already.
BCH = True   # use Boston Children's Hospital test image
if JRS:
    if BCH:
        input_image_filename = 'Test_1_0810.png'  # typical BCH (under-stained) data
    else:
        input_image_filename = 'Easy1.png'
else:
    input_image_filename = 'Easy1.png'

input_image_url = 'https://data.kitware.com/api/v1/file/576ad39b8d777f1ecd6702f2/download'
if not os.path.exists(input_image_filename):
    urlretrieve(input_image_url, input_image_filename)

reference_image_filename = 'Hard.png'
reference_image_url = 'https://data.kitware.com/api/v1/file/57718cc48d777f1ecd8a883f/download'
if not os.path.exists(reference_image_filename):
    urlretrieve(reference_image_url, reference_image_filename)
    
output_image_filename = 'HardWithEasy1Colors.png'

# The pixels are RGB triplets of unsigned char.  The images are 2 dimensional.
PixelType = itk.RGBPixel[itk.UC]
ImageType = itk.Image[PixelType, 2]

## [4]:
# Invoke the functional, eager interface for ITK
input_image = itk.imread(input_image_filename, PixelType)
reference_image = itk.imread(reference_image_filename, PixelType)

# view inputs before attempted color normalization
view(input_image)
view(reference_image)

eager_normalized_image = itk.structure_preserving_color_normalization_filter(
    input_image,
    reference_image,
    color_index_suppressed_by_hematoxylin=0,
    color_index_suppressed_by_eosin=1)
itk.imwrite(eager_normalized_image, output_image_filename)

## [5]:
view(input_image)

## [6]:
view(reference_image)

## [7]:
view(eager_normalized_image)

## [8]:
# Alternatively, invoke the ITK pipeline
input_reader = itk.ImageFileReader[ImageType].New(FileName=input_image_filename)
reference_reader = itk.ImageFileReader[ImageType].New(FileName=reference_image_filename)
spcn_filter = itk.StructurePreservingColorNormalizationFilter.New(Input=input_reader.GetOutput())
spcn_filter.SetColorIndexSuppressedByHematoxylin(0)
spcn_filter.SetColorIndexSuppressedByEosin(1)
spcn_filter.SetInput(0, input_reader.GetOutput())
spcn_filter.SetInput(1, reference_reader.GetOutput())
output_writer = itk.ImageFileWriter.New(spcn_filter.GetOutput())
output_writer.SetInput(spcn_filter.GetOutput())
output_writer.SetFileName(output_image_filename)
output_writer.Write()

## [9]:
view(output_writer.GetInput())
