Operator to flatten masks based on an intensity value

I find myself in need of a filter that takes a set of binary masks-real valued image pairs and returns a single mask where the value is the mask index with the largest real value. I use this constantly for flattening logits or label interpolation.

Currently, I do this in numpy with the argmax features.

# Let `logits` come from a neural network I have
# It is [512 x 512 x 128] with 4 channels
# Each channel is a semantic class - say, grey matter, white matter, skull, skin
logits.shape() # [512, 512, 128, 4]

# Get mask, which is a semantic segmentation
mask = np.argmax(images)
mask.shape() # [512, 512, 128]
np.min(max)  # 0
np.max(max)  # 3

I can then do some additional filtering to map [0, 1, 2, 3] to my class labels and also suppress values below a threshold. Say, for instance, if the logit is less than 0 for all labels, I usually assign that background. I do the exact same operation in different situations, like label correcting.

I imagine a C++ interface something like this:

// Types
using Dimension = 3;  
using MaskPixelType = short;
using IntensityPixelType = float;

using MaskImageType = itk::Image<MaskPixelType, Dimension>;
using IntensityImageType = itk::Image<IntensityPixelType, Dimension>;

// Basic Interface using pairs
using ImagePair = pair<MaskImageType, IntensityImageType>;
vector<ImagePair> image_pairs = get_image_pairs();  # Not defined here

auto filter = itk::FlattenLabelsImageFilter<MaskImageType, IntensityImageType>();
filter->SetPairs(pairs);
segmentation = filt->GetOutput();

Additional interfaces (e.g., .add_pair) as needed.

The real benefit of this filter would be if the filter works for images that align in physical space but do not have the same extent. The issue with these filters is that they are very memory intensive. Say my network outputs images at [128, 128, 128] and I have 8 classes to segment. That’s a 128 x 128 x 128 x 8 x sizeof(float) memory requirement (I could compress to int, yes). However, these images might be subsamples of a larger image. So I actually have to resample the logits to their native spacing first, which might be [512, 512, 1028], before joining, which has an extreme memory burden.

I think the basic “label flattening map” is not so difficult to implement. Is there already a solution for this? Is my approach wrong? The second resampling method might take more effort and defining.

I don’t understand what is MaskImageType for. The argmax along channels can be done without it. And what does spacing have to do with it? Could you provide some sample inputs and desired outputs (as e.g. NRRD images)? But it sounds like an easy filter to make.

Yes, sorry, got ahead of myself. No MaskImageType needed. Let me see if I can get some examples here.