# file: "pp.py", slightly modified original ITK color normalization demo code.
# Modified to use a typical UT patch used in training for the reference image and
# a typical Boston Children's Hospital patch of (older) H&E stained osteosarcoma data.
#
# Original code from:
# https://github.com/InsightSoftwareConsortium/ITKColorNormalization/blob/master/examples/ITKColorNormalization.ipynb

def np_stats(fname,dat):
    print(f'\n *** {fname} stats)')
    print(f'\t type {type(dat)}: dtype {dat.dtype}: shape {dat.shape}')

# Install itk-spcn and itkwidgets, if necessary
import sys

"""
necessary = False
if necessary:
    !{sys.executable} -m pip install itk-spcn
    !{sys.executable} -m pip install itkwidgets
"""

# Import needed packages
from urllib.request import urlretrieve
import os

# Import itk, which includes itk-spcn.
import itk
from itkwidgets import view

original = False
if original:
    # Fetch input images, if we don't have them already.
    input_image_filename = 'Easy1.png'
    input_image_url = 'https://data.kitware.com/api/v1/file/576ad39b8d777f1ecd6702f2/download'
    if not os.path.exists(input_image_filename):
        urlretrieve(input_image_url, input_image_filename)

    reference_image_filename = 'Hard.png'
    reference_image_url = 'https://data.kitware.com/api/v1/file/57718cc48d777f1ecd8a883f/download'
    if not os.path.exists(reference_image_filename):
        urlretrieve(reference_image_url, reference_image_filename)
        
    output_image_filename = 'HardWithEasy1Colors.png'
else:  # use U Texas data as reference and Boston Children's Hospital as input images
    import numpy as np

    input_image_filename = 'inp0_(2,1).npy'  # source tile with patch coords
    reference_image_filename = 'ref0.npy'
    output_image_filename = 'UTrefWithBCHinp.png'

# The pixels are RGB triplets of unsigned char.  The images are 2 dimensional.
PixelType = itk.RGBPixel[itk.UC]
ImageType = itk.Image[PixelType, 2]

# Invoke the functional, eager interface for ITK

if original:
    input_image = itk.imread(input_image_filename, PixelType)
    reference_image = itk.imread(reference_image_filename, PixelType)
else: 
    # read .npy data, convert to ITK compatible format
    input_image = np.load(input_image_filename)
    np_stats(input_image_filename,input_image)
    input_image = itk.image_from_array(input_image)
    print(f' type(input_image) = {type(input_image)}')
    

    reference_image = np.load(reference_image_filename)
    np_stats(reference_image_filename,reference_image)
    reference_image = itk.image_from_array(reference_image)
    print(f' type(reference_image) = {type(reference_image)}')

print('\n !!! Invoke the functional, eager interface for ITK\n')
eager_normalized_image = itk.structure_preserving_color_normalization_filter(
    input_image,
    reference_image,
    color_index_suppressed_by_hematoxylin=0,
    color_index_suppressed_by_eosin=1)
itk.imwrite(eager_normalized_image, output_image_filename)

view(input_image)

view(reference_image)

view(eager_normalized_image)

