ITK functions don't work properly when used inside Tensorflow data multiprocessing pipeline

Description

The code below implements a Tensorflow (TF) data pipeline that uses ITK functions for processing the inputs.

It seems that ITK modules can’t be properly accessed when they’re wrapped in a Tensorflow py_function , which processes the samples using multiple workers (see full error trace below). Crucially, this issue doesn’t occur if TF is forced to use a single worker (by setting num_parallel_calls=1 (in ds.map()).

I first thought that it could be potential conflict between worker usage between ITK and TF (as ITK functions use multiprocessing as well), but forcing ITK to use a single worker (by using number_of_work_units=1 in resample_image_filter ) doesn’t solve the problem.

System specification:

  • Ubuntu 18.04
  • Python 3.8.0

Reproduce error:

  1. Create virtual env
python -m venv itk_tf_venv
source itk_tf_venv/bin/activate
pip install tensorflow==2.3.1 itk==5.1.1 numpy==1.19.0
  1. Create test.py script that applies ITK’s Euler3D transform wrapped in a simple Tensorflow data pipeline:
import numpy as np
import tensorflow as tf
import itk

def apply_itk_transform(image, angles=(0.1, 0.1, 0.1), translation=(0., 0., 0.)):

    # tensor to numpy
    image = image.numpy().astype(np.int16)

    angles = np.array(angles, dtype=np.float)
    translations = np.array(translation, dtype=np.float)

    output_direction = itk.matrix_from_array(np.eye(3))
    image_size = image.shape[::-1]  # to ITK ordering (XYZ)
    center = [p / 2 for p in image_size]

    # transformation matrix
    rigid_euler = itk.Euler3DTransform.New()
    rigid_euler.SetRotation(*angles)
    rigid_euler.SetTranslation(translations)
    rigid_euler.SetCenter(center)

    # define interpolator type
    linear_interpolator_type = itk.LinearInterpolateImageFunction[itk.Image[itk.SS, 3], itk.D].New()

    # apply transform
    image = itk.resample_image_filter(
        image,
        size=image_size,
        transform=rigid_euler,
        output_origin=[0, 0, 0],
        output_spacing=(1.0, 1.0, 1.0),
        output_direction=output_direction,
        interpolator=linear_interpolator_type.New(),
    )

    return np.asarray(image)


image = np.zeros((200, 200, 200), dtype=np.int16)

# create data pipeline with 2 samples
ds = tf.data.Dataset.from_tensor_slices([image]).repeat(2)

# wrap ITK transform in tf.py_function
ds = ds.map(
    lambda image: tf.py_function(apply_itk_transform, [image], [tf.int16]),
    num_parallel_calls=4,
)

# take samples and print shapes
for batch in ds.take(2):
    print(batch[0].shape)
  1. Execute script
python test.py

Full error message:

Unknown: AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig'                                                                                              
Traceback (most recent call last):                                                         
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 242, 
in __call__                                                                                                                
    return func(device, token, args)                                                              
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 131, 
in __call__                                                                                                                
    ret = self._func(*args)                                                                       
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 3
02, in wrapper                                                                                                             
    return func(*args, **kwargs)                                                                        
  File "test.py", line 14, in apply_itk_transform                                                                          
    output_direction = itk.matrix_from_array(np.eye(3))                                                           
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 376, in GetMatrixFromArray  
    vnl_matrix = GetVnlMatrixFromArray(arr)                                                                          
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 365, in GetVnlMatrixFromArray                                                                                                                 
    return _GetVnlObjectFromArray(arr, "GetVnlMatrixFromArray")                                                                                                                                                                           
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 351, in _GetVnlObjectFromArr
ay                                                                                                                         
    PixelType = _get_itk_pixelid(arr)                                                                                                          
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 223, in _get_itk_pixelid    
    numpy.complex64:itk.complex[itk.F],                                                                              
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkLazy.py", line 52, in __getattribute__       
    itkBase.LoadModule(module, namespace)                                                                              
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkBase.py", line 61, in LoadModule             
    swig.update(this_module.swig)                                                                                     
AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig'                                                                                
(200, 200, 200)                                                                                                            
Traceback (most recent call last):                                                                                         
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/eager/context.py", line 2102, 
in execution_mode                                                                                                          
    yield                                                                                                                  
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", lin
e 755, in _next_internal                                                                                                   
    ret = gen_dataset_ops.iterator_get_next(                                                                               
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 
2610, in iterator_get_next                                                                                                 
    _ops.raise_from_not_ok_status(e, name)                                                                                 
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 6843,
in raise_from_not_ok_status                                                                                        
    six.raise_from(core._status_to_exception(e.code, message), None)                                                       
  File "<string>", line 3, in raise_from                                                                                   
tensorflow.python.framework.errors_impl.UnknownError: AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig' 
Traceback (most recent call last):                                                                                                                                                                          
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 242, 
in __call__                                                                                                                
    return func(device, token, args)                                                                            
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 131, 
in __call__                                                                                                                
    ret = self._func(*args)                                                                                     
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 3
02, in wrapper                                                                                                             
    return func(*args, **kwargs)                                                                                
  File "test.py", line 14, in apply_itk_transform                                                                          
    output_direction = itk.matrix_from_array(np.eye(3))                                                          
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 376, in GetMatrixFromArray  
    vnl_matrix = GetVnlMatrixFromArray(arr)                                                                     
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 365, in GetVnlMatrixFromArra
y                                                                                                                          
    return _GetVnlObjectFromArray(arr, "GetVnlMatrixFromArray")                                                 
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 351, in _GetVnlObjectFromArr
ay                                                                                                                         
    PixelType = _get_itk_pixelid(arr)                                                                                                                                             
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 223, in _get_itk_pixelid    
    numpy.complex64:itk.complex[itk.F],                                                                         
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkLazy.py", line 52, in __getattribute__       
    itkBase.LoadModule(module, namespace)                                                                        
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkBase.py", line 61, in LoadModule             
    swig.update(this_module.swig)                                                                              
AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig'                                         
         [[{{node EagerPyFunc}}]] [Op:IteratorGetNext]                                                                     
              
During handling of the above exception, another exception occurred:                                                                            
Traceback (most recent call last):                                                                                         
  File "test.py", line 53, in <module>                                                                                     
    for batch in ds.take(2):                                                                                               
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", lin
e 736, in __next__                                                                                                         
    return self.next()                                                                                                     
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 772, in next

	return self._next_internal()
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 764, in _next_internal
    return structure.from_compatible_tensor_list(self._element_spec, ret)
  File "/home/goncalo/.pyenv/versions/3.8.0/lib/python3.8/contextlib.py", line 131, in __exit__
    self.gen.throw(type, value, traceback)
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/eager/context.py", line 2105, in execution_mode
    executor_new.wait()
  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/eager/executor.py", line 67, in wait
    pywrap_tfe.TFE_ExecutorWaitForAllPendingNodes(self._handle)
tensorflow.python.framework.errors_impl.UnknownError: AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig'
Traceback (most recent call last):

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 242, in __call__
    return func(device, token, args)

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/ops/script_ops.py", line 131, in __call__
    ret = self._func(*args)

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py", line 302, in wrapper
    return func(*args, **kwargs)

  File "test.py", line 14, in apply_itk_transform
    output_direction = itk.matrix_from_array(np.eye(3))

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 376, in GetMatrixFromArray
    vnl_matrix = GetVnlMatrixFromArray(arr)

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 365, in GetVnlMatrixFromArray
    return _GetVnlObjectFromArray(arr, "GetVnlMatrixFromArray")

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 351, in _GetVnlObjectFromArray
    PixelType = _get_itk_pixelid(arr)

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkExtras.py", line 223, in _get_itk_pixelid
    numpy.complex64:itk.complex[itk.F],

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkLazy.py", line 52, in __getattribute__
    itkBase.LoadModule(module, namespace)

  File "/home/goncalo/itk_tf_issue/itk_tf_venv/lib/python3.8/site-packages/itkBase.py", line 61, in LoadModule
    swig.update(this_module.swig)

AttributeError: module 'itk.ITKCommonPython' has no attribute 'swig'


         [[{{node EagerPyFunc}}]]

This issue seems similar.

1 Like

Thanks @dzenanz, the test suggested here by @hjmjohnson seems to solve the problem in the case that I described here.
Besides slowing down module loading, does disabling itkLazy have any downside during runtime?

1 Like

There shouldn’t be downsides besides slower initial load.

1 Like